zoukankan      html  css  js  c++  java
  • LOJ2527 HAOI2018 染色 生成函数、二项式反演、NTT

    传送门


    调了1h竟然是因为1004535809写成了998244353

    “恰好有(K)种颜色出现了(S)次”的限制似乎并不容易达到,考虑容斥计算。

    (c_j)表示强制(j)种颜色恰好出现(S)次,其他颜色随意染的方案数。可以通过生成函数知道

    (egin{align*} c_j &= inom{m}{j} n! [x^n] (frac{x^k}{k!})^j (sumlimits_{i=0}^infty frac{x^i}{i!})^{m-j} \ &= inom{m}{j} n! [x^n] (frac{x^k}{k!})^j e^{(m-j)x} \ &= inom{m}{j} n! [x^n] (frac{x^k}{k!})^j sumlimits_{i=0}^infty frac{x^i (m-j)^i}{i!} \ &= inom{m}{j}n! frac{(m-j)^{n-jk}}{(k!)^j (n-jk)!} end{align*})

    显然是可以预处理阶乘之后(O(1))计算的。注意不要漏掉了指数型生成函数前面要乘的(n!)

    又设(h_j)表示恰好有(j)种颜色出现了(S)次的方案总数。不难发现有一个反演:(h_j = c_j - sumlimits_{i=j+1}^{Max} h_{i}A(i,j))(A(i,j))是一个与(i,j)相关的系数,表示(h_i)(c_j)中的出现次数。

    既然(h_i)中恰好有(i)种颜色出现了(S)次,那么对于任意一个对(h_i)产生贡献的状态,只要枚举到当前状态中(i)种恰好出现了(S)次的颜色构成的集合的任意一个大小为(j)的子集时都会对(c_j)产生(1)的贡献。所以(A(i,j) = inom{i}{j})

    所以可以得到(h_j = c_j - sumlimits_{i=j+1}^{Max}h_i inom{i}{j}),两边同乘(j!)得到(h_jj! = c_jj! - sumlimits_{i=j+1}^{Max}frac{h_ii!}{(i-j)!})

    设多项式(H = sumlimits_{i=0}^{Max}h_ii!x^i , C = sumlimits_{i=0}^{Max}c_ii!x^i),记(rev(H))为多项式(H)所有系数翻转过来之后的多项式,那么不难得到(rev(H) = rev(C) - W * rev(H)),其中(W = sumlimits_{i=1}^{Max} frac{1}{i!}x^i)。多项式求逆即可。

    Update:不难发现(W+1 = e^x),所以求逆的结果是(e^{-x}),所以可以不必求逆直接把(e^{-x})的系数代替求逆;实际上这个反演的过程是二项式反演的一个变体,可以通过二项式反演的方式进行NTT,实质一样。

    #include<iostream>
    #include<cstdio>
    #include<random>
    #include<cstring>
    #include<algorithm>
    //This code is written by Itst
    using namespace std;
    
    const int mod = 998244353;
    inline int read(bool flg = 0){
        int a = 0;
        char c = getchar();
        bool f = 0;
        while(!isdigit(c) && c != EOF){
            if(c == '-')
                f = 1;
            c = getchar();
        }
        if(c == EOF)
            exit(0);
        while(isdigit(c)){
            if(flg)
                a = (a * 10ll + c - 48) % mod;
            else
                a = a * 10 + c - 48;
            c = getchar();
        }
        if(flg) a += mod;
        return f ? -a : a;
    }
    
    const int MAXN = (1 << 19) + 7 , MAXM = 1e7 + 7 , MOD = 1004535809;
    #define PII pair < int , int >
    #define st first
    #define nd second
    
    inline int poww(long long a , int b){
        int times = 1;
        while(b){
            if(b & 1)
                times = times * a % MOD;
            a = a * a % MOD;
            b >>= 1;
        }
        return times;
    }
    
    namespace poly{
        const int G = 3 , INV = (MOD + 1) / G;
        int A[MAXN] , B[MAXN] , C[MAXN] , D[MAXN] , E[MAXN];
        int a[MAXN] , b[MAXN] , c[MAXN] , d[MAXN];
        int need , inv , dir[MAXN] , _inv[MAXN];
    #define clear(x) memset(x , 0 , sizeof(int) * need)
    
        void init(int len){
            need = 1;
            while(need < len)
                need <<= 1;
            inv = poww(need , MOD - 2);
            for(int i = 1 ; i < need ; ++i)
                dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
        }
    
        void init_inv(){
            _inv[1] = 1;
            for(int i = 2 ; i < MAXN ; ++i)
                _inv[i] = MOD - 1ll * (MOD / i) * _inv[MOD % i] % MOD;
        }
    
        void NTT(int *arr , int type){
            for(int i = 1 ; i < need ; ++i)
                if(i < dir[i])
                    arr[i] ^= arr[dir[i]] ^= arr[i] ^= arr[dir[i]];
            for(int i = 1 ; i < need ; i <<= 1){
                int wn = poww(type == 1 ? G : INV , (MOD - 1) / i / 2);
                for(int j = 0 ; j < need ; j += i << 1){
                    long long w = 1;
                    for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
                        int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
                        arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
                        arr[i + j + k] = x < y ? x + MOD - y : x - y;
                    }
                }
            }
        }
    
        void mul(int *a , int *b){
            NTT(a , 1);NTT(b , 1);
            for(int i = 0 ; i < need ; ++i)
                a[i] = 1ll * a[i] * b[i] % MOD;
            NTT(a , -1);
        }
    
        void getInv(int *a , int *b , int len){
            if(len == 1){
                b[0] = poww(a[0] , MOD - 2);
                return;
            }
            getInv(a , b , (len + 1) >> 1);
            memcpy(A , a , sizeof(int) * len);
            memcpy(B , b , sizeof(int) * len);
            init(len * 3);
            NTT(A , 1);NTT(B , 1);
            for(int i = 0 ; i < need ; ++i)
                A[i] = 1ll * A[i] * B[i] % MOD * B[i] % MOD;
            NTT(A , -1);
            for(int i = 0 ; i < len ; ++i)
                b[i] = (2 * b[i] - 1ll * A[i] * inv % MOD + MOD) % MOD;
            clear(A);clear(B);
        }
    }
    using namespace poly;
    int F[MAXN] , H[MAXN] , jc[MAXM] , Inv[MAXM] , W[MAXN];
    int N , M , K , Len;
    
    void init(){
        jc[0] = 1;
        for(int i = 1 ; i <= N || i <= M ; ++i)
            jc[i] = 1ll * jc[i - 1] * i % MOD;
        Inv[max(N , M)] = poww(jc[max(N , M)] , MOD - 2);
        for(int i = max(N , M) - 1 ; i >= 0 ; --i)
            Inv[i] = Inv[i + 1] * (i + 1ll) % MOD;
    }
    
    int binom(int b , int a){
        return b < a ? 0 : 1ll * jc[b] * Inv[a] % MOD * Inv[b - a] % MOD;
    }
    
    int calc(int j){
        return 1ll * poww(Inv[K] , j) * Inv[N - j * K] % MOD * poww(M - j , N - j * K) % MOD * binom(M , j) % MOD * jc[N] % MOD;
    }
    
    int main(){
    #ifndef ONLINE_JUDGE
        freopen("in","r",stdin);
        //freopen("out","w",stdout);
    #endif
        init_inv();
        N = read(); M = read(); K = read();
        for(int i = 0 ; i <= M ; ++i) W[i] = read();
        Len = min(N / K , M);
        init();
        for(int i = 0 ; i <= Len ; ++i)
            F[i] = 1ll * calc(i) * jc[i] % MOD;
        reverse(F , F + Len + 1);
        for(int i = 0 ; i <= Len ; ++i)
            H[i] = Inv[i];
        getInv(H , a , Len + 1);
        init((Len + 1) * 2);
        mul(F , a);
        reverse(F , F + Len + 1);
        int ans = 0;
        for(int i = 0 ; i <= Len ; ++i)
            ans = (ans + 1ll * F[i] * inv % MOD * Inv[i] % MOD * W[i]) % MOD;
        cout << ans;
        return 0;
    }
    
  • 相关阅读:
    PHP 获取某年第几周的开始日期和结束日期的实例
    PHP科学计数法转换成数字
    laravel 辅助函数
    laravel5.3之后可以使用withCount()这个方法
    laravel 5.1 Model 属性详解
    laravel的启动过程解析
    转:按需加载html 图片 css js
    移动平台WEB前端开发技巧汇总(转)
    php重定向页面的三种方式
    zepto API参考(~~比较全面)
  • 原文地址:https://www.cnblogs.com/Itst/p/10548463.html
Copyright © 2011-2022 走看看