zoukankan      html  css  js  c++  java
  • Loj 2320.「清华集训 2017」生成树计数

    Loj 2320.「清华集训 2017」生成树计数

    题目描述

    在一个 (s) 个点的图中,存在 (s-n) 条边,使图中形成了 (n) 个连通块,第 (i) 个连通块中有 (a_i) 个点。

    现在我们需要再连接 (n-1) 条边,使该图变成一棵树。对一种连边方案,设原图中第 (i) 个连通块连出了 (d_i) 条边,那么这棵树 (T) 的价值为:

    [mathrm{val}(T) = left(prod_{i=1}^{n} {d_i}^m ight)left(sum_{i=1}^{n} {d_i}^m ight) ]

    你的任务是求出所有可能的生成树的价值之和,对 (998244353) 取模。

    输入格式

    输入的第一行包含两个整数 (n,m),意义见题目描述。

    接下来一行有 (n) 个整数,第 (i) 个整数表示 (a_i) ((1le a_i< 998244353))

    * 你可以由 (a_i) 计算出图的总点数 (s),所以在输入中不再给出 (s) 的值。

    输出格式

    输出包含一行一个整数,表示答案。

    数据范围与提示

    本题共有 (20) 个测试点,每个测试点 (5) 分。

    - (20\%) 的数据中,(nle500)

    - 另外 (20\%) 的数据中,(n le 3000)

    - 另外 (10\%) 的数据中,(n le 10010, m = 1)

    - 另外 (10\%)的数据中,(n le 10015,m = 2)

    -另外 (20\%) 的数据中,所有 (a_i) 相等。

    (\)

    好神的题啊!

    假设我们知道了每个点的度数,考虑计算此时的生成树的个数。这个用(prufer)序列非常好解决:

    假设第(i)个点在(prufer)序列中出现次数为(d_i),(则其度数为(d_i+1)

    [Ans=(n-2)!prod_{i=1}^nfrac{{a_i}^{d_i+1}}{d_i!} ]

    先考虑对式子进行变形

    [egin{align} Ans&= sum_{sum d_i==n-2} (n-2)! sum_{i=1}^nfrac{{{a_i}^{d_i+1}d_i}^{2m}}{d_i!} prod_{j=1,j eq i}^nfrac{{d_j}^m}{d_j!}\ &=(n-2)!prod_{i=1}^na_i sum_{sum_{d_i==n-2}}sum_{i=1}^nfrac{{{a_i}^{d_i}d_i}^{2m}}{d_i!} prod_{j=1,j eq i}^nfrac{{d_j}^m}{d_j!}\ end{align} ]

    [Ans'=sum_{sum_{d_i==n-2}}sum_{i=1}^nfrac{{{a_i}^{d_i}d_i}^{2m}}{d_i!}prod_{j=1,j eq i}^nfrac{{d_j}^m}{d_j!} ]

    考虑用生成函数解决:

    [A(x)=sum_{i=0}^nfrac{i^{2m}}{i!}x^i\ B(x)=sum_{i=0}^nfrac{i^m}{i!}x^i ]

    (Ans')的生成函数为

    [sum_{i=1}^nA(a_i)prod_{j=1,j eq i}^nB(a_j)\ =sum_{i=1}^nfrac{A(a_i)}{B(a_i)}prod_{j=1}^nB(a_j) ]

    对于(prod_{j=1}^nB(a_j)),我们的一般套路是将其写成

    [exp(ln(prod_{j=1}^nB(a_j)))\ =exp(sum_{j=1}^nln(B(a_j))) ]

    这样做的好处是我们只需要求出(ln(B(x))),然后对第(i)项系数乘上(displaystyle sum_{j=1}^n{a_j}^i)就可以得到(displaystyle sum_{j=1}^nln(B(a_j)))了。对于(displaystyle sum_{i=1}^nfrac{A(a_i)}{B(a_i)})我们也 用相同的处理方式。

    所以:

    [Ans'=sum_{i=1}^nfrac{A}{B}(a_i)exp(sum_{j=1}^nln(B(a_j))) ]

    现在的问题是如何求出

    [sum_{i=1}^n{a_i}^k ]

    考虑(ln(x))的取(x_0=1)时的泰勒展开形式

    [ln(x)=sum_{i=0}frac{ln^{[i](1)}}{i!}(x-1)^i\ =sum_{i=1}frac{(-1)^{i-1}}{i}(x-1)^i ]

    所以:

    [ln(1+a_jx)=sum_{i=1}frac{(-1)^{i-1}{a_j}^i}{i}x^i ]

    那么我们只需要求出

    [sum_{i=1}^nln(a_i) ]

    就行了。

    [sum_{i=1}^nln(a_i)=ln(prod_{i=1}^n(1+a_ix)) ]

    (prod_{i=1}^n(1+a_ix))可以用分治(NTT)求出。

    代码:

    #include<bits/stdc++.h>
    #define ll long long
    #define N 200005
    
    using namespace std;
    inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
    
    const ll mod=998244353;
    ll ksm(ll t,ll x) {
    	ll ans=1;
    	for(;x;x>>=1,t=t*t%mod)
    		if(x&1) ans=ans*t%mod;
    	return ans;
    }
    
    int n,m;
    int a[N];
    
    void NTT(ll *a,int d,int flag) {
    	int n=1<<d;
    	static int rev[N<<2];
    	static ll G=3;
    	for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
    	for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    	
    	for(int s=1;s<=d;s++) {
    		int len=1<<s,mid=len>>1;
    		ll w=flag==1?ksm(G,(mod-1)/len):ksm(G,mod-1-(mod-1)/len);
    		for(int i=0;i<n;i+=len) {
    			ll t=1;
    			for(int j=0;j<mid;j++,t=t*w%mod) {
    				ll u=a[i+j],v=a[i+j+mid]*t%mod;
    				a[i+j]=(u+v)%mod;
    				a[i+j+mid]=(u-v+mod)%mod;
    			}
    		}
    	}
    	
    	if(flag==-1) {
    		ll inv=ksm(n,mod-2);
    		for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
    	}
    }
    
    ll A[N<<2],B[N<<2];
    ll inv[N<<2];
    ll f[N<<2],g[N<<2];
    void Inv(ll *inv,ll *a,int d) {
    	static ll A[N<<3];
    	if(d==0) {
    		inv[0]=ksm(a[0],mod-2);
    		return ;
    	}
    	Inv(inv,a,d-1);
    	for(int i=0;i<1<<d;i++) A[i]=a[i];
    	for(int i=1<<d;i<1<<d+1;i++) inv[i]=A[i]=0;
    	NTT(A,d+1,1);
    	NTT(inv,d+1,1);
    	for(int i=0;i<1<<d+1;i++) {
    		inv[i]=(2*inv[i]-A[i]*inv[i]%mod*inv[i]%mod+mod)%mod;
    	}
    	NTT(inv,d+1,-1);
    	for(int i=1<<d;i<1<<d+1;i++) inv[i]=0;
    }
    
    void Der(ll *a,int d) {
    	int n=1<<d;
    	for(int i=0;i<n-1;i++) a[i]=(i+1)*a[i+1]%mod;
    	a[n-1]=0;
    }
    
    void Int(ll *a,int d) {
    	int n=1<<d;
    	for(int i=n-1;i>0;i--) a[i]=ksm(i,mod-2)*a[i-1]%mod;
    	a[0]=0;
    }
    
    ll ln[N<<2];
    void Ln(ll *ln,ll *a,int d) {
    	static ll der[N<<2];
    	for(int i=0;i<1<<d+1;i++) der[i]=0;
    	for(int i=0;i<1<<d;i++) der[i]=a[i];
    	Inv(inv,a,d);
    	Der(der,d);
    	NTT(inv,d+1,1),NTT(der,d+1,1);
    	for(int i=0;i<1<<d+1;i++) ln[i]=inv[i]*der[i]%mod;
    	NTT(ln,d+1,-1);
    	for(int i=1<<d;i<1<<d+1;i++) ln[i]=0;
    	Int(ln,d);
    	for(int i=1<<d;i<1<<d+1;i++) ln[i]=0;
    }
    
    ll ex[N<<2];
    
    void Exp(ll *exp,ll *a,int d) {
    	static ll A[N<<2],B[N<<2];
    	if(d==0) {
    		exp[0]=1;
    		return ;
    	}
    	Exp(exp,a,d-1);
    	for(int i=0;i<1<<d;i++) A[i]=a[i];
    	for(int i=1<<d;i<1<<d+1;i++) exp[i]=A[i]=0;
    	Ln(B,exp,d);
    	NTT(exp,d+1,1);
    	NTT(B,d+1,1);
    	NTT(A,d+1,1);
    	for(int i=0;i<1<<d+1;i++) {
    		exp[i]=exp[i]*(1-B[i]+A[i]+mod)%mod;
    	}
    	NTT(exp,d+1,-1);
    	for(int i=1<<d;i<1<<d+1;i++) exp[i]=0;
    }
    
    void solve(int l,int r,ll *a) {
    	static ll A[N<<2],B[N<<2];
    	if(l==r) return ;
    	int mid=l+r>>1;
    	solve(l,mid,a),solve(mid+1,r,a);
    	int d=ceil(log2(r-l+2));
    	for(int i=0;i<1<<d;i++) A[i]=B[i]=0;
    	for(int i=l;i<=mid;i++) A[i-l+1]=a[i];
    	for(int i=mid+1;i<=r;i++) B[i-mid]=a[i];
    	A[0]=B[0]=1;
    	NTT(A,d,1),NTT(B,d,1);
    	for(int i=0;i<1<<d;i++) A[i]=A[i]*B[i]%mod;
    	NTT(A,d,-1);
    	for(int i=l;i<=r;i++) a[i]=A[i-l+1];
    }
    
    ll summ[N];
    ll cal(int k) {
    	ll ans=0;
    	for(int i=1;i<=n;i++) (ans+=ksm(a[i],k))%=mod;
    	return ans;
    }
    
    ll tem[N<<2];
    ll fac[N],ifac[N];
    int main() {
    	n=Get(),m=Get();
    	fac[0]=1;
    	for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
    	ifac[n]=ksm(fac[n],mod-2);
    	for(int i=n-1;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
    	for(int i=1;i<=n;i++) a[i]=Get();
    	for(int i=1;i<=n;i++) summ[i]=a[i];
    	int d=ceil(log2(2*n+1));
    	solve(1,n,summ);
    	summ[0]=1;
    	Ln(ln,summ,d);
    	memcpy(summ,ln,sizeof(summ));
    	
    	summ[0]=n;
    	for(int i=1;i<=n;i++) {
    		if(!(i&1)) summ[i]=summ[i]*(mod-1)%mod;
    		summ[i]=summ[i]*i%mod;
    	}
    	
    	for(int i=0;i<=n;i++) {
    		A[i]=ksm(i+1,2*m)*ifac[i]%mod;
    		B[i]=ksm(i+1,m)*ifac[i]%mod;
    	}
    	
    	
    	Ln(ln,B,d);
    	for(int i=0;i<1<<d;i++) ln[i]=ln[i]*summ[i]%mod;
    	for(int i=n;i<1<<d;i++) ln[i]=0;
    	Exp(g,ln,d);
    	for(int i=n;i<=1<<d;i++) g[i]=0;
    	
    	Inv(inv,B,d);
    	for(int i=n;i<1<<d;i++) inv[i]=0;
    
    	NTT(inv,d,1),NTT(A,d,1);
    	for(int i=0;i<1<<d;i++) f[i]=inv[i]*A[i]%mod;
    	NTT(f,d,-1);
    	for(int i=0;i<1<<d;i++) f[i]=f[i]*summ[i]%mod;
    	for(int i=n;i<1<<d;i++) f[i]=0;
    	
    	
    	NTT(f,d,1),NTT(g,d,1);
    	for(int i=0;i<1<<d;i++) f[i]=f[i]*g[i]%mod;
    	NTT(f,d,-1);
    	
    	ll ans=fac[n-2];
    	for(int i=1;i<=n;i++) ans=ans*a[i]%mod;
    	
    	ans=ans*f[n-2]%mod;
    	cout<<ans;
    	return 0;
    }
    
    
  • 相关阅读:
    [hackerrank]The Love-Letter Mystery
    *[hackerrank]Algorithmic Crush
    [hackerrank]Palindrome Index
    [hackerrank]Closest Number
    [hackerrank]Even Odd Query
    *[hackerrank]Consecutive Subsequences
    Permutation Sequence(超时,排列问题)
    Set Matrix Zeroes
    python requests的安装与简单运用(转)
    字符串指针数组,指向指针的指针
  • 原文地址:https://www.cnblogs.com/hchhch233/p/10705803.html
Copyright © 2011-2022 走看看