오늘 풀어본 문제는 백준의 1067번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.
이 문제의 내용과 조건은 다음과 같다.
$N$ 개의 수가 있는 $X$ 와 $Y$ 가 있다. 이때 $X$ 나 $Y$ 를 순환 이동시킬 수 있다. 순환 이동이란 마지막 원소를 제거하고 그 수를 맨 앞으로 다시 삽입하는 것을 말한다. 예를 들어, ${1, 2, 3}$ 을 순환 이동시키면 ${3, 1, 2}$ 가 될 것이고, ${3, 1, 2}$ 는 ${2, 3, 1}$ 이 된다. 순환 이동은 $0$ 번 또는 그 이상 할 수 있다. 이 모든 순환 이동을 한 후에 점수를 구하면 된다. 점수 $S$ 는 다음과 같이 구한다.
$S = X[0] \times Y[0] + X[1] \times Y[1] + \cdots + X[N-1] \times Y[N-1]$
이때 $S$ 를 최대로 하면 된다.
첫째 줄에 $N$ 이 주어진다. 둘째 줄에는 $X$ 에 들어있는 $N$ 개의 수가 주어진다. 셋째 줄에는 $Y$ 에 있는 수가 모두 주어진다. $N$ 은 $60,000$ 보다 작거나 같은 자연수이고, $X$ 와 $Y$ 에 들어있는 모든 수는 $100$ 보다 작은 자연수 또는 $0$ 이다.
첫째 줄에 $S$ 의 최댓값을 출력한다.
이 문제는 FFT를 이용해야 하는 문제이다. 정확하게 말하자면, ‘convolution’ 연산을 하는 과정에서 FFT를 이용해야 한다. 나이브(naive) 한 방법으로 컨볼루션 연산을 하면 시간복잡도가 $O(n^2)$ 이 되는데, FFT를 활용하게 되면 시간 복잡도를 $O(n \log n)$으로 줄일 수 있는 것이다.
물론 이 문제는 단순히 두 벡터의 컨볼루션 연산을 한 뒤 결과를 내보내면 되는 것이 아니라, 각 벡터에 이동을 적용하여 연산한 결과들의 최댓값을 얻어내야 한다. 이는 첫 번째 벡터를 그대로 이어붙이고, 두 번째 벡터를 뒤집어서 컨볼루션을 진행하면 된다.
예를 들어 $(1,\ 2,\ 3,\ 4)$ 와 $(6,\ 7,\ 8,\ 5)$ 에 대한 점수를 구하기 위해서는 다음 식들을 계산하고 대소를 비교해야 한다.
여기서 첫 번째 수들은 이동을 시키지 않을 것을 볼 수 있는데, 두 번째 수들만 이동시킨 결과만 계산해도 첫 번째 수들을 임의로 이동시킨 결과 중 하나와 일치하기 때문에 냅둔 것이다.
이제 앞에서 내가 말한 방식대로 첫 번째 숫자들을 그대로 이은 뒤 두 번째 숫자들을 뒤집으면 $(1,\ 2,\ 3,\ 4,\ 1,\ 2,\ 3,\ 4)$ 과 $(5,\ 8,\ 7,\ 6)$ 가 되고, 이 둘을 convolution 한 결과의 각 항들은 다음과 같다.
여기서 우리가 필요한 식들이 모두 나오게 된다! 이제 이것들을 모조리 비교하여 최댓값을 구해내면 그것이 우리가 찾는 점수가 되는 것이다.
코드는 다음과 같이 작성하였다.
#include <bits/stdc++.h>
using namespace std;
double const_pi(void) {
return std::atan(1) * 4;
}
void FFT(std::vector<std::complex<double>>& a, const std::complex<double>& w) {
int n = a.size();
if (n == 1) return;
std::vector<std::complex<double>> a_even(n / 2), a_odd(n / 2);
for (int i = 0; i < n / 2; i++) {
a_even[i] = a[2 * i];
a_odd[i] = a[2 * i + 1];
}
std::complex<double> w_sq = w * w;
FFT(a_even, w_sq);
FFT(a_odd, w_sq);
std::complex<double> w_i = 1;
for (int i = 0; i < n / 2; i++) {
a[i] = a_even[i] + w_i * a_odd[i];
a[i + n / 2] = a_even[i] - w_i * a_odd[i];
w_i *= w;
}
}
std::vector<std::complex<double>> convolution(std::vector<std::complex<double>> a, std::vector<std::complex<double>> b, bool getIntegerResult = false) {
int n = 1;
double pi = const_pi();
while (n <= a.size() || n <= b.size()) {
n <<= 1;
}
n <<= 1;
a.resize(n);
b.resize(n);
std::vector<std::complex<double>> c(n);
std::complex<double> w(cos(2 * pi / n), sin(2 * pi / n));
FFT(a, w);
FFT(b, w);
for (int i = 0; i < n; i++) {
c[i] = a[i] * b[i];
}
FFT(c, std::complex(w.real(), -w.imag()));
for (int i = 0; i < n; i++) {
c[i] /= std::complex<double>(n, 0);
if (getIntegerResult) {
c[i] = std::complex(round(c[i].real()), round(c[i].imag()));
}
}
return c;
}
int main(void) {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int N;
cin >> N;
vector<complex<double>> X(2 * N);
vector<complex<double>> Y(N);
for (int i = 0; i < N; i++) {
int tempNum;
cin >> tempNum;
X[i] = X[i + N] = complex<double>(tempNum, 0);
}
for (int i = 0; i < N; i++) {
int tempNum;
cin >> tempNum;
Y[N - i - 1] = complex<double>(tempNum, 0);
}
auto Z = convolution(X, Y, true);
int result = 0;
for (auto& i : Z) {
cout << i << " ";
result = max(result, static_cast<int>(i.real()));
}
cout << endl;
cout << result;
return 0;
}
실행 결과 ‘출력 초과’가 나왔다.
제출하는 순간 아차 싶었고, 역시 문제가 발생했다. 테스트를 위해서 Z 벡터의 값을 출력했었는데, 이를 지우지 않고 제출한 것이다. 해결 방법은 당연하게도 테스트용 출력 코드들을 지우는 것이었다.
코드는 다음과 같이 작성하였다.
#include <bits/stdc++.h>
using namespace std;
double const_pi(void) {
return std::atan(1) * 4;
}
void FFT(std::vector<std::complex<double>>& a, const std::complex<double>& w) {
int n = a.size();
if (n == 1) return;
std::vector<std::complex<double>> a_even(n / 2), a_odd(n / 2);
for (int i = 0; i < n / 2; i++) {
a_even[i] = a[2 * i];
a_odd[i] = a[2 * i + 1];
}
std::complex<double> w_sq = w * w;
FFT(a_even, w_sq);
FFT(a_odd, w_sq);
std::complex<double> w_i = 1;
for (int i = 0; i < n / 2; i++) {
a[i] = a_even[i] + w_i * a_odd[i];
a[i + n / 2] = a_even[i] - w_i * a_odd[i];
w_i *= w;
}
}
std::vector<std::complex<double>> convolution(std::vector<std::complex<double>> a, std::vector<std::complex<double>> b, bool getIntegerResult = false) {
int n = 1;
double pi = const_pi();
while (n <= a.size() || n <= b.size()) {
n <<= 1;
}
n <<= 1;
a.resize(n);
b.resize(n);
std::vector<std::complex<double>> c(n);
std::complex<double> w(cos(2 * pi / n), sin(2 * pi / n));
FFT(a, w);
FFT(b, w);
for (int i = 0; i < n; i++) {
c[i] = a[i] * b[i];
}
FFT(c, std::complex(w.real(), -w.imag()));
for (int i = 0; i < n; i++) {
c[i] /= std::complex<double>(n, 0);
if (getIntegerResult) {
c[i] = std::complex(round(c[i].real()), round(c[i].imag()));
}
}
return c;
}
int main(void) {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int N;
cin >> N;
vector<complex<double>> X(2 * N);
vector<complex<double>> Y(N);
for (int i = 0; i < N; i++) {
int tempNum;
cin >> tempNum;
X[i] = X[i + N] = complex<double>(tempNum, 0);
}
for (int i = 0; i < N; i++) {
int tempNum;
cin >> tempNum;
Y[N - i - 1] = complex<double>(tempNum, 0);
}
auto Z = convolution(X, Y, true);
int result = 0;
for (auto& i : Z) {
result = max(result, static_cast<int>(i.real()));
}
cout << result;
return 0;
}
그러자 모든 테스트 케이스를 통과하고 정답이 나오는 것을 확인할 수 있었다.
이젠 하다하다 CLASS 8 문제를 풀게되었다. 예상했겠지만, 세미나에서 배웠던 내용으로 FFT를 활용한 문제였다.
FFT를 써서 문제를 푼 것까지는 좋은데, FFT가 뭐하는 건지 솔직히 아직 제대로 이해하지 못했다. 지난 학기에 응용복소함수론 강의를 듣긴 했었는데, 아마 그 내용과 관련이 있나 싶기도 하다. 조만간 FFT는 제대로 공부해봐야 할 것 같다.
그리고 구현해둔 FFT 코드는 ‘자료구조’ 는 아니긴 하나, 이후에 유용하게 사용할 수 있을 것 같아 깃험 레포지토리2에 저장해두었다. 필요한 사람은 참고하면 된다.
오늘의 PS는 여기까지!
1: https://www.acmicpc.net/problem/1067
2: https://github.com/ChoiCube84/baekjoon-solutions/blob/main/C%2B%2B/custom_data_structures/custom_algorithms.hpp