zoukankan      html  css  js  c++  java
  • LOJ2320「清华集训 2017」生成树计数

    由于菜鸡的我实在是没学会上升幂下降幂那一套理论,这里用的是完全普通多项式的做法。

    要是有大佬愿意给我讲讲上升幂下降幂那一套东西,不胜感激orz!

    首先可以想到prufer序列,如果不会的话可以左转百度。

    我们把答案写成prufer序列的形式,这样的话他的贡献是固定的,其中$d_i$表示$i$的出现次数。

    $ans = (n-2)! prod_{i=1}^n frac{a_i^{d_i+1}}{d_i!}$

    把原来的柿子带进来

    $ans = sum_{sum d_i = n-2} (n-2)! sum_{i=1}^n a_i^{d_i+1} d_i^mprod_{j=1}frac{d_j^m}{d_j!}$

    把相同的因数提出来

    $ans =(n-2)! prod_{i=1}^n a_i sum_{sum d_i = n-2} sum_{i=1}^n a_i^{d_i} d_i^mprod_{j=1}frac{d_j^m}{d_j!}$

    考虑后面的

    $ans' = sum_{sum d_i = n-2} sum_{i=1}^n a_i^{d_i} d_i^mprod_{j=1}frac{d_j^m}{d_j!}$

    为了方便写成生成函数形式,把柿子化成这样

    $ans' = sum_{sum d_i = n-2} sum_{i=1}^n frac{a_i^{d_i} d_i^{2m}}{d_i!}prod_{j=1,j eq i}frac{d_j^m}{d_j!} $

    $A(x)=sum_{i=0}^n frac{i^{2m}}{i!} x^i$

    $B(x)=sum_{i=0}^n frac{i^m}{i!} x^i$

    把它们带进去

    $ans '=sum_{i=1}^n A(a_i) prod_{j=1,j eq i}^n B(a_j) \ = sum_{i=1}^n frac{A}{B}(a_i) prod_{j=1}^n B(a_j)$

    对于$prod B(a_i)$ 根据常见套路,写成$exp(sum ln(B(a_i)))$ 当然$frac{A}{B}$也是这么处理。

    接下来我们就遇到了更常见的,求$sum_{i=1}^{n} a_i^k$ 也就是求数列幂和。

    这个东西可以戳这里:qwq (额 需要密码,如果你需要的话可以戳我QQ)

    这里简略的说一下,就是我们有$F(x)=sum(a_1^i+...+a_n^i)x^i$ 可以化成这样$F(x) = sum (a_ix+...+a_i^nx^n)$ 进而写成这样$F(x) = sum frac{a_i}{1-a_ix}$

    我们令$G(x) = prod (1-a_ix)$ 所以有$F = -ln'(G)$

    对于$G$来说,我们可以直接分治FFT,然后F就可以poly_ln来求

    其余的就是写一个多项式全家桶就可以了。

    啊,多项式真有趣。

    //Love and Freedom.
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    #include<cmath>
    #define ll long long
    #define inf 20021225
    #define mdn 998244353
    #define G 3
    #define N 400100
    using namespace std;
    int read()
    {
        int s=0,f=1; char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return f*s;
    }
    int fac[N],inv[N];
    int Inv(int x){return 1ll*inv[x]*fac[x-1]%mdn;}
    void upd(int &x,int y){x+=x+y>=mdn?y-mdn:y;}
    int ksm(int bs,int mi)
    {
        int ans=1;
        while(mi)
        {
            if(mi&1)    ans=1ll*ans*bs%mdn;
            bs=1ll*bs*bs%mdn; mi>>=1;
        }
        return ans;
    }
    int r[N];
    int init(int n)
    {
        int l=0,lim=1;
        while(lim<n)    lim<<=1,l++;
        for(int i=0;i<lim;i++)
            r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
        return lim;
    }
    void ntt(int *a,int lim,int f)
    {
        for(int i=0;i<lim;i++)    if(r[i]>i)
            swap(a[r[i]],a[i]);
        for(int k=2,mid=1;k<=lim;k<<=1,mid<<=1)
        {
            int Wn=ksm(G,(mdn-1)/k); if(f)    Wn=ksm(Wn,mdn-2);
            for(int w=1,i=0;i<lim;i+=k,w=1)    for(int j=0;j<mid;j++,w=1ll*w*Wn%mdn)
            {
                int x=a[i+j],y=1ll*w*a[i+mid+j]%mdn;
                a[i+j]=(x+y)%mdn; a[i+mid+j]=(mdn+x-y)%mdn;
            }
        }
        if(f)    for(int kinv=ksm(lim,mdn-2),i=0;i<lim;i++)
            a[i]=1ll*a[i]*kinv%mdn;
    }
    int poly_mul(int *a,int *b,int *c,int n)
    {
        int lim=init(n<<1);
        ntt(a,lim,0); ntt(b,lim,0);
        for(int i=0;i<lim;i++)    c[i]=1ll*a[i]*b[i]%mdn;
        ntt(c,lim,1); return lim;
    }
    int tmp[N];
    void poly_inv(int *a,int *ans,int n)
    {
        if(n==1){ans[0]=ksm(a[0],mdn-2); return;}
        int mid=n+1>>1; poly_inv(a,ans,mid);
        int lim=init(n<<1);
        for(int i=0;i<n;i++)    tmp[i]=a[i];
        for(int i=n;i<lim;i++)    tmp[i]=0;
        ntt(tmp,lim,0); ntt(ans,lim,0);
        for(int i=0;i<lim;i++)
            ans[i]=(2ll-1ll*tmp[i]*ans[i]%mdn+mdn)*ans[i]%mdn;
        ntt(ans,lim,1);
        for(int i=n;i<lim;i++)    ans[i]=0;
    }
    int tmp2[N];
    void poly_ln(int *a,int *ans,int n)
    {
        poly_inv(a,tmp2,n);
        for(int i=0;i<n;i++)    ans[i]=1ll*a[i+1]*(i+1)%mdn;
        int lim=poly_mul(tmp2,ans,tmp2,n); tmp2[n-1]=0;
        for(int i=1;i<n;i++)    ans[i]=1ll*tmp2[i-1]*Inv(i)%mdn,tmp2[i-1]=0;
        for(int i=n;i<lim;i++)    ans[i]=tmp2[i]=0;
        ans[0]=0;
    }
    int tmp3[N],tmp4[N];
    void poly_exp(int *a,int *ans,int n)
    {
        if(n==1){ans[0]=1; return;}
        int mid=n+1>>1; poly_exp(a,ans,mid);
        for(int i=0;i<(n<<1);i++)    tmp2[i]=tmp3[i]=0;
        for(int i=0;i<n;i++)    tmp4[i]=a[i]; poly_ln(ans,tmp3,n);
        int lim=init(n<<1);
        ntt(tmp3,lim,0); ntt(tmp4,lim,0); ntt(ans,lim,0);
        for(int i=0;i<lim;i++)
            ans[i]=1ll*(tmp4[i]-tmp3[i]+1+mdn)%mdn*ans[i]%mdn;
        ntt(ans,lim,1);
        for(int i=n;i<lim;i++)    ans[i]=tmp4[i]=0;
    }
    int tmp5[N],tmp6[N];
    void solve(int *a,int l,int r)
    {
        if(l==r)    return;
        int mid=l+r>>1;
        solve(a,l,mid); solve(a,mid+1,r);
        int d=r-l+1;
        for(int i=l;i<=mid;i++)    tmp5[i-l+1]=a[i];
        for(int i=mid+1;i<=r;i++)    tmp6[i-mid]=a[i];
        tmp5[0]=tmp6[0]=1;
        int lim=poly_mul(tmp5,tmp6,tmp5,d);
        for(int i=0;i<d;i++)    a[l+i]=tmp5[i+1];
        for(int i=0;i<=lim;i++)    tmp5[i]=tmp6[i]=0;
    }
    void prework(int n)
    {
        n++; fac[0]=1;
        for(int i=1;i<=n;i++)    fac[i]=1ll*fac[i-1]*i%mdn;
        inv[n]=ksm(fac[n],mdn-2);
        for(int i=n;i;i--)    inv[i-1]=1ll*inv[i]*i%mdn;
    }
    int a[N],f[N],n,m,g[N],A[N],B[N],in[N],fr[N],ls[N];
    int main()
    {
        n=read(),m=read(); prework(n);
        for(int i=1;i<=n;i++)    a[i]=read(),f[i]=(mdn-a[i])%mdn;
        solve(f,1,n); f[0]=1; poly_ln(f,g,n+1);
        for(int i=0;i<=n;i++)    f[i]=1ll*g[i+1]*(i+1)%mdn;
        for(int i=n;i;i--)    f[i]=(-f[i-1]+mdn)%mdn; f[0]=n;
        for(int i=0;i<=n;i++)
        {
            A[i]=1ll*ksm(i+1,2*m)*inv[i]%mdn;
            B[i]=1ll*ksm(i+1,m)*inv[i]%mdn;
        }
        int lim=init(n<<1|1);
        poly_ln(B,g,n+1); poly_inv(B,in,n+1);
        for(int i=0;i<n;i++)    g[i]=1ll*g[i]*f[i]%mdn;
        for(int i=n;i<lim;i++)    g[i]=in[i]=0;
        poly_exp(g,ls,n);
        for(int i=n;i<lim;i++)    ls[i]=0;
        poly_mul(A,in,fr,n);
        for(int i=0;i<n;i++)    fr[i]=1ll*fr[i]*f[i]%mdn;
        for(int i=n;i<lim;i++)    fr[i]=0;
        poly_mul(fr,ls,f,n);
        int ans=fac[n-2];
        for(int i=1;i<=n;i++)    ans=1ll*ans*a[i]%mdn;
        ans=1ll*ans*f[n-2]%mdn;
        printf("%d
    ",ans);
        return 0;
    }
    View Code
  • 相关阅读:
    创建nodejs服务器
    研磨设计模式学习笔记2外观模式Facade
    研磨设计模式学习笔记4单例模式Signleton
    研磨设计模式学习笔记1简单工厂(SimpleFactory)
    getResourceAsStream小结
    研磨设计模式学习笔记3适配器模式Adapter
    oracle数据库代码块
    DecimalFormat
    .NET中常用的代码(转载)
    WebClient的研究笔记
  • 原文地址:https://www.cnblogs.com/hanyuweining/p/11594258.html
Copyright © 2011-2022 走看看