zoukankan      html  css  js  c++  java
  • 模板

    #include<bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int MAXN = 3e5 + 51, MOD = 998244353, G = 3, INVG = 332748118;
    int exponent, fd, cnt = 1, limit = -1, rres, ptr;
    int rev[MAXN], f[MAXN], g[MAXN], tmp[MAXN], tmp2[MAXN], tmp3[MAXN], tbm[MAXN];
    int res[MAXN], base[MAXN], fail[MAXN];
    ll delta[MAXN];
    
    inline int read() {
        int num = 0;
        bool neg = false;
        char ch = getchar();
        while(!isdigit(ch) && ch != '-')
            ch = getchar();
        if(ch == '-')
            neg = true, ch = getchar();
        while(isdigit(ch))
            num = (num << 3) + (num << 1) + (ch - '0'), ch = getchar();
        return neg ? -num : num;
    }
    
    inline int qpow(ll x, int n) {
        ll res = 1;
        for(; n; x = x * x % MOD, n >>= 1)
            if(n & 1)
                res = res * x % MOD;
        return res;
    }
    
    inline void NTT(int *cp, int cnt, int inv) {
        int cur = 0, res = 0;
        for(int i = 0; i < cnt; i++)
            if(i < rev[i])
                swap(cp[i], cp[rev[i]]);
    
        for(int i = 2; i <= cnt; i <<= 1) {
            cur = i >> 1, res = qpow(inv == 1 ? G : INVG, (MOD - 1) / i);
            for(int *p = cp; p != cp + cnt; p += i) {
                ll w = 1;
                for(int j = 0; j < cur; j++) {
                    int t = w * p[j + cur] % MOD, t2 = p[j];
                    p[j + cur] = (t2 - t + MOD) % MOD, p[j] = (t2 + t) % MOD;
                    w = w * res % MOD;
                }
            }
        }
    
        if(inv == -1) {
            int invl = qpow(cnt, MOD - 2);
            for(int i = 0; i <= cnt; i++)
                cp[i] = (ll) cp[i] * invl % MOD;
        }
    }
    
    inline void inv(int fd, int *f, int *res) {
        static int tmp[MAXN];
        if(fd == 1) {
            res[0] = qpow(f[0], MOD - 2);
            return;
        }
        inv((fd + 1) >> 1, f, res);
        int cnt = 1, limit = -1;
        while(cnt < (fd << 1))
            cnt <<= 1, limit++;
        for(int i = 0; i < cnt; i++) {
            tmp[i] = i < fd ? f[i] : 0;
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << limit);
        }
        NTT(tmp, cnt, 1), NTT(res, cnt, 1);
        for(int i = 0; i < cnt; i++)
            res[i] = 1ll * (2 - 1ll * tmp[i] * res[i] % MOD + MOD) % MOD * res[i] % MOD;
        NTT(res, cnt, -1);
        for(int i = fd; i < cnt; i++)
            res[i] = 0;
    }
    
    inline void mod(int *f) {
        static int tmp[MAXN], q[MAXN];
        int deg = fd << 1;
        while(!f[--deg]);
        if(deg < fd)
            return;
    
        for(int i = 0; i < cnt; i++)
            tmp[i] = i <= deg ? f[i] : 0;
        reverse(tmp, tmp + 1 + deg);
        for(int i = deg + 1 - fd; i <= deg; tmp[i] = 0, i++);
        NTT(tmp, cnt, 1);
        for(int i = 0; i < cnt; q[i] = (ll)tmp[i] * tmp3[i] % MOD, i++);
        NTT(q, cnt, -1);
        for(int i = 0; i < cnt; tmp[i] = 0, q[i] = i <= deg - fd ? q[i] : 0, i++);
        reverse(q, q + 1 + deg - fd), NTT(q, cnt, 1);
        for(int i = 0; i < cnt; tmp[i] = (ll)q[i] * g[i] % MOD, i++);
        NTT(tmp, cnt, -1);
        for(int i = 0; i < fd; f[i] = (f[i] - tmp[i] + MOD) % MOD, i++);
        for(int i = 0; i < cnt; q[i] = tmp[i] = 0, f[i] = i < fd ? f[i] : 0, i++);
    }
    
    vector<ll>bmf[MAXN];
    inline void BerlekampMassey(int length, int *base, int *res) {
        int cur = 0;
        for(int i = 1; i <= length; i++) {
            ll curr = base[i];
            for(int j = 0; j < bmf[cur].size(); j++) {
                curr = (curr - (ll)base[i - j - 1] * bmf[cur][j] % MOD) % MOD;
            }
            delta[i] = curr;
            if(!delta[i]) {
                continue;
            }
            fail[cur] = i;
            if(!cur) {
                bmf[++cur].resize(i), delta[i] = base[i];
                continue;
            }
            int id = cur - 1, x = bmf[id].size() - fail[id] + i;
            for(int j = 0; j < cur; j++) {
                if(i - fail[j] + bmf[j].size() < x) {
                    id = j, x = i - fail[j] + bmf[j].size();
                }
            }
            bmf[cur + 1] = bmf[cur], cur++;
            while(bmf[cur].size() < x) {
                bmf[cur].push_back(0);
            }
            ll mul = (ll)delta[i] * qpow(delta[fail[id]], MOD - 2) % MOD;
            bmf[cur][i - fail[id] - 1] = (ll)(bmf[cur][i - fail[id] - 1] + mul) % MOD;
            for(int j = 0; j < bmf[id].size(); j++) {
                int t = (ll)mul * bmf[id][j] % MOD;
                bmf[cur][i - fail[id] + j] = (bmf[cur][i - fail[id] + j] - t + MOD) % MOD;
            }
        }
        ptr = cur;
        for(int i = 0; i < bmf[cur].size(); i++) {
            res[i + 1] = (bmf[cur][i] % MOD + MOD) % MOD;
        }
    }
    int main() {
    #ifdef Yinku
        freopen("Yinku.in", "r", stdin);
    #endif // Yinku
        fd = read(), exponent = read();
        for(int i = 0; i < fd; i++)
            tbm[i + 1] = f[i] = (read() + MOD) % MOD;
    
        BerlekampMassey(fd, tbm, tmp);
        for(int i = 1, ci = bmf[ptr].size(); i <= ci; i++)
            printf("%d%c", tmp[i], " 
    "[i == ci]);
    
        for(int i = 1; i <= fd; g[fd - i] = MOD - tmp[i], i++);
        g[fd] = 1;
        for(int i = 0; i <= fd; i++)
            tmp2[i] = g[i];
    
        reverse(tmp2, tmp2 + 1 + fd), inv(fd << 1, tmp2, tmp3);
        for(int i = 0; i <= fd; i++)
            tmp2[i] = 0;
    
        while(cnt < (fd << 2))
            cnt <<= 1, limit++;
    
        for(int i = 0; i < cnt; i++)
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << limit);
    
        NTT(g, cnt, 1), NTT(tmp3, cnt, 1), base[1] = res[0] = 1;
        while(exponent) {
            if(exponent & 1) {
                NTT(res, cnt, 1), NTT(base, cnt, 1);
                for(int i = 0; i < cnt; i++)
                    res[i] = (ll)res[i] * base[i] % MOD;
                NTT(res, cnt, -1), NTT(base, cnt, -1), mod(res);
            }
            NTT(base, cnt, 1);
            for(int i = 0; i < cnt; i++)
                base[i] = (ll)base[i] * base[i] % MOD;
            NTT(base, cnt, -1), mod(base), exponent >>= 1;
        }
        for(int i = 0; i < fd; i++)
            rres = (rres + (ll)res[i] * f[i] % MOD) % MOD;
        printf("%d
    ", rres);
    }
    
  • 相关阅读:
    牛客代码测试栈深度
    "Coding Interview Guide" -- 在行列都排好序的矩阵中找数
    "Coding Interview Guide" -- 括号字符串的有效性和最长有效长度
    "Coding Interview Guide" -- 将正方形矩阵顺时针转动90°
    "Coding Interview Guide" -- 按照左右半区的方式重新组合单链表
    "Coding Interview Guide" -- 先序、中序和后序数组两两结合重构二叉树
    "Coding Interview Guide" -- 只用位运算不用算术运算实现整数的加减乘除运算
    "Coding Interview Guide" -- 从N个数中等概率打印M个数
    "Coding Interview Guide" -- 判断字符数组中是否所有的字符都只出现过一次
    "Coding Interview Guide" -- 字符串的统计字符串
  • 原文地址:https://www.cnblogs.com/Inko/p/11747894.html
Copyright © 2011-2022 走看看