114514 年前写的题解,搬过来了
0.前言
必备内容:NTT,多项式求逆
建议到博客中查看
1.正文
定义 (F_R(x)) 为把 (F(x)) 系数翻转后的函数
若 (F_R(x)) 次数为 (k) ,那么 (F_R(x)=F(frac{1}{x})*x^k)
若 (f(x)=g(x)*q(x)+r(x)), (f) 次数为 (n) ,(g) 次数为 (m), (q) 次数为 (n-m) , (r) 次数为 (m-1)
我们可以把 (frac{1}{x}) 带入函数
得
[f(frac{1}{x})=g(frac{1}{x})*q(frac{1}{x})+r(frac{1}{x})
]
[f(frac{1}{x})*x^n=g(frac{1}{x})*x^mq(frac{1}{x})*x^{n-m}+r(frac{1}{x})*x^{m-1}*x^{n-m+1}
]
[f_R(x)=g_R(x)*q_R(x)+r_R(x)*x^{n-m+1}
]
[f_R(x) equiv g_R(x)*q_R(x) pmod {x^{n-m+1}}
]
[q_R(x)equiv frac{f_R(x)}{g_R(x)} pmod {x^{n-m+1}}
]
从而我们可以知道 (q(n) mod x^{n-m+1})
我们知道 (q) 次数为 (n-m) ,而模数次数为 (n-m+1)
于是就可以算出 (q) 啦!!!!!
然后我们可以用
[r(x)=f(x)-g(x)*q(x)
]
然后 (r) 也算出来啦!
2.代码
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
const int N = 524288, mod = 998244353, G = 3, iG = (mod + 1) / G;
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int lim, pp[N];
void revlim() { L(i, 0, lim - 1) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (lim >> 1))); }
void up(int x) { lim = 1; for(; lim <= x; lim <<= 1); }
void cle(int *f) { L(i, 0, lim - 1) f[i] = 0; }
void NTT(int *f, int flag) {
L(i, 0, lim - 1) if(pp[i] < i) swap(f[pp[i]], f[i]);
for(int i = 2; i <= lim; i <<= 1)
for(int j = 0, l = (i >> 1), ch = qpow(flag == 1 ? G : iG, (mod - 1) / i); j < lim; j += i)
for(int k = j, now = 1; k < j + l; k ++) {
int pa = f[k], pb = (ll) f[k + l] * now % mod;
f[k] = (pa + pb) % mod, f[k + l] = (pa + mod - pb) % mod;
now = (ll) now * ch % mod;
}
if(flag == -1) {
int nylim = qpow(lim);
L(i, 0, lim - 1) f[i] = (ll) f[i] * nylim % mod;
}
}
int sav[N], sv[N];
void inv(int *f, int *g, int len) {
if(len == 1) return g[0] = qpow(f[0]), void();
inv(f, g, (len + 1) >> 1), up(len << 1), cle(sav), copy(f, f + len, sav), revlim(), NTT(sav, 1), NTT(g, 1);
L(i, 0, lim - 1) g[i] = (ll) g[i] * (2ll + mod - (ll) g[i] * sav[i] % mod) % mod;
NTT(g, -1), fill(g + len, g + lim, 0);
}
int sava[N], savb[N];
void div(int *f, int *g, int *ansa, int *ansb, int n, int m) {
up(n << 1), cle(sava), cle(savb), copy(g, g + m, sava), reverse(sava, sava + m);
inv(sava, savb, n), up(n << 1), revlim(), copy(f, f + n, sava), reverse(sava, sava + n);
NTT(sava, 1), NTT(savb, 1);
L(i, 0, lim - 1) ansa[i] = (ll) sava[i] * savb[i] % mod;
NTT(ansa, -1), fill(ansa + n - m + 1, ansa + lim, 0), reverse(ansa, ansa + n - m + 1);
cle(sav), cle(sava), copy(ansa, ansa + n - m + 1, sav), copy(g, g + m, sava), NTT(sav, 1), NTT(sava, 1);
L(i, 0, lim - 1) sav[i] = (ll) sav[i] * sava[i] % mod;
NTT(sav, -1);
L(i, 0, m - 2) ansb[i] = (f[i] + mod - sav[i]) % mod;
}
int n, m, f[N], g[N], ansa[N], ansb[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> m, ++n, ++m;
L(i, 0, n - 1) cin >> f[i];
L(i, 0, m - 1) cin >> g[i];
div(f, g, ansa, ansb, n, m);
L(i, 0, n - m) cout << ansa[i] << " ";
cout << endl;
L(i, 0, m - 2) cout << ansb[i] << " ";
cout << endl;
return 0;
}
祝大家学习愉快!