Educational Codeforces Round 115 (Rated for Div. 2)-C. Delete Two Elements
falconlee236
·2021. 11. 16. 20:12
문제 설명
모노폴리는 정수 $n$개로 이루어진 배열 $a$를 가지고 있다. $a$의 산술평균을 $k$라고 가정하자.($k$는 정수가 아닐 수 있다.)
모노폴리는 정확히 배열 $a$에 있는 숫자 2개를 지워서 정확히 $n - 2$개의 숫자들로 이루어진 평균을 여전히 $k$로 유지하고 싶다.
우리가 해야할 것은 만약 두 숫자를 지워도 여전히 평균이 $k$가 되는 숫자들의 인덱스 쌍 $[i, j](i < j)$ 의 개수를 구하는 것이다.
Input
첫번째 줄에는 테스트 케이스의 개수를 나타내는 정수 $t (1 \le t \le 10^4)$ 이 주어진다.
각 테스트케이스의 첫번째 줄에는 배열에 있는 원소의 수 $n (3 \le n \le 2 \cdot 10^5)$ 이 주어진다.
각 테스트케이스의 두번째 줄에는 배열에 있는 정수 $a_1, a_2, ...., a_n (0 \le a_i \le 10^9)$ 가 주어진다. 이때 $a_i$는 배열의 $i$번째 원소이다.
Output
각 테스트케이스 두 원소를 지워서 얻은 $(n - 2)$ 개로 이루어진 평균이 $n$개로 이루어진 평균과 같게하는 원소들의 인덱스 $[i, j] (i < j)$의 총 개수를 출력한다.
Example
input
4
4
8 8 8 8
3
50 20 10
5
1 4 7 3 5
7
1 2 3 4 5 6 7
output
6
0
2
3
문제 접근
사용한 알고리즘: 수학, 자료구조, 구현
걸린 시간 : 00:06
온몸 비틀면서 겨우겨우 binary_search, lower_bound, upper_bound를 이용해서 대회 종료 5분전에 풀었는데 문제 해설을 보니 well-known문제라고 한다. 실제로 고수들의 풀이를 보니까 다 비슷비슷 하더라. 이 문제에서 구하는 방식은 앞으로 고수들의 well-known풀이를 이용해서 쉽게쉽게 풀어보자.
일단 간단한 수학을 해야한다. $a$에 있는 모든 숫자 $n$개 들의 총 합을 $s$라고 하자. 그렇다면 평균 $k$는 $s/n$으로 구할 수 있다. 우리가 제거해야하는 두 숫자의 합을 $p$라고 하자. 그러면 우리는 다음식을 만족하는 $p$를 구해야한다. $s/n = (s - p)/(n - 2)$
이 식을 $p$에 대한 식으로 정리하면 $p = (2 * s)/n$ 이라는 간단한 식이 나온다. 여기서 배열 $a$에 있는 수는 모두 정수이기 때문에 두 정수의 합인 $p$ 또한 정수가 되어야 한다. 즉 $(2 * s) mod n \neq 0$ 이면 $p$를 구할 수 없기 때문에 무조건 답은 0이다.
우리가 마지막으로 해야할 일은 두 수를 더해서 $p$가 되는 쌍을 찾으면 된다. 여기서 고수들의 well-known풀이가 등장한다. 일단 associative data structure를 사용해야하는데, c++에서는 std::map을 이용한다. 인덱스 $i, j$는 $(i < j)$라는 조건을 만족해야 하기 때문에 중복을 구하면 안된다.
우리는 배열의 한 원소를 $a_i$라고 할때, $p - a_i$가 있는지 확인을 할것이다. 만약 $map[p - a_i]$이 존재하다면 $i$번째 위치 이전에 $p - a_i$가 $map[p - a_i]$개 존재한다는 뜻이기 때문에 그만큼 더한다. 만약 존재하지 않는다면 std::map은 0을 반환한다. 그리고 이후 map자료구조에 현재 index의 원소를 key값으로 하는 value값을 1올려준다. 이런식으로 $n$개까지 모두 순회하면 우리가 구해야하는 $[i, j]$ 의 쌍의 합을 중복없이 구할 수 있다.
정답 코드
#include <iostream>
#include <map>
using namespace std;
int main() {
ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int t; cin >> t;
while(t--){
int n; cin >> n;
int arr[n];
long long sum = 0;
for(int i = 0; i < n; i++){
cin >> arr[i];
sum += arr[i] * 1LL;
}
if((sum << 1) % n != 0){
cout << 0 << "\n";
continue;
}
map<long long, int> mp;
long long p = (sum << 1) / n;
long long ans = 0;
for(int i = 0; i < n; i++){
ans += mp[p - arr[i]];
mp[arr[i]]++;
}
cout << ans << "\n";
}
return 0;
}