Nguyễn Hoàng Vũ - Trường Đại học Công nghệ, ĐHQGHN
Reviewer:
Trần Xuân Bách - Đại học Chicago
Nguyễn Minh Nhật - Trường THPT chuyên Khoa học Tự nhiên, ĐHQGHN
Các bài toán tổ hợp ngày càng xuất hiện nhiều trong các cuộc thi lập trình thi đấu và thường xuyên nắm giữ các vị trí khó nhất. Bài viết này sẽ giới thiệu về một công cụ quan trọng để giải các bài toán tổ hợp, đó là Biến đổi Fourier nhanh - Fast Fourier transform, hay còn được viết tắt là FFT.
Phép cộng (trừ) hai số phức tương đương với phép cộng (trừ) hai vectơ biểu diễn chúng. Khi nhân hai số phức, ta nhân môđun của chúng và cộng acgumen của chúng.
Công thức tổng quát cho các căn đơn vị cấp n: zn,i=exp(n2kπi) với k từ 0 đến n−1. Kí hiệu ωn=zn,1 thì ta còn có thể viết dưới dạng ωn0,ωn1,…,ωnn−1.
Các căn đơn vị nằm trên đường tròn đơn vị tạo thành một đa giác đều n đỉnh.
Hình 2. Biểu diễn các căn đơn vị cấp 8 trên mặt phẳng
thì c chính là dãy hệ số của đa thức C(z)=A(z)⋅B(z).
Có thể thấy rằng việc tính đa thức C theo định nghĩa sẽ có độ phức tạp thời gian O(n2), không đủ nhanh khi độ dài n khá lớn, ta cần một hướng tiếp cận khác. Giả sử ta tính được giá trị của A và B tại m điểm khác nhau:
A(z0),A(z1),…,A(zm−1)B(z0),B(z1),…,B(zm−1)
thì giá trị của C tại các điểm tương ứng là: C(zj)=A(zj)⋅B(zj).
Ta có định lý quan trọng sau đây:
Định lý nội suy đa thức. Cho m cặp số phức (u0,v0),(u1,v1),…,(um−1,vm−1) thoả mãn zi=zj với mọi 0≤i<j<m, tồn tại duy nhất một đa thức P có bậc không quá m−1 thoả mãn P(ui)=vi với mọi 0≤i<m.
Ví dụ:
Có duy nhất một đường thẳng (đa thức bậc 1) đi qua hai điểm bất kì trên mặt phẳng.
Có duy nhất một parabol (đa thức bậc 2) đi qua ba điểm bất kì trên mặt phẳng.
Ý tưởng của thuật toán FFT là chọn ra một tập điểm z0,z1,…,zm−1 sao cho ta có thể tính nhanh giá trị của đa thức A và B trên đó, đồng thời có thể khôi phục được đa thức C dựa trên C(z0),C(z1),…,C(zm−1).
Ta có một đa thức A(z)=a0z0+a1z1+…+an−1zn−1. Không mất tính tổng quát, giả sử n là một luỹ thừa của 2 hay n=2k với k∈N. Nếu n không phải là một lũy thừa của 2, ta thêm các số hạng aizi bị thiếu và cho các hệ số ai bằng 0.
Tập số mà FFT chọn để tính là tập các căn đơn vị cấp n, tức {ωn0,ωn1,…,ωnn−1} (nhắc lại ωn=exp(n2πi)).
Định nghĩa. Cho một dãy a0,a1,…,an−1, Biến đổi Fourier nhanh - Fast Fourier Transform - FFT là bất kì thuật toán nào tính dãy A(ωn0),A(ωn1),…,A(ωnn−1) trong thời gian O(nlogn). Phép biến đổi này bản thân nó được gọi là Biến đổi Fourier rời rạc - Discrete Fourier Transform - DFT.
Các số mũ trên bảng đã được lấy dư cho 8. Ta sẽ đưa các cột màu đỏ (chỉ số chẵn) về bên trái, các cột màu xanh (chỉ số lẻ) về bên phải và chia ma trận thành bốn ma trận con như dưới đây. Chú ý rằng việc này cũng thay đổi thứ tự của vectơ hệ số.
Nhận thấy rằng các hệ số tương ứng ở hai ma trận con bên trái bằng nhau và các hệ số tương ứng ở hai ma trận con bên phải trái dấu, do tính chất ωn/2=−1.
Vậy ta chỉ cần tính AX và AY là đủ để tính kết quả. Mặt khác, có thể thấy tính AX và AY tương đương với tính DFT(a0,a2,a4,a6) và DFT(a1,a3,a5,a7).
Xét trường hợp tổng quát, với 0≤i<2n, ta có (chú ý i ở đây là chỉ số, không phải đơn vị ảo):
Chứng minh:
Đặt (b0,b1,…,bn−1)=DFT(a0,a1,…,an−1) và (c0,c1,…,cn−1)=DFT(b0,b1,…,bn−1).
Ta có:
cl=k∑ckωkl=k∑ωklj∑ajωjk=j∑ajk∑ωk(j+l)
Xét j+l≡0(modn), ta có:
k∑ωk(j+l)=ωj+l−1(ωj+l)n−1=0
bởi wj+l là một căn đơn vị.
Từ đây ta suy ra cl=naj với j+l≡0(modn).
Ngoài ra, nếu ta sử dụng ω−1=exp(n−2πi) thay cho ω ở lần DFT thứ hai, kết quả ta nhận được sẽ là (na0,na1,…,nan−1) (bạn đọc tự chứng minh).
Ta sẽ cài đặt chung cả biến đổi xuôi và ngược:
using cd = complex<long double>;
void fft(vector<cd> &a, bool invert) {
/// invert = true tương ứng với biến đổi ngược
int n = a.size();
if (n == 1) return;
vector<cd> a0, a1;
for (int i = 0; i < n / 2; i++) {
a0.push_back(a[2 * i]);
a1.push_back(a[2 * i + 1]);
}
fft(a0, invert); fft(a1, invert);
cd w = 1, wn = polar(1.0L, acos(-1.0L) / n * (invert ? -2 : 2));
/// polar(r, t) = r * exp(it) và acos(-1.0L) = pi
/// thay wn = wn^-1 ở biến đổi ngược
for (int i = 0; i < n / 2; i++) {
a[i] = a0[i] + w * a1[i];
a[i + n / 2] = a0[i] - w * a1[i];
/// Ta sẽ chia 2 ở mỗi tầng đệ quy thay cho việc chia n ở cuối
if (invert) {
a[i] /= 2; a[i + n / 2] /= 2;
}
w *= wn;
}
}
Dưới đây là cài đặt để tính tích chập của hai dãy số:
vector<int> conv(const vector<int> &a, const vector<int> &b) {
if (a.empty() || b.empty()) return {};
vector<cd> fa(a.begin(), a.end());
vector<cd> fb(b.begin(), b.end());
int n = 1;
while (n < int(a.size() + b.size()) - 1) n <<= 1;
fa.resize(n); fb.resize(n);
fft(fa, false); fft(fb, false);
for (int i = 0; i < n; i++)
fa[i] *= fb[i];
fft(fa, true);
vector<int> res(n);
for (int i = 0; i < n; i++)
res[i] = int(real(fa[i]) + 0.5);
return res;
}
Vậy nếu ta sắp xếp dãy a lại thành (a0,a4,a2,a6,a1,a5,a3,a7) thì mỗi lần gọi đệ quy ở tầng thứ i sẽ thực hiện trên một đoạn dài 2i.
Viết lại dãy chỉ số dưới dạng nhị phân:
(000,100,010,110,001,101,011,111)
Có thể thấy rằng nếu ta đảo ngược thứ tự bit thì, ví dụ 100→001 thì ta nhận được dãy tăng dần từ 0 đến n−1.
Gọi rev(i) là số nhận được sau khi đảo ngược thứ tự bit của i, ta có rev(i)=2rev(i/2)∣[(imod2)⋅n] với ∣ thể hiện phép toán bitwise OR (bạn đọc tự chứng minh).
Ta có cài đặt sau:
void fft(vector<cd> &a, bool invert) {
int n = a.size(), L = __builtin_ctz(n);
vector<int> rev(n);
for (int i = 0; i < n; i++) {
rev[i] = (rev[i >> 1] | (i & 1) << L) >> 1;
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int len = 2; len <= n; len <<= 1) {
cd wlen = polar(1.0L, acos(-1.0L) / len * (invert ? -2 : 2));
for (int i = 0; i < n; i += len) {
cd w = 1;
for (int j = 0; j < len / 2; j++) {
cd u = a[i + j];
cd v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w *= wlen;
}
}
}
if (invert) {
for (auto &x : a) x /= n;
}
}
Ở cài đặt phía trên ta có viết w *= wlen để tính luỹ thừa của căn đơn vị. Việc nhân nhiều lần sẽ ảnh hưởng rất lớn đến độ chính xác của thuật toán vì ta đang thực hiện tính toán trên số thực.
Nhận xét rằng với mỗi len ta chỉ cần tính ωlen0,ωlen1,…,ωlenlen/2−1.
Định nghĩa một mảng root sao cho với mỗi len và với mỗi 0≤j<2len ta có root(2len+j)=ωlenj.
Cài đặt để tính mảng root:
vector<cd> root(n);
root[1] = 1;
for (int k = 2; k < n; k *= 2) {
cd z = polar(1.0l, acos(-1.0l) / k * (invert ? 1 : -1));
for (int j = k / 2; j < k; j++) {
root[2 * j] = root[j];
root[2 * j + 1] = root[j] * z;
}
}
Sau khi có mảng root ta có thể sửa cài đặt FFT như sau:
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k)
for (int j = 0; j < k; j++) {
cd z = root[j + k] * a[i + j + k];
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
Thử nghiệm với n=220, sai số chỉ rơi vào khoảng 5.5511⋅10−16.
Sử dụng công thức này ta có thể tính tích chập chỉ dùng hai lần gọi hàm fft:
vector<int> conv(const vector<int> &a, const vector<int> &b) {
if (a.empty() || b.empty()) return {};
int n = 1;
while (n < int(a.size() + b.size()) - 1) n <<= 1;
vector<cd> in(n), out(n);
for (int i = 0; i < int(a.size()); i++)
in[i].real(a[i]);
for (int i = 0; i < int(b.size()); i++)
in[i].image(b[i]);
fft(in, false);
for (int i = 0; i < n; i++)
in[i] *= in[i];
for (int i = 0; i < n; i++) {
/// (n - i) mod n
int j = -i & (n - 1);
out[i] = in[i] - conj(in[j]);
}
fft(out, true)
vector<int> res(n);
/// ở trên ta không chia cho 4i nên kết quả sẽ nằm trong phần ảo
for (int i = 0; i < n; i++)
res[i] = int(imag(out[i]) / 4 + 0.5);
return res;
}
Xét bài toán nhân đa thức nhưng lần này ta muốn các hệ số của đa thức chia lấy dư cho một số nguyên tố p. Nếu ta sử dụng thuật toán FFT thông thường có thể gây ra sai số lớn vì hệ số của đa thức kết quả có thể rất lớn. Thuật toán NTT cho phép ta tính toán chỉ dùng số nguyên, từ đó kết quả luôn đảm bảo chính xác.
Thuật toán FFT dựa trên các tính chất của căn đơn vị. Các tính chất này cũng xuất hiện trên căn đơn vị trong số học modulo. Cụ thể ta gọi căn đơn vị cấp n modulo p là một số nguyên ωn thoả mãn:
(ωn)n=1(modp)(ωn)j=(ωn)k,∀0≤j<k<n
Điều kiện thứ hai có thể được viết lại thành: (ωn)k=1 với mọi 1≤k<n.
Các căn đơn vị cấp n khác được biểu diễn bằng một luỹ thừa của ωn.
Để áp dụng thuật toán FFT, ta cần các căn đơn vị cho các luỹ thừa nhỏ hơn của 2. Ta có thể chứng minh tính chất sau:
(ωn2)n/2=1(modp)(ωn2)k=1(modp),∀1≤k<2n
Từ đây ta có nếu ωn là căn đơn vị cấp n, thì ωn2 là căn đơn vị cấp 2n và do đó ta tính được căn đơn vị cho các luỹ thừa của 2 nhỏ hơn.
Để tính biến đổi ngược ta cần nghịch đảo moduloωn−1 tồn tại, điều này là hiển nhiên vì p là số nguyên tố.
Ta chứng minh được rằng với một số nguyên tố có dạng p=c2k+1, tồn tại ω2k=gc với g là một căn nguyên thuỷ modulo p. Lấy ví dụ với p=998244353=119⋅223+1 có căn nguyên thuỷ g=3 và ω223=3119modp=15311432.
Cài đặt với p=998244353 (các hằng số root, root_pw thể hiện căn đơn vị và số mũ tương ứng, root_1 là nghịch đảo của căn đơn vị, hàm inverse tính nghịch đảo modulo của một số):
const int mod = 998244353;
const int root = 15311432;
const int root_1 = 469870224;
const int root_pw = 1 << 23;
void fft(vector<int> & a, bool invert) {
int n = a.size(), L = __builtin_ctz(n);
vector<int> rev(n);
for (int i = 0; i < n; i++) {
rev[i] = (rev[i >> 1] | (i & 1) << L) >> 1;
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = invert ? root_1 : root;
for (int i = len; i < root_pw; i <<= 1)
wlen = (int)(1LL * wlen * wlen % mod);
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < len / 2; j++) {
int u = a[i + j];
int v = 1ll * a[i + j + len / 2] * w % mod;
a[i + j] = u + v < mod ? u + v : u + v - mod;
a[i + j + len / 2] = u - v >= 0 ? u - v : u - v + mod;
w = 1ll * w * wlen % mod;
}
}
}
if (invert) {
int n_1 = inverse(n, mod);
for (int & x : a)
x = 1ll * x * n_1 % mod;
}
}
Nếu kết quả của phép nhân đa thức có hệ số nhỏ hơn M1⋅M2, với M1,M2 là hai số nguyên tố có dạng c2k+1, ta có thể thực hiện NTT trên hai modulo này và dùng định lý thặng dư Trung Hoa để khôi phục kết quả.