前置知识
- Fast Fourier Transform - HolyK
- 多项式系列操作
循环卷积优化
简介:目前网上关于多项式操作的文章和模板大多仍然是朴素的实现,常数巨大,这个技巧利用循环卷积优化多项式操作的常数。
(这个trick在N年前就有了,「循环卷积优化」这个名字是我瞎起的,如果有人知道这个trick的名字请联系我。)
多项式求逆
求 (B(x)) 满足 (A(x)B(x) equiv 1 pmod n)。
牛顿迭代得
朴素的实现需要做3次长度为 (2^{t+2}) 的FFT,把多余的部分舍去,常数较大。
发现 (B_{t+1}(x)) 的前 (2^t) 项和 (B_t(x)) 一样,所以只需要求后 (2^t) 项,即求 (A(x)B_t^2(x)) 的 (x^{2^t}dots x^{2^{t+1}-1}) 项系数。
设
这个结果的前 (2^t) 项确定是1了,所以只有后半部分是有用的。
由于 (deg B_t = 2^t),如果做长度为 (2^{t+1}) 的卷积,多余的部分会循环到前半部分,不会影响后半部分的结果。
同样的, (A(x)B_t^2(x) = A(x)B_t(x) imes B_t(x)) ,卷积多余的部分会循环到前 (2^t) 项,后半部分不会受到影响。
所以只需要做5次长度为 (2^{t+1}) 的FFT,在实际测试中(用100000的数据测试)常数约为正常写法的三分之二。
下面这个是递归写法:
Polynom inverse(Polynom a) {
int n = a.size();
assert((n & n - 1) == 0);
if (n == 1) return {fpow(a[0])};
int m = n >> 1;
Polynom b = inverse(Polynom(a.begin(), a.begin() + m)), c = b;
b.resize(n);
dft(a), dft(b);
for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i < m; i++) a[i] = 0;
for (int i = m; i < n; i++) a[i] = P - a[i];
dft(a);
for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i < m; i++) a[i] = c[i];
return a;
}
多项式开根
求 (B(x)) 使 $A(x)equiv B^2(x) pmod{x^n} $。
牛顿迭代得
只要求 ((B_t^2(x) - A(x))cdot{B_t^{-1}(x)}) 的后半部分即可。
容易发现 (B_t^2(x) - A(x)) 前 (2^t) 项是 (0),所以 (B_t^{-1}(x)) 的后 (2^t) 项并不会对答案的后半部分产生贡献,只求在模 (x^{2^t}) 意义下的 (B_t^{-1}(x)) 即可。
对于 (B_t^2(x)),由于前 (2^t) 项已知,所以只需要求长度为 (2^t) 的循环卷积,
实现时,不用每一层都调用多项式求逆,可以和开根一起迭代,具体看代码。
实际测试常数优化效果显著,可以在洛谷 P5205 的前几个看见我的提交。
Polynom sqrt(Polynom a) { // return-value: sqrt{a}
int len = a.size();
assert((len & len - 1) == 0);
assert(a[0] == 1); // warning: sqrtMod is needed if a[0] > 1.
Polynom b(len), binv{1}, bsqr{1}; // sqrt, sqrt_inv, sqrt_sqr
Polynom foo, bar; // temp
b[0] = 1;
auto shift = [](int x) { return (x & 1 ? x + P : x) >> 1; }; // quick div 2
for (int m = 1, n = 2; n <= len; m <<= 1, n <<= 1) {
foo.resize(n), bar = binv;
for (int i = 0; i < m; i++) {
foo[i + m] = sub(sum(a[i], a[i + m]), bsqr[i]);
foo[i] = 0;
}
binv.resize(n);
dft(foo), dft(binv);
for (int i = 0; i < n; i++) foo[i] = 1LL * foo[i] * binv[i] % P;
idft(foo);
for (int i = m; i < n; i++) b[i] = shift(foo[i]);
// inv
if (n == len) break;
for (int i = 0; i < n; i++) foo[i] = b[i];
bar.resize(n), binv = bar;
dft(foo), dft(bar);
bsqr.resize(n);
for (int i = 0; i < n; i++) bsqr[i] = 1LL * foo[i] * foo[i] % P;
idft(bsqr);
for (int i = 0; i < n; i++) foo[i] = 1LL * foo[i] * bar[i] % P;
idft(foo);
for (int i = 0; i < m; i++) foo[i] = 0;
for (int i = m; i < n; i++) foo[i] = P - foo[i];
dft(foo);
for (int i = 0; i < n; i++) foo[i] = 1LL * foo[i] * bar[i] % P;
idft(foo);
for (int i = m; i < n; i++) binv[i] = foo[i];
}
return b;
}