zoukankan      html  css  js  c++  java
  • WC2019 T1 数树 题解

    题面
    题意: 本题包含3个Task:

    1. Task0:给定两棵树的边集S,T,求(bas^{n-|Sigcap T|})
    2. Task1:给定一棵树的边集S,求(sum_{T}bas^{n-|Sigcap T|})
    3. Task2:求(sum_{S}sum_{T}bas^{n-|Sigcap T|})

    其中n为点数,(n,bas)都是给定的,(nleq 10^5),答案对(998244353)取模。


    Task0

    直接模拟即可。

    #include<bits/stdc++.h>
    using namespace std;
    #define N 200007
    #define ll long long
    const ll mod=998244353;
    struct str
    {
    	int x,y;
    };
    bool operator <(str a,str b)
    {
    	return a.x<b.x||a.x==b.x&&a.y<b.y;
    }
    set<str> S;
    ll p2(ll x){return x*x%mod;}
    ll pw(ll x,ll p)
    {
    	return p?p2(pw(x,p/2))*(p&1?x:1)%mod:1;
    }
    int main()
    {
    	int x,y,n,op;
    	ll p;
    	scanf("%d%lld%d",&n,&p,&op);
    	for(int i=1;i<n;i++)
    	{
    		scanf("%d%d",&x,&y);
    		if(x>y)swap(x,y);
    		S.insert({x,y});
    	}
    	int cnt=0;
    	for(int i=1;i<n;i++)
    	{
    		scanf("%d%d",&x,&y);
    		if(x>y)swap(x,y);
    		if(S.find({x,y})!=S.end())cnt++;
    	}
    	ll ans=pw(p,n-cnt);
    	printf("%lld
    ",ans);
    	return 0;
    }
    
    

    Task1

    要求(ans=sum_{T}bas^{n-|Sigcap T|})

    我们设(ans'=sum_{T}bas^{-|Sigcap T|}),那么(ans=bas^n*ans')

    然后我们考虑怎么处理(|Sigcap T|),可以考虑先枚举它,然后再枚举S和T。

    那么我们就需要算出一个(F(E)),表示交集恰好为E的(<S,T>)的数量。

    但是这东西不好算,那么我们再设一个(G(E)),表示交集至少为E的(<S,T>)的数量。

    那么根据容斥原理我们有:

    [F(E)=sum_{Esubseteq V}(-1)^{|V|-|E|}G(V) ]

    然后我们有

    [ans'=sum_{Esubseteq S}F(E)bas^{-|E|} ]

    [=sum_{Esubseteq S}sum_{Esubseteq Vsubseteq S}(-1)^{|V|-|E|}G(V)bas^{-|E|} ]

    [=sum_{Vsubseteq S}(-1)^{|V|}G(V)sum_{Esubseteq V}(-bas)^{-|E|} ]

    [=sum_{Vsubseteq S}(-1)^{|V|}G(V)sum_{i=0}^{|V|}C_{|V|}^i(-bas^{-1})^{i} ]

    [=sum_{Vsubseteq S}(-1)^{|V|}G(V)(1-bas^{-1})^{|V|} ]

    我们设(P=(bas^{-1}-1)),那么(ans'=sum_{Vsubseteq S}G(V)P^{|V|})

    然后我们考虑(G(V))怎么算

    至少包含(V),那么可以看成先将(V)中的边全部连上,形成(n-|V|)个连通块,然后把这些连通块连起来的方案数。

    我们设将(V)中的边连上后形成的连通块大小分别为(a_1,a_2...,a_{n-|V|}),那么我们有结论:

    [G(V)=n^{n-|V|-2}prod _{i=1}^{n-|V|}a_i ]

    证明:

    我们考虑如果我们将所有连通块连成一棵树,那么对应的边有多少种方案。

    一条连接连通块i和连通块j的边的方案数为(a_i*a_j)

    那么将整棵树连出来的方案数为(prod_{i=1}^{n-|V|}a_i^{deg_i}),其中(deg_i)为连通块i在这棵树上的度数。

    我们惊奇的发现(deg_i)就是这棵树对应的purfer序列中i出现的次数加1

    那么我们就枚举所有的purfer序列,设其为P,数字i在P中的出现次数为(times_i)

    [G(V)=sum_{P}prod_{i=1}^{n-|V|}a_i^{times_i+1} ]

    [=(prod_{i=1}^{n-|V|}a_i)sum_{P}prod_{i=1}^{n-|V|}a_i^{times_i} ]

    发现后面这部分它就等于((a_1+a_2...+a_{n-|V|})^{n-|V|-2}),即(n^{n-|V|-2})

    于是就有

    [G(V)=n^{n-|V|-2}prod _{i=1}^{n-|V|}a_i ]

    证毕。

    回到我们的问题,有

    [ans'=sum_{Vsubseteq S}n^{n-|V|-2}P^{|V|}prod _{i=1}^{n-|V|}a_i ]

    为了后续操作的方便,我们将(n,P)放到连乘里面

    [ans'=n^{-2}P^{n}sum_{Vsubseteq S}prod _{i=1}^{n-|V|}a_inP^{-1} ]

    然后我们设(K=nP^{-1},ans''=sum_{Vsubseteq S}prod _{i=1}^{n-|V|}a_iK),那么(ans'=ans''*n^{-2}P^n)

    我们考虑(ans'')怎么算

    容易想到用树上dp解决,设(dp[v][i])为在v子树内,与v相连的连通块大小为i的所有方案的贡献和(注意这里的贡献不包含与v相连的那个连通块)

    另外我们规定(dp[v][0])即为v子树内的答案。

    那么有转移方程

    [dp[v][i]=sum_{j=1}^idp[v][j]*dp[u][i-j] ]

    [dp[v][0]=sum_{i=1}^ndp[v][i]*i*K ]

    初始状态为(dp[v][1]=1)

    最后的答案就是(dp[1][1])

    但是这样复杂度是(O(n^2))的,考虑怎么优化

    我们其实没有必要将每一个(dp[v][i])都算出来,我们只需要知道它们整体的一个值就可以。

    于是我们设(f[v]=sum_{i=0}^{n}dp[v][i])(g[v]=dp[v][0]=sum_{i=1}^ndp[v][i]*i*K)

    那么在转移过程中有

    [g[v]=sum_{i=1}^nsum_{j=1}^idp[v][j]*dp[u][i-j]*i*K ]

    [=sum_{i=1}^nsum_{j=0}^n(i+j)K*dp[v][i]*dp[u][j] ]

    [=sum_{i=1}^nsum_{j=0}^niK*dp[v][i]*dp[u][j]+sum_{i=1}^nsum_{j=0}^njK*dp[v][i]*dp[u][j] ]

    [=g[v]*f[u]+g[u]*f[v] ]

    [f[v]=sum_{i=1}^nsum_{j=1}^idp[v][j]*dp[u][i-j] ]

    [=(sum_{i=1}^{n}dp[v][i])(sum_{j=0}^{n}dp[u][j]) ]

    [=f[v]*f[u] ]

    最后还要(f[v]+=g[v])

    初始状态为(f[v]=1,g[v]=K)

    这样就可以(O(n))转移了

    最后的答案为(bas^nn^{-2}P^{n}g[1])

    #include<bits/stdc++.h>
    using namespace std;
    #define N 200007
    #define M 400007
    #define ll long long
    const ll mod=998244353;
    int f[N],g[N],n,P,K,sz[N];
    int hd[N],pre[M],to[M],num;
    void adde(int x,int y)
    {
    	num++;pre[num]=hd[x];hd[x]=num;to[num]=y;
    }
    void dfs(int v,int fa)
    {
    	f[v]=1,g[v]=K;
    	for(int i=hd[v];i;i=pre[i])
    	{
    		int u=to[i];
    		if(u==fa)continue;
    		dfs(u,v);
    		g[v]=(1ll*g[v]*f[u]+1ll*g[u]*f[v])%mod;
    		f[v]=1ll*f[v]*f[u]%mod;
    	}
    	f[v]=(f[v]+g[v])%mod;
    }
    ll p2(ll x){return x*x%mod;}
    ll pw(ll x,ll p)
    {
    	return p?p2(pw(x,p/2))*(p&1?x:1)%mod:1;
    }
    int main()
    {
    	//freopen("data.in","r",stdin);
    	int x,y,op,p;
    	scanf("%d%d%d",&n,&p,&op);
    	for(int i=1;i<n;i++)
    	{
    		scanf("%d%d",&x,&y);
    		adde(x,y),adde(y,x);
    	}
    	if(p==1)
    	{
    		printf("%lld
    ",pw(n,n-2));
    		return 0;
    	}
    	P=pw(pw(p,mod-2)-1,mod-2),K=1ll*n*P%mod;
    	dfs(1,0);
    	ll ans = g[1] * p2(pw(n,mod-2)) % mod * pw(pw(P,mod-2),n) % mod;
    	printf("%lld
    ",ans*pw(p,n)%mod);
    	return 0;
    }
    

    Task2

    类似Task1,我们同样可以得到

    [ans'=sum_{V}G(V)P^{|V|} ]

    其中(G(V)=(n^{n-|V|-2}prod_{i=1}^{n-|V|}a_i)^2),V为任意一棵n个点的森林的边集

    带入化简得

    [ans'=n^{-4}P^nsum_{V}prod_{i=1}^{n-|V|}(a_i^2n^2P^{-1}) ]

    (K=n^2P^{-1},ans''=sum_{V}prod_{i=1}^{n-|V|}(a_i^2K))

    (ans'=n^{-4}P^nans'')

    考虑(ans''),它相当于将n个点划分为若干连通块,每一个连通块内部构成一棵树,则每一种划分方案的贡献为(prod_{i=1}^{n-|V|}a_i^{a_i-2}(a_i^2K)),其中每一个大小为i的连通块的贡献为(i^{i-2}(i^2K)=i^iK)

    根据生成函数的一些性质,如果我们设(G(x)=sum_{i=1}^infty frac{i^iK}{i!}x^i),那么(e^{G(x)})就是(ans'')的指数型生成函数,取它的第(x^n)项再乘以(n!)就可以得到(ans'')

    多项式(exp)即可。

    最后的答案为(bas^nn^{-4}P^nans'')

    完整的代码(注意要特判(bas=1)的情况):

    #include<bits/stdc++.h>
    using namespace std;
    #define N 600007
    #define ll long long
    const ll mod=998244353;
    const int lim=2e5;
    int tp;
    ll n,bas;
    ll p2(ll x){return x*x%mod;}
    ll pw(ll x,ll p)
    {
    	return p?p2(pw(x,p/2))*(p&1?x:1)%mod:1;
    }
    namespace tp0
    {
    	struct edge
    	{
    		int x,y;
    	};
    	bool operator <(edge a,edge b)
    	{
    		return a.x<b.x||a.x==b.x&&a.y<b.y;
    	}
    	set<edge> S;
    	void work()
    	{
    		if(bas==1)
    		{
    			printf("%d
    ",1);
    			return ;
    		}
    		for(int i=1;i<n;i++)
    		{
    			int x,y;
    			scanf("%d%d",&x,&y);
    			if(x>y)swap(x,y);
    			S.insert({x,y});
    		}
    		int cnt=0;
    		for(int i=1;i<n;i++)
    		{
    			int x,y;
    			scanf("%d%d",&x,&y);
    			if(x>y)swap(x,y);
    			if(S.find({x,y})!=S.end())cnt++;
    		}
    		printf("%lld
    ",pw(bas,n-cnt));
    	}
    }
    namespace tp1
    {
    	int hd[N],pre[N],to[N],num;
    	ll f[N],g[N],K,P,Q;
    	void adde(int x,int y)
    	{
    		num++;pre[num]=hd[x];hd[x]=num;to[num]=y;
    	}
    	void dfs(int v,int fa)
    	{
    		f[v]=1,g[v]=K;
    		for(int i=hd[v];i;i=pre[i])
    		{
    			int u=to[i];
    			if(u==fa)continue;
    			dfs(u,v);
    			g[v]=(g[v]*f[u]+f[v]*g[u])%mod;
    			f[v]=f[v]*f[u]%mod;
    		}
    		f[v]=(f[v]+g[v])%mod;
    	}
    	void work()
    	{
    		if(bas==1)
    		{
    			printf("%lld
    ",pw(n,n-2));
    			return ;
    		}
    		P=(pw(bas,mod-2)-1+mod)%mod;
    		K=n*pw(P,mod-2)%mod;
    		Q=pw(n,2*(mod-2))*pw(P,n)%mod;
    		int x,y;
    		for(int i=1;i<n;i++)
    		{
    			scanf("%d%d",&x,&y);
    			adde(x,y),adde(y,x);
    		}
    		dfs(1,0);
    		ll ans=g[1];
    		ans=ans*Q%mod;
    		ans=ans*pw(bas,n)%mod;
    		printf("%lld
    ",ans);
    	}
    }
    namespace tp2
    {
    	ll inv[N],fac[N],ifac[N];
    	int rev[N],len;
    	void getlen(int n)
    	{
    		for(len=1;len<=n;len<<=1);
    		for(int i=0;i<len;i++)
    			rev[i]=rev[i>>1]>>1|(i&1?len>>1:0);
    	}
    	void NTT(ll *a,int op)
    	{
    		for(int i=0;i<len;i++)
    			if(rev[i]<i)swap(a[rev[i]],a[i]);
    		for(int i=1;i<len;i<<=1)
    		{
    			ll nw=pw(3,(mod-1)/(i<<1));
    			for(int j=0;j<len;j+=i<<1)
    			{
    				ll w=1;
    				for(int k=j;k<j+i;k++)
    				{
    					ll x=a[k],y=a[k+i]*w%mod;
    					a[k]=(x+y)%mod,a[k+i]=(x-y+mod)%mod;
    					w=w*nw%mod;
    				}
    			}
    		}
    		if(op<0)
    		{
    			reverse(a+1,a+len);
    			ll Inv=pw(len,mod-2);
    			for(int i=0;i<len;i++)
    				a[i]=a[i]*Inv%mod;
    		}
    	}
    	void copy(ll *a,ll *b,int n=len)
    	{
    		for(int i=0;i<n;i++)a[i]=b[i];
    		for(int i=n;i<len;i++)a[i]=0;
    	}
    	ll mul_c[N],mul_d[N];
    	void mul(ll *t,ll *a,ll *b)
    	{
    		ll *c=mul_c,*d=mul_d;
    		copy(c,a),copy(d,b);
    		NTT(c,1),NTT(d,1);
    		for(int i=0;i<len;i++)c[i]=c[i]*d[i]%mod;
    		NTT(c,-1);
    		copy(t,c);
    	}
    	ll inv_c[N];
    	void getinv(int p,ll *a,ll *b)
    	{
    		if(p==1)return a[0]=pw(b[0],mod-2),(void)1;
    		getinv((p+1)/2,a,b);
    		getlen(2*p);
    		ll *c=inv_c;
    		copy(c,b,p);
    		NTT(a,1),NTT(c,1);
    		for(int i=0;i<len;i++)a[i]=a[i]*(2-a[i]*c[i]%mod+mod)%mod;
    		NTT(a,-1);
    		for(int i=p;i<len;i++)a[i]=0;
    	}
    	void devir(ll *a,int n)
    	{
    		for(int i=1;i<=n;i++)a[i-1]=a[i]*i%mod;
    		a[n]=0;
    	}
    	void inter(ll *a,int n)
    	{
    		for(int i=n;i>=0;i--)a[i+1]=a[i]*inv[i+1]%mod;
    		a[0]=0;
    	}
    	ll ln_c[N];
    	void getln(int n,ll *a,ll *b)
    	{
    		ll *c=ln_c;
    		getlen(2*n);
    		copy(c,b,n);
    		getinv(n,a,c);
    		devir(c,n);
    		getlen(2*n);
    		mul(a,a,c);
    		inter(a,n);
    		for(int i=n;i<len;i++)a[i]=0;
    	}
    	ll exp_c[N];
    	void getexp(int p,ll *a,ll *b)
    	{
    		if(p==1)return a[0]=1,(void)1;
    		getexp((p+1)/2,a,b);
    		getlen(2*p);
    		ll *c=exp_c;
    		copy(c,a,0);
    		getln(p,c,a);
    		getlen(2*p);
    		for(int i=0;i<p;i++)c[i]=(b[i]-c[i]+mod)%mod;
    		c[0]=(c[0]+1)%mod;
    		mul(a,a,c);
    		for(int i=p;i<len;i++)a[i]=0;
    	}
    	void Init()
    	{
    		fac[0]=1;
    		for(int i=1;i<=lim;i++)fac[i]=fac[i-1]*i%mod;
    		ifac[lim]=pw(fac[lim],mod-2);
    		for(int i=lim;i>=1;i--)ifac[i-1]=ifac[i]*i%mod;
    		inv[1]=1;
    		for(int i=2;i<=lim;i++)
    			inv[i]=mod-mod/i*inv[mod%i]%mod;
    	}
    	ll f[N],g[N];
    	void work()
    	{
    		Init();
    		ll P,K;
    		if(bas==1)
    		{
    			printf("%lld
    ",pw(n,2*(n-2)));
    			return ;
    		}
    		P=(pw(bas,mod-2)-1+mod)%mod;
    		K=n*n%mod*pw(P,mod-2)%mod;
    		for(int i=1;i<=n;i++)
    			f[i]=pw(i,i)*K%mod*ifac[i]%mod;
    		getexp(n+1,g,f);
    		ll ans=g[n]*fac[n]%mod;
    		ans=ans*pw(n,4*(mod-2))%mod*pw(P,n)%mod;
    		ans=ans*pw(bas,n)%mod;
    		printf("%lld
    ",ans);
    	}
    }
    
    int main()
    {
    	scanf("%lld%lld%d",&n,&bas,&tp);
    	if(tp==0)tp0::work();
    	else if(tp==1)tp1::work();
    	else tp2::work();
    	return 0;
    }
    
  • 相关阅读:
    数据源ObjectDataSource的数据访问类的编写
    ASP.NET网页文本编辑器的使用
    装饰模式
    策略模式
    代理模式
    备份、还原数据库
    简单工厂和工厂模式
    ASP.NET上传多个文件
    数据库访问类的编写
    UVA 11069 A Graph Problem
  • 原文地址:https://www.cnblogs.com/lishuyu2003/p/12146391.html
Copyright © 2011-2022 走看看