zoukankan      html  css  js  c++  java
  • Luogu4705 玩游戏 分治FFT

    传送门


    (egin{align*} Ans_k &= sumlimits_{i=1}^nsumlimits_{j=1}^m (a_i + b_j)^k \ &= sumlimits_{i=1}^n sumlimits_{j=1}^m sumlimits_{p=0}^k inom{k}{p} a_i^p b_j^{k-p} \ &= k! sumlimits_{p=0}^k frac{sumlimits_{i=1}^n a_i^p}{p!}frac{sumlimits_{j=1}^m b_j^{k-p}}{(k-p)!} end{align*})

    最后的式子是一个卷积的形式,所以我们只需要求出多项式(A(x) = sumlimits_{p=0}^T sumlimits_{i=1}^n a_i^p x^p)(B(x) = sumlimits_{p=0}^T sumlimits_{i=1}^m b_i^p x^p)的各项系数,就可以通过一次卷积求出答案。

    显然的一件事情是(A(x))(B(x))求法一样,所以我们只考虑(A(x))

    不难发现设(F_i(x) = sumlimits_{j=0}^infty a_i^jx^j = frac{1}{1-a_ix}),那么(A(x) = sumlimits_{i=1}^n F_i(x))。直接暴力通分复杂度显然是(O(n^2)),但是不难发现和若干个二次多项式相乘一样,这个可以使用分治FFT优化通分,复杂度可以变为(O(nlog^2n))

    但是因为分治FFT优化通分还是太慢所以及其容易被卡常,可以考虑当分治区间小到某个值以下时暴力通分。

    多项式Ln的做法学不来qwq

    #include<bits/stdc++.h>
    //this code is written by Itst
    using namespace std;
    
    int read(){
        int a = 0; char c = getchar();
        while(!isdigit(c)) c = getchar();
        while(isdigit(c)){
            a = a * 10 + c - 48; c = getchar();
        }
        return a;
    }
    
    const int _ = (1 << 18) + 7 , MOD = 998244353;
    
    int poww(long long a , int b){
        int times = 1;
        while(b){
            if(b & 1) times = times * a % MOD;
            a = a * a % MOD; b >>= 1;
        }
        return times;
    }
    
    #define VI vector < int > 
    
    namespace poly{
        const int G = 3 , INV = 332748118;
        int dir[_] , need , invnd;
    
        void init(int len){
            need = 1;
            while(need < len) need <<= 1;
            invnd = poww(need , MOD - 2);
            for(int i = 1 ; i < need ; ++i)
                dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
        }
    
        void NTT(VI &arr , int tp){
            arr.resize(need);
            for(int i = 1 ; i < need ; ++i)
                if(i < dir[i])
                    arr[i] ^= arr[dir[i]] ^= arr[i] ^= arr[dir[i]];
            for(int i = 1 ; i < need ; i <<= 1){
                int wn = poww(tp == 1 ? G : INV , MOD / i / 2);
                for(int j = 0 ; j < need ; j += i << 1){
                    long long w = 1;
                    for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
                        int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
                        arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
                        arr[i + j + k] = x < y ? x + MOD - y : x - y;
                    }
                }
            }
            if(tp != 1){
                for(int i = 0 ; i < need ; ++i)
                    arr[i] = 1ll * arr[i] * invnd % MOD;
                int len = arr.size();
                while(len && arr[len - 1] == 0) --len;
                arr.resize(len);
            }
        }
    
        void mul(VI a , VI b , VI &c){
            int l = (int)(a.size() + b.size()) - 1;
            init(l); NTT(a , 1); NTT(b , 1); c.resize(need);
            for(int i = 0 ; i < need ; ++i)
                c[i] = 1ll * a[i] * b[i] % MOD;
            NTT(c , -1); c.resize(l);
        }
    
        void getInv(VI a , VI &b , int len){
            if(len == 1){b.resize(a.size()); return (void)(b[0] = poww(a[0] , MOD - 2));}
            getInv(a , b , (len + 1) >> 1);
            vector < int > A = a , B = b; A.resize(len); B.resize(len);
            init(A.size() + B.size() + 1); NTT(A , 1); NTT(B , 1);
            for(int i = 0 ; i < need ; ++i)
                A[i] = 1ll * A[i] * B[i] % MOD * B[i] % MOD;
            NTT(A , -1);
            for(int i = 0 ; i < len ; ++i)
                b[i] = (2ll * b[i] - A[i] + MOD) % MOD;
        }
    }
    using namespace poly;
    
    #define mid ((l + r) >> 1)
    #define lch (x << 1)
    #define rch (x << 1 | 1)
    vector < int > arr[_ << 2][2];
    int a[_] , b[_] , jc[_] , inv[_] , N , M , T;
    
    void solve(int x , int l , int r , int *num){
        arr[x][0].clear(); arr[x][1].clear();
        if(r - l <= 30){
            vector < int > &A = arr[x][0] , &B = arr[x][1] , tpa , tpb;
            init((r - l + 1) * 2); A.push_back(0); B.push_back(1);
            NTT(A , 1); NTT(B , 1);
            for(int i = l ; i <= r ; ++i){
                tpa.clear(); tpb.clear();
                tpa.push_back(1); tpb.push_back(1); tpb.push_back(MOD - num[i]);
                NTT(tpa , 1); NTT(tpb , 1);
                for(int i = 0 ; i < need ; ++i){
                    A[i] = (1ll * A[i] * tpb[i] + 1ll * B[i] * tpa[i]) % MOD;
                    B[i] = 1ll * B[i] * tpb[i] % MOD;
                }
            }
            NTT(A , -1); NTT(B , -1);
        }
        else{
            solve(lch , l , mid , num); solve(rch , mid + 1 , r , num);
            int l = arr[lch][1].size() + arr[rch][1].size();
            init(l); vector < int > a , b , c , d;
            arr[x][0].resize(need); arr[x][1].resize(need);
            a = arr[lch][0]; b = arr[lch][1]; c = arr[rch][0]; d = arr[rch][1];
            NTT(a , 1); NTT(b , 1); NTT(c , 1); NTT(d , 1);
            for(int i = 0 ; i < need ; ++i){
                arr[x][0][i] = (1ll * a[i] * d[i] + 1ll * b[i] * c[i]) % MOD;
                arr[x][1][i] = 1ll * b[i] * d[i] % MOD;
            }
            NTT(arr[x][0] , -1); NTT(arr[x][1] , -1);
        }
    }
    
    void init(){
        jc[0] = 1;
        for(int i = 1 ; i <= T ; ++i)
            jc[i] = 1ll * jc[i - 1] * i % MOD;
        inv[T] = poww(jc[T] , MOD - 2);
        for(int i = T - 1 ; i >= 0 ; --i)
            inv[i] = inv[i + 1] * (i + 1ll) % MOD;
    }
    
    int main(){
    #ifndef ONLINE_JUDGE
        freopen("in","r",stdin);
        //freopen("out","w",stdout);
    #endif
        N = read(); M = read();
        for(int i = 1 ; i <= N ; ++i) a[i] = read();
        for(int i = 1 ; i <= M ; ++i) b[i] = read();
        T = read(); init();
        vector < int > tp1 , tp2;
        solve(1 , 1 , N , a); arr[1][1].resize(T + 1); getInv(arr[1][1] , tp1 , T + 1);
        mul(tp1 , arr[1][0] , tp1); tp1.resize(T + 1);
        solve(1 , 1 , M , b); arr[1][1].resize(T + 1); getInv(arr[1][1] , tp2 , T + 1);
        mul(tp2 , arr[1][0] , tp2); tp2.resize(T + 1);
        for(int i = 0 ; i <= T ; ++i){
            tp1[i] = 1ll * tp1[i] * inv[i] % MOD;
            tp2[i] = 1ll * tp2[i] * inv[i] % MOD;
        }
        mul(tp1 , tp2 , tp1);
        int Inv = poww(1ll * N * M % MOD , MOD - 2);
        for(int i = 1 ; i <= T ; ++i)
            printf("%lld
    " , 1ll * tp1[i] * jc[i] % MOD * Inv % MOD);
        return 0;
    }
    
  • 相关阅读:
    堆排序
    conda 安装pytorch
    Dev GridControl GridView常用属性
    java 同步调用和异步调用
    spring Boot 整合 Memcached (含 windows 安装)
    spring Boot 整合 Elasticsearch
    windows 下安装 elasticsearch
    代理模式---Cglib动态代理
    代理模式---JDK动态代理
    代理模式---静态代理
  • 原文地址:https://www.cnblogs.com/Itst/p/10989995.html
Copyright © 2011-2022 走看看