zoukankan      html  css  js  c++  java
  • [EOJ629] 两开花

    Description

    给定一棵以 (1) 为根 (n) 个节点的树。

    定义 (f(k)) :从树上等概率随机选出 (k) 个节点,这 (k) 个点的虚树大小的期望。

    一个点 (x) 在这些被选出的 (k) 个点的虚树上,当且仅当它满足下列条件至少一个:

    • (x) 被选出。
    • 存在两个被选出的节点 (a,b),使得 (operatorname{lca}(a,b)=x)

    给定 (m),求 (f(1),f(2),cdots,f(m))。 对 (998244353) 取模。(nleq 4cdot 10^5)

    Sol

    又是套着期望皮的计数题。

    对于每个点 (i) 求出有多少种方案对答案有贡献即可:

    • (i) 被选出,总方案数为 (C(n-1,k-1))
    • (i) 至少两个儿子的子树中存在被选出的点。

    第二种不太好算,考虑用总方案数减去不合法的方案数。

    总方案数就是 (C(n-1,k))

    如果点 (i) 的子树中没有被选中的,方案数为 (C(n-sze[i],k))

    只有一个儿子的子树中有被选中的,可以枚举儿子 (j),方案数就是 (sumlimits_{j} C(n-sze[i]+sze[j],k))

    注意到这样的话,(i) 子树中没有被选中的方案数被多算了 儿子个数次,所以还需要加上 (son[i] imes C(n-sze[i],k))

    所以

    [f(k)=sumlimits_{i=1}^n C_{n-1}^{k-1}+C_{n-1}^k+(son[i]-1) imes C_{n-sze[i]}^k-sum_j C_{n-sze[i]+sze[j]}^k ]

    [f(k)=sumlimits_{i=1}^n C_{n}^{k}+(son[i]-1) imes C_{n-sze[i]}^k-sum_j C_{n-sze[i]+sze[j]}^k ]

    如何对于每个 (k) 快速求呢?

    观察到式子中的每一项组合数的上标都是 (k),所以我们可以开个桶 (buc[i]),在形如 (buc[n-sze[i]]) 的地方加上 (son[i]+1),在 (buc[n-sze[i]+sze[j]])(-1)

    好处就是,再推一步式子:

    [f(k)=sum_{i=0}^n buc[i]cdot C_i^k ]

    这就是个卷积的形式,(mathbf{NTT})优化就吼了。

    Code

    #pragma GCC optimize(2)
    #include<bits/stdc++.h>
    using std::min;
    using std::max;
    using std::swap;
    using std::vector;
    typedef double db;
    typedef long long ll;
    #define pb(A) push_back(A)
    #define pii std::pair<int,int>
    #define all(A) A.begin(),A.end()
    #define mp(A,B) std::make_pair(A,B)
    const int N=2e6+5;
    const int mod=998244353;
    
    int son[N],sze[N],buc[N];
    int n,m,cnt,head[N],fac[N];
    int a[N],b[N],lim,rev[N],ifac[N];
    
    struct Edge{
        int to,nxt;
    }edge[N<<1];
    
    void add(int x,int y){
        edge[++cnt].to=y;
        edge[cnt].nxt=head[x];
        head[x]=cnt;
    }
    
    int ksm(int a,int b=mod-2,int ans=1){
        while(b){
            if(b&1) ans=1ll*ans*a%mod;
            a=1ll*a*a%mod;b>>=1;
        } return ans;
    }
    
    void ntt(int *f,int g){
        for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            int tmp=ksm(g,(mod-1)/(mid<<1));
            for(int R=mid<<1,j=0;j<lim;j+=R){
                for(int w=1,k=0;k<mid;k++,w=1ll*w*tmp%mod){
                    int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                    f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
                }
            }
        } if(g>3)
            for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
    }
    
    int getint(){
        int X=0,w=0;char ch=getchar();
        while(!isdigit(ch))w|=ch=='-',ch=getchar();
        while( isdigit(ch))X=X*10+ch-48,ch=getchar();
        if(w) return -X;return X;
    }
    
    void init(int n){
        fac[0]=ifac[0]=1;
        for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
        ifac[n]=ksm(fac[n]);
        for(int i=n-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
    }
    
    void dfs(int now,int fa=0){
        sze[now]=1; int tot=0; buc[n]++;
        for(int i=head[now];i;i=edge[i].nxt){
            int to=edge[i].to;
            if(sze[to]) continue;
            tot++; dfs(to,now);
            sze[now]+=sze[to];
        }
        for(int i=head[now];i;i=edge[i].nxt){
            int to=edge[i].to;
            if(to==fa) continue;
            (buc[n-sze[now]+sze[to]]+=mod-1)%=mod;
        } (buc[n-sze[now]]+=tot-1+mod)%=mod;
    }
    
    int C(int n,int m){
        if(n<m) return 0;
        return 1ll*ifac[n]*fac[m]%mod*fac[n-m]%mod;
    }
    
    signed main(){
        n=getint(),m=getint(),init(N-5);
        for(int i=1;i<n;i++){
            int x=getint(),y=getint();
            add(x,y),add(y,x);
        } dfs(1);
        lim=1;while(lim<=n+n) lim<<=1;
        for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
        for(int i=0;i<=n;i++)
            a[n-i]=1ll*buc[i]*fac[i]%mod,
            b[i]=ifac[i];
        ntt(a,3),ntt(b,3);
        for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
        ntt(a,(mod+1)/3);
        for(int i=1;i<=m;i++) 
            printf("%lld
    ",1ll*a[n-i]*ifac[i]%mod*C(n,i)%mod);
        return 0;
    }
    
    
  • 相关阅读:
    计算十位数以内的数的反数
    用Python做一个简单的小游戏
    Python的发展历史及其前景
    监控相关总结
    前端css学习_Day15
    常用命令总结
    mysql常用命令总结
    Python之Paramiko、前端之html学习_Day14
    Python操作redis、memcache和ORM框架_Day13
    Python连接msyql、redis学习_Day12
  • 原文地址:https://www.cnblogs.com/YoungNeal/p/10363101.html
Copyright © 2011-2022 走看看