zoukankan      html  css  js  c++  java
  • 题解 洛谷 P4705 【玩游戏】

    先化简答案的式子:

    [ans_k=frac{1}{nm}sum_{i=1}^nsum_{j=1}^m(a_i+b_j)^k =frac{1}{nm}sum_{i=1}^nsum_{j=1}^msum_{l=0}^kinom{k}{l}a_i^lb_j^{k-l} =frac{k!}{nm}sum_{l=0}^kleft (sum_{i=1}^nfrac{a_i^l}{l!} ight )left (sum_{j=1}^mfrac{b_j^{k-l}}{(k-l)!} ight ) ]

    发现是卷积的形式,那么只要能快速计算 (s_k=sumlimits_{i=1}^n a_i^k),然后进行卷积计算答案即可。

    构造生成函数 (f(x)=sumlimits_{i geqslant 0} s_ix^i),进行化简:

    [f(x)=sum_{i geqslant 0} s_ix^i =sum_{i geqslant 0} sum_{j=0}^{n-1} a_j^ix^i =sum_{j=0}^{n-1}frac{1}{1-a_jx} ]

    注意到:

    [{(ln (1-a_ix))}'= frac{-a_i}{1-a_ix} ]

    代入得:

    [f(x)=sum_{i=0}^{n-1}frac{1}{1-a_ix} =sum_{i=0}^{n-1}1-x{(ln (1-a_ix))}' =n-x{left (sum_{i=0}^{n-1}ln (1-a_ix) ight )}' =n-xleft (lnprod_{i=0}^{n-1} (1-a_ix) ight )' ]

    对于 (prodlimits_{i=0}^{n-1} (1-a_ix)),可以通过分治来求解,即先算出左右两半后乘起来。

    最后通过多项式对数函数即可求解答案。

    (code:)

    #include<bits/stdc++.h>
    #define maxn 800010
    #define P 998244353
    #define G 3
    using namespace std;
    typedef long long ll;
    template<typename T> inline void read(T &x)
    {
        x=0;char c=getchar();bool flag=false;
        while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
        while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
        if(flag)x=-x;
    }
    ll n,m,k,all;
    ll rev[maxn],a[maxn],b[maxn],f[maxn],g[maxn],fac[maxn],ifac[maxn];
    ll qp(ll x,ll y)
    {
        ll v=1;
        while(y)
        {
            if(y&1) v=v*x%P;
            x=x*x%P,y>>=1;
        }
        return v;
    }
    int calc(int n)
    {
        int lim=1;
        while(lim<=n) lim<<=1;
        for(int i=0;i<lim;++i)
            rev[i]=(rev[i>>1]>>1)|((i&1)?lim>>1:0);
        return lim;
    }
    void NTT(ll *a,int lim,int type)
    {
        for(int i=0;i<lim;++i)
            if(i<rev[i])
                swap(a[i],a[rev[i]]);
        for(int len=1;len<lim;len<<=1)
        {
            ll wn=qp(G,(P-1)/(len<<1));
            for(int i=0;i<lim;i+=len<<1)
            {
                ll w=1;
                for(int j=i;j<i+len;++j,w=w*wn%P)
                {
                    ll x=a[j],y=w*a[j+len]%P;
                    a[j]=(x+y)%P,a[j+len]=(x-y+P)%P;
                }
            }
        }
        if(type==1) return;
        ll inv=qp(lim,P-2);
        for(int i=0;i<lim;++i) a[i]=a[i]*inv%P;
        reverse(a+1,a+lim);
    }
    void Inv(int deg,ll *a,ll *b)
    {
        static ll t[maxn];
        if(deg==1)
        {
            b[0]=qp(a[0],P-2);
            return;
        }
        Inv((deg+1)>>1,a,b);
        int lim=calc(deg<<1);
        for(int i=0;i<deg;++i) t[i]=a[i];
        for(int i=deg;i<lim;++i) t[i]=b[i]=0;
        NTT(t,lim,1),NTT(b,lim,1);
        for(int i=0;i<lim;++i)
            b[i]=b[i]*(2-t[i]*b[i]%P+P)%P;
        NTT(b,lim,-1);
        for(int i=deg;i<lim;++i) b[i]=0;
    }
    void Ln(int deg,ll *a,ll *b)
    {
        static ll inva[maxn],dera[maxn];
        Inv(deg,a,inva);
        for(int i=0;i<deg-1;++i) dera[i]=a[i+1]*(i+1)%P;
        dera[deg-1]=0;
        int lim=calc(deg<<1);
        for(int i=deg;i<lim;++i) dera[i]=inva[i]=0;
        NTT(dera,lim,1),NTT(inva,lim,1);
        for(int i=0;i<lim;++i) b[i]=dera[i]*inva[i]%P;
        NTT(b,lim,-1);
        for(int i=deg-1;i>=1;--i) b[i]=b[i-1]*qp(i,P-2)%P;
        b[0]=0;
        for(int i=deg;i<lim;++i) b[i]=0;
    }
    void solve(int l,int r,ll *a,ll *b)
    {
        if(l==r)
        {
            b[0]=1,b[1]=P-a[l];
            return;
        }
        int mid=(l+r)>>1;
        ll a1[maxn],a2[maxn];
        solve(l,mid,a,a1),solve(mid+1,r,a,a2);
        int lim=calc(r-l+1);
        for(int i=mid-l+2;i<lim;++i) a1[i]=0;
        for(int i=r-mid+1;i<lim;++i) a2[i]=0;
        NTT(a1,lim,1),NTT(a2,lim,1);
        for(int i=0;i<lim;++i) b[i]=a1[i]*a2[i]%P;
        NTT(b,lim,-1);
        for(int i=r-l+2;i<lim;++i) b[i]=0;
    }
    void get(int n,ll *a,ll *b)
    {
        static ll t[maxn];
        solve(1,calc(all),a,t),Ln(all+1,t,b);
        for(int i=0;i<=all;++i) b[i]=b[i+1]*(i+1)%P;
        b[all+1]=0;
        for(int i=all;i;--i) b[i]=P-b[i-1];
        b[0]=n;
        for(int i=0;i<=all;++i) b[i]=b[i]*ifac[i]%P;
    }
    void init()
    {
        fac[0]=ifac[0]=1;
        for(int i=1;i<=all;++i) fac[i]=fac[i-1]*i%P;
        ifac[all]=qp(fac[all],P-2);
        for(int i=all-1;i;--i) ifac[i]=ifac[i+1]*(i+1)%P;
    }
    int main()
    {
        read(n),read(m);
        for(int i=1;i<=n;++i) read(a[i]);
        for(int i=1;i<=m;++i) read(b[i]);
        read(k),all=max(k,max(n,m)),init();
        get(n,a,f),get(m,b,g);
        int lim=calc(all<<1);
        NTT(f,lim,1),NTT(g,lim,1);
        for(int i=0;i<lim;++i) f[i]=f[i]*g[i]%P;
        NTT(f,lim,-1);
        ll inv=qp(n*m%P,P-2);
        for(int i=1;i<=k;++i) printf("%lld
    ",fac[i]*inv%P*f[i]%P);
        return 0;
    }
    
  • 相关阅读:
    msp430入门编程41
    msp430入门编程40
    msp430入门编程37
    msp430入门编程36
    msp430入门编程35
    msp430入门编程34
    msp430入门编程33
    msp430入门编程31
    msp430入门编程32
    msp430入门编程30
  • 原文地址:https://www.cnblogs.com/lhm-/p/13442295.html
Copyright © 2011-2022 走看看