高级操作,感觉非常神仙。
题目中的字母太难懂了,重新定义一下。
$$A(x) = B(x) * C(x) + D(x)$$
其中,$A(x)$的次数是$n$,$B(x)$的次数是$m$,$A, B$都已知,要求$C$的次数是$n - m$,$D$的次数小于$m$。
定义一种操作:
如果$A$的次数为$n$,那么
$$A_R(x) = x^nA(frac{1}{x})$$
其实就是把一个多项式各项系数翻转过来。
比如$A(x) = x^3 + 2x^2 + 3x + 5$,有
$$A_R(x) = 5x^3 + 3x^2 + 2x + 1$$
有了这个操作之后就可以变魔术了。
发现这个$D(x)$很难受,想办法把它搞掉。
$$A(x) = B(x) * C(x) + D(x)$$
$$A(frac{1}{x}) = B(frac{1}{x})C(frac{1}{x}) + D(frac{1}{x})$$
两边乘上$x^n$,
$$x^nA(frac{1}{x}) = x^mB(frac{1}{x})x^{n - m}C(frac{1}{x}) + x^{n - m + 1} * x^{m - 1}D(frac{1}{x})$$
$$A_R(x) = B_R(x)C_R(x) + x^{n - m + 1} D_R(x)$$
两边模上$x^{n - m + 1}$,
$$A_R(x) equiv B_R(x)C_R(x) (mod x^{n - m + 1}) $$
其实就是求逆了呀。
求出$C(x)$之后只要重新算一遍多项式乘法减掉就可以算出$D(x)$了。
时间复杂度$O(nlogn)$。
Code:
data:image/s3,"s3://crabby-images/6da44/6da44a3c422e49abcf1dae786223d28e774e2de6" alt=""
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; typedef long long ll; typedef vector <ll> poly; const int N = 1e5 + 5; namespace Poly { const int L = 1 << 18; const ll P = 998244353LL; int lim, pos[L]; inline void deb(poly c) { for (int i = 0; i < (int)c.size(); i++) printf("%lld%c", c[i], " "[i == (int)c.size() - 1]); } template <typename T> inline void inc(T &x, T y) { x += y; if (x >= P) x -= P; } template <typename T> inline void sub(T &x, T y) { x -= y; if (x < 0) x += P; } inline ll fpow(ll x, ll y) { ll res = 1; for (; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } inline void prework(int len) { int l = 0; for (lim = 1; lim < len; lim <<= 1, ++l); for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)); } inline void ntt(poly &c, int opt) { c.resize(lim, 0); for (int i = 0; i < lim; i++) if (i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(3, (P - 1) / (i << 1)); if (opt == -1) wn = fpow(wn, P - 2); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = w * c[j + k + i] % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = fpow(lim, P - 2); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } inline poly mul(const poly x, const poly y) { poly u = x, v = y, res; prework(x.size() + y.size() - 1); ntt(u, 1), ntt(v, 1); for (int i = 0; i < lim; i++) res.push_back(u[i] * v[i] % P); ntt(res, -1); res.resize(x.size() + y.size() - 1); return res; } poly getInv(poly x, int len) { x.resize(len, 0); if (len == 1) { poly res; res.push_back(fpow(x[0], P - 2)); return res; } poly y = getInv(x, (len + 1) >> 1); prework(len << 1); poly u = x, v = y, res; ntt(u, 1), ntt(v, 1); for (int i = 0; i < lim; i++) res.push_back(v[i] * ((2LL - u[i] * v[i] % P + P) % P) % P); ntt(res, -1); res.resize(len, 0); return res; } inline poly getDiv(poly x, poly y) { poly u = x, v = y, res; reverse(u.begin(), u.end()); reverse(v.begin(), v.end()); // deb(u), deb(v); int len = x.size() - y.size() + 1; u.resize(len, 0), v.resize(len, 0); // deb(u), deb(v); res = getInv(v, len); // deb(res); res = mul(u, res); res.resize(len); // deb(u); reverse(res.begin(), res.end()); return res; } inline poly getRest(poly x, poly y) { poly u = x, v = y, res = getDiv(x, y); v = mul(v, res); res = u; int len = max(u.size(), v.size()); res.resize(len, 0), v.resize(len, 0); for (int i = 0; i < len; i++) res[i] = (res[i] - v[i] + P) % P; for (; !res.empty() && !res[res.size() - 1]; res.pop_back()); return res; } } using Poly :: getDiv; using Poly :: getRest; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9'|| ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } int main() { int n, m; read(n), read(m); ++n, ++m; poly x, y; x.resize(n, 0), y.resize(m, 0); for (int i = 0; i < n; i++) read(x[i]); for (int i = 0; i < m; i++) read(y[i]); poly u = getDiv(x, y), v = getRest(x, y); for (int i = 0; i < (int)u.size(); i++) printf("%lld%c", u[i], " "[i == (int)u.size() - 1]); for (int i = 0; i < (int)v.size(); i++) printf("%lld%c", v[i], " "[i == (int)v.size() - 1]); return 0; }