zoukankan      html  css  js  c++  java
  • P4723 【模板】常系数齐次线性递推 题解

    祭奠我逝去的一下午加一晚上=-=

    Description

    Luogu传送门

    顺便 sto \(Zhang\_RQ\)学长 orz

    Solution

    这名字听起来就很高大上的样子(事实上确实如此

    好吧其实我并没有打算推式子,因为 \(BJpers2\) 巨佬在他的题解中已经把式子推的很明白了,只是他的代码着实有些毒瘤,因此我这里只是想放下我自己的代码罢了。

    简单说两句,推出特征多项式 \(p(x)\) 之后,就是要求 \(x^n \ \ mod \ \ p(x)\),然而我们这个 \(n\)\(10^9\),没办法直接放到多项式里算,所以采用快速幂的思想。

    快速幂过程

    设最开始的多项式 \(t(x) = 1\),倍增往上跳,最大情况下 \(t\) 会是一个 \(k - 1\) 次的多项式乘一下就会变成 \(2 \times k - 2\) 次,然后去对 \(k\) 次的多项式 \(p(x)\) 取模。

    次数体现在代码里的话,就是 \(Mod\) 函数中传的实参是 \(n << 1\)

    坑点

    1. 数组最好都开成局部变量,不然就各种错乱(我一开始用的全局变量数组就一直都是 0).
    2. 边界!边界!边界!好吧,说实话这玩意就算知道有坑点也没啥用,就算让我再写一遍可能也得调半天。

    废话不多说,上代码吧,希望对您有帮助。

    Code(大常数警告)

    #include <bits/stdc++.h>
    #define ll long long
    
    using namespace std;
    
    namespace IO{
        inline ll read(){
            ll x = 0, f = 1;
            char ch = getchar();
            while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
            while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
            return x * f;
        }
    
        template <typename T> inline void write(T x){
            if(x < 0) putchar('-'), x = -x;
            if(x > 9) write(x / 10);
            putchar(x % 10 + '0');
        }
    
        inline void print(ll a[], ll n){
            for(int i = 0; i <= n; ++i) printf("%lld ", a[i]);
            puts("");
        }
    }
    using namespace IO;
    
    const ll N = 5e5 + 10;
    const ll mod = 998244353;
    const ll G = 3, Gi = 332748118;
    ll n, m, k;
    ll a[N], b[N], c[N], d[N], e[N], p[N], res[N], t[N];
    ll f[N], g[N], ig[N], q[N], r[N];
    
    namespace NTT{
        ll lim, len;
    
        inline ll qpow(ll a, ll b){
            ll res = 1;
            while(b){
                if(b & 1) res = res * a % mod;
                a = a * a % mod, b >>= 1;
            }
            return res;
        }
    
        inline void get_rev(ll n){
            lim = 1, len = 0;
            while(lim <= n) lim <<= 1, ++len;
            for(int i = 0; i <= lim; ++i) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (len - 1));
        }
    
        inline void ntt(ll A[], ll lim, ll type){
            for(int i = 0; i <= lim; ++i)
                if(i < p[i]) swap(A[i], A[p[i]]);
            for(int mid = 1; mid < lim; mid <<= 1){
                ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
                for(int i = 0; i < lim; i += (mid << 1)){
                    ll w = 1;
                    for(int j = 0; j < mid; ++j, w = w * Wn % mod){
                        ll x = A[i + j], y = w * A[i + j + mid] % mod;
                        A[i + j] = (x + y) % mod;
                        A[i + j + mid] = (x - y + mod) % mod;
                    }
                }
            }
            if(type == 1) return;
            ll inv = qpow(lim, mod - 2);
            for(int i = 0; i <= lim; ++i) A[i] = A[i] * inv % mod;
        }
    
        inline void Mul(ll n, ll m, ll a[], ll b[], bool flag = 1){
            static ll d[N], e[N];
            for(int i = 0; i < (n << 2); ++i) d[i] = e[i] = 0;
            for(int i = 0; i < n; ++i) d[i] = a[i], e[i] = b[i];
            get_rev(n + m);
            ntt(d, lim, 1), ntt(e, lim, 1);
            for(int i = 0; i < lim; ++i) d[i] = d[i] * e[i] % mod;
            ntt(d, lim, -1);
            for(int i = 0; i < (n << 1); ++i) a[i] = d[i];
            for(int i = (n << 1); i <= lim; ++i) a[i] = 0;
            if(flag) for(int i = n; i < (n << 1); ++i) a[i] = 0;
    
        }
    
        inline void Inv(ll n, ll a[], ll b[]){
            if(!n) return b[0] = qpow(a[0], mod - 2), void();
            Inv(n >> 1, a, b);
            get_rev(n << 1);
            for(int i = 0; i <= n; ++i) c[i] = a[i];
            for(int i = n + 1; i <= lim; ++i) c[i] = 0;
            ntt(c, lim, 1), ntt(b, lim, 1);
            for(int i = 0; i <  lim; ++i) b[i] = (2ll - c[i] * b[i] % mod + mod) * b[i] % mod;
            ntt(b, lim, -1);
            for(int i = n + 1; i <= lim; ++i) b[i] = 0;
        }
    }
    using namespace NTT;
    
    inline void Mod(ll n, ll m, ll f[], ll g[], ll r[]){
        static ll a[N], b[N];
        for(int i = 0; i < (n << 2); ++i) a[i] = b[i] = d[i] = 0;
        for(int i = 0; i < n - m + 1; ++i) a[i] = f[n - i - 1];
        for(int i = 0; i < n - m + 1; ++i) b[i] = g[m - i - 1];
    
        Inv(n - m + 1, b, d);
        Mul(n - m + 1, n - m + 1, a, d);
        for(int i = 0; i <= n - m; ++i) q[i] = a[n - m - i];
    
        for(int i = 0; i < (n << 2); ++i) a[i] = b[i] = 0;
        for(int i = 0; i < n; ++i) a[i] = f[i];
        for(int i = 0; i < m; ++i) b[i] = g[i];
        Mul(n, n, b, q);
    
        for(int i = 0; i < m - 1; ++i) r[i] = (a[i] - b[i] + mod) % mod;
        for(int i = m - 1; i < lim; ++i) r[i] = 0;
    }
    
    inline void solve(ll p, ll n){
        t[1] = res[0] = 1;
        while(p){
            if(p & 1) Mul(n, n, res, t, 0), Mod(n << 1, n, res, g, res);// b % g --> b
            Mul(n, n, t, t, 0), Mod(n << 1, n, t, g, t);
            p >>= 1;
        }
    }
    
    signed main(){
        // freopen("P4723.in", "r", stdin);
        // freopen("P4723.out", "w", stdout);
        n = read(), m = read();
        g[0] = 1;
        for(int i = 1; i <= m; ++i) g[i] = (mod - (read() % mod + mod) % mod);
        reverse(g, g + 1 + m);
        for(int i = 0; i < m; ++i) f[i] = read();
        solve(n, m + 1);
        ll ans = 0;
        for(int i = 0; i < m; ++i) ans = (ans + res[i] * f[i] % mod + mod) % mod;
        write(ans), puts("");
        return 0;
    }
    

    \[\_EOF\_ \]

  • 相关阅读:
    三个习题
    20 python--celery
    19 python --队列
    18 python --多线程
    17 python --多进程
    16 python --memcached
    15 python --redis
    14 python --mysql
    13 python --正则
    12 python --json
  • 原文地址:https://www.cnblogs.com/xixike/p/15626626.html
Copyright © 2011-2022 走看看