zoukankan      html  css  js  c++  java
  • 大型大常数多项式模板(已卡常...)

    # include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    typedef vector <int> Poly;
    
    const int mod(998244353);
    const int inv2(499122177);
    const int maxn(1 << 18);
    
    inline void Inc(int &x, const int y) {
        x = x + y >= mod ? x + y - mod : x + y;
    }
    
    inline void Dec(int &x, const int y) {
        x = x - y < 0 ? x - y + mod : x - y;
    }
    
    inline int Add(int x, const int y) {
        return x + y >= mod ? x + y - mod : x + y;
    }
    
    inline int Sub(int x, const int y) {
        return x - y < 0 ? x - y + mod : x - y;
    }
    
    inline int Pow(ll x, int y) {
        ll ret = 1;
        for (; y; y >>= 1, x = x * x % mod)
            if (y & 1) ret = ret * x % mod;
        return ret;
    }
    
    namespace NTT {
    	int w[2][maxn], r[maxn], l, deg;
    
    	inline void Init(int len) {
    		int i, x, y;
    		for (l = 0, deg = 1; deg < len; deg <<= 1) ++l;
    		for (i = 0; i < deg; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    		x = Pow(3, (mod - 1) / deg), y = Pow(x, mod - 2), w[0][0] = w[1][0] = 1;
    		for (i = 1; i < deg; ++i) w[0][i] = (ll)w[0][i - 1] * x % mod, w[1][i] = (ll)w[1][i - 1] * y % mod;
    	}
    
    	inline void DFT(int *p, int opt) {
    		int i, j, k, t, wn, x, y;
    		for (i = 0; i < deg; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
    		for (i = 1; i < deg; i <<= 1)
    			for (t = i << 1, j = 0; j < deg; j += t)
    				for (k = 0; k < i; ++k) {
    					wn = w[opt == -1][deg / t * k];
    					x = p[j + k], y = (ll)p[j + k + i] * wn % mod;
    					p[j + k] = Add(x, y), p[j + k + i] = Sub(x, y);
    				}
    		if (opt == -1) for (i = 0, wn = Pow(deg, mod - 2); i < deg; ++i) p[i] = (ll)p[i] * wn % mod;
    	}
    }
    
    using NTT :: Init;
    using NTT :: DFT;
    
    void Inv(int *p, int *q, int len) {
        if (len == 1) {
            q[0] = Pow(p[0], mod - 2);
            return;
        }
        Inv(p, q, len >> 1);
        static int a[maxn], b[maxn];
        int tmp = len << 1, i;
        Init(tmp);
        for (i = 0; i < tmp; ++i) a[i] = b[i] = 0;
        for (i = 0; i < len; ++i) a[i] = p[i], b[i] = q[i];
        DFT(a, 1), DFT(b, 1);
        for (i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod * b[i] % mod;
        DFT(a, -1);
        for (i = 0; i < len; ++i) q[i] = Sub(Add(q[i], q[i]), a[i]);
    }
    
    inline void Calc(int *p, int *q, int len) {
        int i;
        for (i = len - 2; ~i; --i) q[i + 1] = (ll)p[i] * Pow(i + 1, mod - 2) % mod;
        q[0] = 0;
    }
    
    inline void ICalc(int *p, int *q, int len) {
        int i;
        for (i = len - 2; ~i; --i) q[i] = (ll)p[i + 1] * (i + 1) % mod;
        q[len - 1] = 0;
    }
    
    inline void Ln(int *p, int *q, int len) {
        static int a[maxn], b[maxn];
        int tmp = len << 1, i;
        for (i = 0; i < tmp; ++i) a[i] = b[i] = 0;
        ICalc(p, a, len), Inv(p, b, len);
        DFT(a, 1), DFT(b, 1);
        for (i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod;
        DFT(a, -1), Calc(a, q, len);
    }
    
    void Exp(int *p, int *q, int len) {
        if (len == 1) {
            q[0] = 1;
            return;
        }
        Exp(p, q, len >> 1);
        static int a[maxn], b[maxn];
        int tmp = len << 1, i;
        Init(tmp);
        for (i = 0; i < tmp; ++i) a[i] = b[i] = 0;
        Ln(q, a, len);
        for (i = 0; i < len; ++i) a[i] = Sub(p[i], a[i]), b[i] = q[i];
        Inc(a[0], 1), DFT(a, 1), DFT(b, 1);
        for (i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod;
        DFT(a, -1);
        for (i = 0; i < len; ++i) q[i] = a[i];
    }
    
    void Sqrt(int *p, int *q, int len) {
        if (len == 1) {
            q[0] = sqrt(p[0]);
            return;
        }
        Sqrt(p, q, len >> 1);
        int i, tmp = len << 1;
    	static int a[maxn], b[maxn];
    	for (i = 0; i < tmp; ++i) a[i] = b[i] = 0;
    	Inv(q, b, len);
        for (i = 0; i < len; ++i) a[i] = p[i];
        Init(tmp), DFT(a, 1), DFT(b, 1);
        for (i = 0; i < tmp; ++i) a[i] = (ll)a[i] * b[i] % mod;
        DFT(a, -1);
        for (i = 0; i < len; ++i) q[i] = (ll)Add(q[i], a[i]) % mod * inv2 % mod;
    }
    
    inline Poly operator +(const Poly &a, const Poly &b) {
        int n = a.size(), m = b.size(), i, l;
    	Poly c(l = max(n, m));
    	for (i = 0; i < n; ++i) c[i] = a[i];
    	for (i = 0; i < m; ++i) Inc(c[i], b[i]);
        return c;
    }
    
    inline Poly operator -(const Poly &a, const Poly &b) {
        int n = a.size(), m = b.size(), i, l;
    	Poly c(l = max(n, m));
    	for (i = 0; i < n; ++i) c[i] = a[i];
    	for (i = 0; i < m; ++i) Dec(c[i], b[i]);
        return c;
    }
    
    inline Poly operator *(const Poly &a, const int b) {
        int n = a.size(), i;
    	Poly c(n);
    	for (i = 0; i < n; ++i) c[i] = (ll)a[i] * b % mod;
        return c;
    }
    
    inline Poly operator *(const Poly &a, const Poly &b) {
    	int n = a.size(), m = b.size(), l = n + m - 1, i, len;
    	Poly c(l);
    	static int x[maxn], y[maxn];
    	Init(l), len = NTT :: deg;
    	for (i = 0; i < len; ++i) x[i] = y[i] = 0;
    	for (i = 0; i < n; ++i) x[i] = a[i];
    	for (i = 0; i < m; ++i) y[i] = b[i];
    	DFT(x, 1), DFT(y, 1);
    	for (i = 0; i < len; ++i) x[i] = (ll)x[i] * y[i] % mod;
    	DFT(x, -1);
    	for (i = 0; i < l; ++i) c[i] = x[i];
        return c;
    }
    
    inline Poly operator %(const Poly &a, const Poly &b) {
        if (a.size() < b.size()) return a;
        Poly x = a, y = b, z;
        int n = a.size(), m = b.size(), res = n - m + 1, len;
    	x = a, y = b, reverse(x.begin(), x.end()), reverse(y.begin(), y.end());
    	for (len = 1; len < res; len <<= 1);
    	x.resize(len), y.resize(len), z.resize(len);
    	Inv(y.data(), z.data(), len), x = x * z;
        x.resize(res), reverse(x.begin(), x.end());
        y = a - x * b, y.resize(m - 1);
        return y;
    }
    
    Poly f[maxn], a, b;
    int n, m, x[maxn], y[maxn], ans[maxn];
    
    inline int Calc(const Poly v, const int x) {
    	int i, n = v.size(), t = 1, ret = 0;
    	for (i = 0; i < n; ++i) Inc(ret, (ll)t * v[i] % mod), t = (ll)t * x % mod;
    	return ret;
    }
    
    void Build(int o, int l, int r) {
        if (l == r) {
            f[o].resize(2), f[o][0] = mod - x[l], f[o][1] = 1;
            return;
        }
        int mid = (l + r) >> 1;
        Build(o << 1, l, mid), Build(o << 1 | 1, mid + 1, r);
        f[o] = f[o << 1] * f[o << 1 | 1];
    }
    
    void Solve_val(Poly cur, int o, int l, int r) {
        if (r - l + 1 <= 2000) {
            for (; l <= r; ++l) ans[l] = 1LL * y[l] * Pow(Calc(cur, x[l]), mod - 2) % mod;
            return;
        }
        int mid = (l + r) >> 1;
        Solve_val(cur % f[o << 1], o << 1, l, mid);
        Solve_val(cur % f[o << 1 | 1], o << 1 | 1, mid + 1, r);
    }
    
    void Solve(Poly &cur, int o, int l, int r) {
        if (l == r) {
            cur[0] = ans[l];
            return;
        }
        int mid = (l + r) >> 1;
        Poly lp(mid - l + 1), rp(r - mid);
        Solve(lp, o << 1, l, mid);
        Solve(rp, o << 1 | 1, mid + 1, r);
        cur = lp * f[o << 1 | 1] + rp * f[o << 1];
    }
    
    inline void Lagrange() {
        int i, len;
        scanf("%d", &n);
        for (i = 1; i <= n; ++i) scanf("%d%d", &x[i], &y[i]);
        Build(1, 1, n), a = f[1], len = a.size();
        for (i = 0; i < len - 1; ++i) a[i] = (ll)a[i + 1] * (i + 1) % mod;
        if (a.size() > 1) a.pop_back();
        else a[0] = 0;
        b.resize(n), Solve_val(a, 1, 1, n), Solve(b, 1, 1, n);
        for (i = 0; i < n; ++i) printf("%d ", b[i]);
        puts("");
    }
    
    int main() {
    	return 0;
    }
    
  • 相关阅读:
    HDU 5313 bitset优化背包
    bzoj 2595 斯坦纳树
    COJ 1287 求匹配串在模式串中出现的次数
    HDU 5381 The sum of gcd
    POJ 1739
    HDU 3377 插头dp
    HDU 1693 二进制表示的简单插头dp
    HDU 5353
    URAL 1519 基础插头DP
    UVA 10294 等价类计数
  • 原文地址:https://www.cnblogs.com/cjoieryl/p/10158721.html
Copyright © 2011-2022 走看看