zoukankan      html  css  js  c++  java
  • LOJ#2983. 「WC2019」数树 排列组合,生成函数,多项式,FFT

    原文链接www.cnblogs.com/zhouzhendong/p/LOJ2983.html

    前言

    我怎么什么都不会?贺忙指导博客才会做。

    题解

    我们分三个子问题考虑。

    子问题0

    将红蓝共有的边连接,每一个连通块的颜色相同,不同连通块独立。

    答案是 (y ^ {连通块数})

    子问题1

    对于红树的一种连接方案,假设将在蓝树上也有的边连接起来,假设连了 (i) 条边,那么对答案的贡献就是:

    [y ^ n / y ^ i ]

    [z = frac 1 y ]

    根据二项式定理

    [z ^ a = sum_{i=0}^a inom{a}{i} (z-1)^i ]

    于是得到贡献是

    [sum_{j=0}^{n-i} inom{n-i}{j} (z -1) ^ j ]

    组合意义就是枚举所有边的子集算答案。

    所以答案是

    [y ^ n sum_{i = 0} ^ {n-1} (z-1) ^ j sum n ^ {n-i-2} prod_k a_k ]

    其中 (a_k) 表示第 (k) 个连通块的大小。

    考虑进一步展开组合意义:

    (prod _k a_k) 的含义就是每一个连通块取一个点的方案数,所以对蓝树进行树形DP,用 dp[x][0] 表示当前连通块没有选点的方案数,dp[x][1] 表示当前连通块已经选了一个的方案数。大力转移即可。

    时间复杂度 (O(n))

    子问题2

    考虑写出答案的式子

    [ans = y ^ n sum_{i=1}^ n (z- 1 ) ^ {n-i} frac{n!}{i!prod a_j!}left (prod a_j^{a_j} ight)(n ^ {i-2}) ^ 2 \ = y ^ n n ^ {-4} (z-1) ^ n sum_{i=1}^ n frac{n!}{i!prod a_j!}prod (a_j^{a_j} (z-1)^ {-1}n ^ 2) ]

    注意到

    [sum_{i=1}^ n frac{n!}{i!prod a_j!}prod (a_j^{a_j} (z-1)^ {-1}n ^ 2) = [n] exp(sum_{igeq 1 } a_j^{a_j} (z-1)^ {-1}n ^ 2frac{x^i}{i!}) ]

    于是运用多项式 exp 即可在 (O(nlog n)) 的时间复杂度内解决这个问题。

    代码

    #include <bits/stdc++.h>
    #define clr(x) memset(x,0,sizeof x)
    #define For(i,a,b) for (int i=(a);i<=(b);i++)
    #define Fod(i,b,a) for (int i=(b);i>=(a);i--)
    #define fi first
    #define se second
    #define pb(x) push_back(x)
    #define mp(x,y) make_pair(x,y)
    #define outval(x) cerr<<#x" = "<<x<<endl
    #define outtag(x) cerr<<"---------------"#x"---------------"<<endl
    #define outarr(a,L,R) cerr<<#a"["<<L<<".."<<R<<"] = ";
    						For(_x,L,R)cerr<<a[_x]<<" ";cerr<<endl;
    using namespace std;
    typedef long long LL;
    LL read(){
    	LL x=0,f=0;
    	char ch=getchar();
    	while (!isdigit(ch))
    		f|=ch=='-',ch=getchar();
    	while (isdigit(ch))
    		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    	return f?-x:x;
    }
    const int mod=998244353;
    int Pow(int x,int y){
    	if (y<0)
    		x=Pow(x,mod-2),y=-y;
    	int ans=1;
    	for (;y;y>>=1,x=(LL)x*x%mod)
    		if (y&1)
    			ans=(LL)ans*x%mod;
    	return ans;
    }
    void Add(int &x,int y){
    	if ((x+=y)>=mod)
    		x-=mod;
    }
    void Del(int &x,int y){
    	if ((x-=y)<0)
    		x+=mod;
    }
    int Add(int x){
    	return x>=mod?x-mod:x;
    }
    int Del(int x){
    	return x<0?x+mod:x;
    }
    const int N=(1<<19)+1;
    int Fac[N],Inv[N],Iv[N];
    void getFI(){
    	int n=N-1;
    	for (int i=Fac[0]=1;i<=n;i++)
    		Fac[i]=(LL)Fac[i-1]*i%mod;
    	Inv[n]=Pow(Fac[n],mod-2);
    	Fod(i,n,1)
    		Inv[i-1]=(LL)Inv[i]*i%mod;
    	For(i,1,n)
    		Iv[i]=(LL)Inv[i]*Fac[i-1]%mod;
    }
    namespace fft{
    	 int w[N],R[N];
    	 void init(int n){
    	 	int d=0;
    	 	while ((1<<d)<n)
    	 		d++;
    	 	For(i,0,n-1)
    	 		R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
    	 	w[0]=1,w[1]=Pow(3,(mod-1)/n);
    	 	For(i,2,n-1)
    	 		w[i]=(LL)w[i-1]*w[1]%mod;
    	 }
    	 void FFT(int *a,int n,int flag){
    	 	if (flag<0)
    	 		reverse(w+1,w+n);
    	 	For(i,0,n-1)
    	 		if (i<R[i])
    	 			swap(a[i],a[R[i]]);
    	 	for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
    	 		for (int i=0;i<n;i+=d<<1)
    	 			for (int j=0;j<d;j++){
    	 				int tmp=(LL)w[t*j]*a[i+j+d]%mod;
    	 				a[i+j+d]=Del(a[i+j]-tmp);
    	 				Add(a[i+j],tmp);
    	 			}
    	 	if (flag<0){
    	 		reverse(w+1,w+n);
    	 		int inv=Pow(n,mod-2);
    	 		For(i,0,n-1)
    	 			a[i]=(LL)a[i]*inv%mod;
    	 	}
    	 }
    }
    using fft::FFT;
    typedef vector <int> vi;
    vi Fix(vi a,int n){
    	while (a.size()>n)
    		a.pop_back();
    	while (a.size()<n)
    		a.pb(0);
    	return a;
    }
    vi operator * (vi a,vi b){
    	int s=(int)a.size()+b.size()-1,n=1;
    	while (n<s)
    		n<<=1;
    	a=Fix(a,n),b=Fix(b,n);
    	fft::init(n);
    	FFT(&a[0],n,1),FFT(&b[0],n,1);
    	For(i,0,n-1)
    		a[i]=(LL)a[i]*b[i]%mod;
    	FFT(&a[0],n,-1);
    	return Fix(a,s);
    }
    vi operator + (vi a,vi b){
    	int s=max(a.size(),b.size());
    	a=Fix(a,s),b=Fix(b,s);
    	For(i,0,s-1)
    		Add(a[i],b[i]);
    	return a;
    }
    vi operator - (vi a,vi b){
    	int s=max(a.size(),b.size());
    	a=Fix(a,s),b=Fix(b,s);
    	For(i,0,s-1)
    		Del(a[i],b[i]);
    	return a;
    }
    vi pInv(vi a){
    	if (a.size()==1)
    		return (vi){Pow(a[0],mod-2)};
    	int n=a.size();
    	vi b=pInv(Fix(a,(n+1)>>1));
    	return Fix(b+b-b*b*a,n);
    }
    vi Der(vi a){
    	int n=a.size();
    	For(i,0,n-2)
    		a[i]=(LL)a[i+1]*(i+1)%mod;
    	return Fix(a,n-1);
    }
    vi Int(vi a){
    	int n=a.size();
    	a.pb(0);
    	Fod(i,n,1)
    		a[i]=(LL)a[i-1]*Iv[i]%mod;
    	a[0]=0;
    	return a;
    }
    vi Ln(vi a){
    	return Int(Fix(Der(a)*pInv(a),a.size()-1));
    }
    vi Exp(vi a){
    	if (a.size()==1)
    		return (vi){1};
    	int n=a.size();
    	vi b=Fix(Exp(Fix(a,(n+1)>>1)),n);
    	return Fix(b*((vi){1}-Ln(b)+a),n);
    }
    int n,z,op;
    namespace so0{
    	map <pair <int,int>,int> Map;
    	int main(){
    		Map.clear();
    		For(i,1,n-1){
    			int x=read(),y=read();
    			if (x>y)
    				swap(x,y);
    			Map[mp(x,y)]=1;
    		}
    		int c=n;
    		For(i,1,n-1){
    			int x=read(),y=read();
    			if (x>y)
    				swap(x,y);
    			c-=Map[mp(x,y)];
    		}
    		cout<<Pow(z,c)<<endl;
    		return 0;
    	}
    }
    namespace so1{
    	int inv_n,izn;
    	vector <int> e[N];
    	int size[N];
    	int dp[N][2];
    	void dfs(int x,int pre){
    		dp[x][0]=dp[x][1]=1;
    		for (auto y : e[x])
    			if (y!=pre){
    				dfs(y,x);
    				int t0=dp[x][0],t1=dp[x][1];
    				dp[x][0]=(LL)t0*dp[y][1]%mod;
    				dp[x][1]=(LL)t1*dp[y][1]%mod;
    				Add(dp[x][0],(LL)t0*dp[y][0]%mod*izn%mod);
    				Add(dp[x][1],(LL)t0*dp[y][1]%mod*izn%mod);
    				Add(dp[x][1],(LL)t1*dp[y][0]%mod*izn%mod);
    			}
    	}
    	int main(){
    		if (z==1){
    			cout<<Pow(n,n-2)<<endl;
    			return 0;
    		}
    		inv_n=Pow(n,mod-2);
    		izn=Del(Pow(z,mod-2)-1);
    		izn=(LL)izn*inv_n%mod;
    		For(i,1,n-1){
    			int x=read(),y=read();
    			e[x].pb(y),e[y].pb(x);
    		}
    		dfs(1,0);
    		int ans=(LL)dp[1][1]*Pow(n,n-2)%mod*Pow(z,n)%mod;
    		cout<<ans<<endl;
    		return 0;
    	}
    }
    namespace so2{
    	int main(){
    		if (z==1){
    			cout<<Pow(n,(n-2)*2)<<endl;
    			return 0;
    		}
    		getFI();
    		int iz=Del(Pow(z,mod-2)-1),tmp=(LL)Pow(iz,-1)*n%mod*n%mod;
    		vi a;
    		a.pb(0);
    		For(i,1,n)
    			a.pb((LL)Pow(i,i)*tmp%mod*Inv[i]%mod);
    		a=Exp(a);
    		int ans=(LL)a[n]*Fac[n]%mod;
    		ans=(LL)ans*Pow(z,n)%mod*Pow(iz,n)%mod*Pow(n,-4)%mod;
    		cout<<ans<<endl;
    		return 0;
    	}
    }
    int main(){
    	n=read(),z=read(),op=read();
    	if (op==0)
    		return so0::main();
    	else if (op==1)
    		return so1::main();
    	else if (op==2)
    		return so2::main();
    	return 0;
    }
    
  • 相关阅读:
    Android SDK 在线更新镜像服务器
    Android Studio (Gradle)编译错误
    java ZIP压缩文件
    java文件操作(输出目录、查看磁盘符)
    JXL读取写入excel表格数据
    Linux命令zip和unzip
    Linux查看系统基本信息
    Ubuntu C++环境支持
    Linux开机执行bash脚本
    ubuntu中磁盘挂载与卸载
  • 原文地址:https://www.cnblogs.com/zhouzhendong/p/LOJ2983.html
Copyright © 2011-2022 走看看