zoukankan      html  css  js  c++  java
  • 【WC2019】数树 树形DP 多项式exp

    题目大意

      有两棵 (n) 个点的树 (T_1)(T_2)

      你要给每个点一个权值吗,要求每个点的权值为 ([1,y]) 内的整数。

      对于一条同时出现在两棵树上的边,这条边的两个端点的值相同。

      若 (op=0),则给你两棵树 (T_1,T_2),求方案数。

      若 (op=1),则给你一棵树 (T_1),求对于所有 (n^{n-2})(T_2),方案数之和。

      若 (op=2),则求对于所有的 (T_1,T_2),求方案数之和。

      (nleq 100000)

    题解

      新建一个图 (G),把两棵树的公共边加到 (G) 中。记 (m) 为两棵树的公共边数量。那么答案就是 (y^{n-m})

      令 (z=y^{-1}),那么答案就变成了 (y^nz^m)。也就是说,每有一条相同的边,方案的贡献就要 ( imes z)

    op=0

      这个大家都会。

    op=1

    [z^m=sum_{i=0}^minom{m}{i}(z-1)^i ]

      那么可以枚举一个边集 (E),计算有多少种生成树包含 (E),然后把答案加上方案数 ( imes{(z-1)}^{lvert E vert})

      记这 (E) 条边形成了 (m) 个连通块,这些连通块的大小为 (a_1,a_2,ldots,a_m),那么贡献就是

    [egin{align} &{(z-1)}^{n-m}sum_{sum_{i=1}^md_i=2m-2}(m-2)!prod_{i=1}^mfrac{a_i^{d_i}}{(d_i-1)!}\ =&{(z-1)}^{n-m}n^{m-2}prod_{i=1}^ma_i\ end{align} ]

      (prod_{i=1}^ma_i) 可以看成是每个连通块内选一个点的方案数。这样就可以DP了。

      时间复杂度:(O(n))

    op=2

      枚举两棵树的公共边个数:

    [egin{align} s_n&=sum_{i=1}^{n}{(z-1)}^{n-i}sum_{sum_{j=1}^ia_j=n}frac{n!}{i!}(prod_{j=1}^ifrac{a_j^{a_j-2}}{a_j!})(n^{i-2}prod_{j=1}^ia_j)^2\ &=sum_{i=1}^{n}{(z-1)}^{n-i}frac{n!n^{2i-4}}{i!}sum_{sum_{j=1}^ia_j=n}prod_{j=1}^ifrac{a_j^{a_j}}{a_j!}\ &=sum_{i=1}^{n}{(z-1)}^{n-i}n^{2i-4}sum_{sum_{j=1}^ia_j=n}prod_{j=1}^iinom{(sum_{k=1}^ja_k)-1}{a_j-1}{}a_j^{a_j}\ end{align} ]

      记 (f_l=sum_{i=1}^{l}{(z-1)}^{-i}n^{2i}sum_{sum_{j=1}^ia_j=l}prod_{j=1}^iinom{(sum_{k=1}^ja_k)-1}{a_j-1}{}a_j^{a_j})

      转移时枚举最后一块的大小,有:

    [f_i=egin{cases} 1&,i=0\ sum_{j=1}^ifrac{(i-1)!n^2j^jf_{i-j}}{(i-j)!(j-1)!(z-1)}&,i>0 end{cases} ]

      直接DP是 (O(n^2)) 的。

      记 (g_i=sum_{igeq 1}frac{n^2i^i}{(i-1)!(z-1)})(F(x))(f) 的 EGF,(G(x))(g) 的 OGF,那么

    [egin{align} xF'(x)&=F(x)G(x)\ frac{F'(x)}{F(x)}&=frac{G(x)}{x}\ ln F(x)&=int frac{G(x)}{x}\ F(x)&=e^{int frac{G(x)}{x}} end{align} ]

      直接多项式 exp 就好了。

      答案为 ((z-1)^nn^{-4}f_n)

      时间复杂度:(O(nlog n))

    代码

    const ll p=998244353;
    ll fp(ll a,ll b)
    {
    	ll s=1;
    	for(;b;b>>=1,a=a*a%p)
    		if(b&1)
    			s=s*a%p;
    	return s;
    }
    const int N=100010;
    int n,op;
    ll z,_z;
    ll ans;
    namespace solve0
    {
    	map<int,int> a[N];
    	void solve()
    	{
    		if(_z==1)
    		{
    			ans=1;
    			return;
    		}
    		int x,y;
    		for(int i=1;i<n;i++)
    		{
    			io::get(x);
    			io::get(y);
    			if(x>y)
    				swap(x,y);
    			a[x][y]++;
    		}
    		ans=1;
    		for(int i=1;i<n;i++)
    		{
    			io::get(x);
    			io::get(y);
    			if(x>y)
    				swap(x,y);
    			if(a[x].count(y))
    				ans=ans*z%p;
    		}
    	}
    }
    namespace solve1
    {
    	vector<int> g[N];
    	ll f[N][2];
    	void dfs(int x,int fa)
    	{
    		f[x][0]=f[x][1]=1;
    		for(auto v:g[x])
    			if(v!=fa)
    			{
    				dfs(v,x);
    				ll s0=(f[x][0]*f[v][0]%p*z+f[x][0]*f[v][1]%p*n)%p;
    				ll s1=(f[x][0]*f[v][1]%p*z+f[x][1]*f[v][0]%p*z+f[x][1]*f[v][1]%p*n)%p;
    				f[x][0]=s0;
    				f[x][1]=s1;
    			}
    	}
    	void solve()
    	{
    		if(_z==1)
    		{
    			ans=fp(n,n-2);
    			return;
    		}
    		int x,y;
    		for(int i=1;i<n;i++)
    		{
    			io::get(x);
    			io::get(y);
    			g[x].push_back(y);
    			g[y].push_back(x);
    		}
    		z--;
    		dfs(1,0);
    		ans=f[1][1]*fp(n,p-2)%p;
    	}
    }
    namespace solve2
    {
    	const int N=270000;
    	namespace ntt
    	{
    		const int W=262144;
    		ll w[N];
    		int rev[N];
    		void init()
    		{
    			w[0]=1;
    			ll s=fp(3,(p-1)/W);
    			for(int i=1;i<W/2;i++)
    				w[i]=w[i-1]*s%p;
    		}
    		void ntt(ll *a,int n,int t)
    		{
    			for(int i=1;i<n;i++)
    			{
    				rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
    				if(rev[i]>i)
    					swap(a[i],a[rev[i]]);
    			}
    			for(int i=2;i<=n;i<<=1)
    				for(int j=0;j<n;j+=i)
    					for(int k=0;k<i/2;k++)
    					{
    						ll u=a[j+k];
    						ll v=a[j+k+i/2]*w[W/i*k];
    						a[j+k]=(u+v)%p;
    						a[j+k+i/2]=(u-v)%p;
    					}
    			if(t==-1)
    			{
    				reverse(a+1,a+n);
    				ll inv=fp(n,p-2);
    				for(int i=0;i<n;i++)
    					a[i]=a[i]*inv%p;
    			}
    		}
    		void mul(ll *a,ll *b,ll *c,int n,int m,int l)
    		{
    			static ll a1[N],a2[N];
    			int k=1;
    			while(k<=n+m)
    				k<<=1;
    			memset(a1,0,sizeof(a1[0])*k);
    			memset(a2,0,sizeof(a2[0])*k);
    			memcpy(a1,a,sizeof(a1[0])*(n+1));
    			memcpy(a2,b,sizeof(a2[0])*(m+1));
    			ntt::ntt(a1,k,1);
    			ntt::ntt(a2,k,1);
    			for(int i=0;i<k;i++)
    				a1[i]=a1[i]*a2[i]%p;
    			ntt::ntt(a1,k,-1);
    			memcpy(c,a1,sizeof(a1[0])*(l+1));
    		}
    		void inv(ll *a,ll *b,int n)
    		{
    			if(n==1)
    			{
    				b[0]=fp(a[0],p-2);
    				return;
    			}
    			inv(a,b,n>>1);
    			static ll a1[N],a2[N];
    			memset(a1,0,sizeof(a1[0])*(n<<1));
    			memset(a2,0,sizeof(a2[0])*(n<<1));
    			memcpy(a1,a,sizeof(a1[0])*n);
    			memcpy(a2,b,sizeof(a2[0])*(n>>1));
    			ntt(a1,n<<1,1);
    			ntt(a2,n<<1,1);
    			for(int i=0;i<n<<1;i++)
    				a1[i]=a2[i]*(2-a1[i]*a2[i]%p)%p;
    			ntt(a1,n<<1,-1);
    			memcpy(b,a1,sizeof(a1[0])*n);
    		}
    		void ln(ll *a,ll *b,int n)
    		{
    			static ll a1[N],a2[N],a3[N];
    			for(int i=1;i<n;i++)
    				a1[i-1]=a[i]*i%p;
    			a1[n-1]=0;
    			inv(a,a2,n);
    			mul(a1,a2,a3,n-1,n-1,n-1);
    			for(int i=1;i<n;i++)
    				b[i]=a3[i-1]*fp(i,p-2)%p;
    			b[0]=0;
    		}
    		void exp(ll *a,ll *b,int n)
    		{
    			if(n==1)
    			{
    				b[0]=1;
    				return;
    			}
    			exp(a,b,n>>1);
    			static ll a1[N],a2[N],a3[N];
    			memset(b+(n>>1),0,sizeof(b[0])*(n>>1));
    			ln(b,a3,n);
    			memset(a1,0,sizeof(a1[0])*n);
    			memset(a2,0,sizeof(a2[0])*n);
    			memcpy(a1,b,sizeof(a1[0])*(n>>1));
    			for(int i=0;i<(n>>1);i++)
    				a2[i]=a[(n>>1)+i]-a3[(n>>1)+i];
    			ntt(a1,n,1);
    			ntt(a2,n,1);
    			for(int i=0;i<n;i++)
    				a1[i]=a1[i]*a2[i]%p;
    			ntt(a1,n,-1);
    			memcpy(b+(n>>1),a1,sizeof(a1[0])*(n>>1));
    		}
    	}
    	ll inv[N],fac[N],ifac[N];
    	ll f[N],g[N],w[N];
    	void solve()
    	{
    		if(_z==1)
    		{
    			ans=fp(n,n-2)*fp(n,n-2)%p;
    			return;
    		}
    		z--;
    		ntt::init();
    		fac[0]=fac[1]=ifac[0]=ifac[1]=inv[1]=1;
    		for(int i=2;i<=n;i++)
    		{
    			fac[i]=fac[i-1]*i%p;
    			inv[i]=-p/i*inv[p%i]%p;
    			ifac[i]=ifac[i-1]*inv[i]%p;
    		}
    		ll ifacz=fp(z,p-2);
    		
    //		f[0]=1;
    //		for(int i=1;i<=n;i++)
    //			w[i]=fp(i,i);
    //		for(int i=1;i<=n;i++)
    //			for(int j=1;j<=i;j++)
    //				f[i]=(f[i]+f[i-j]*fac[i-1]%p*ifac[i-j]%p*ifac[j-1]%p*n%p*n%p*w[j]%p*ifacz)%p;
    
    
    		for(int i=1;i<=n;i++)
    			g[i]=fp(i,i)*n%p*n%p*ifac[i-1]%p*ifacz%p*inv[i]%p;
    		int k=1;
    		while(k<=n)
    			k<<=1;
    		ntt::exp(g,f,k);
    		ans=f[n]*fac[n]%p*fp(z,n)%p*fp(n,p-1-4)%p;
    	}
    }
    int main()
    {
    	freopen("tree.in","r",stdin);
    	freopen("tree.out","w",stdout);
    	io::get(n);
    	io::get(_z);
    	io::get(op);
    	z=fp(_z,p-2);
    	if(op==0)
    		solve0::solve();
    	else if(op==1)
    		solve1::solve();
    	else
    		solve2::solve();
    	ans=ans*fp(_z,n)%p;
    	ans=(ans%p+p)%p;
    	io::put(ans);
    	return 0;
    }
    
  • 相关阅读:
    092、部署Graylog日志系统(2019-05-16 周四)
    091、万能的数据收集器 Fluentd (2019-05-15 周三)
    090、ELK完成部署和使用 (2019-05-13 周二)
    在CentOS7上无人值守安装Zabbix4.2
    089、初探ELK (2019-05-13 周一)
    34、Scrapy 知识总结
    33、豆瓣图书短评
    32、出任爬虫公司CEO(爬取职友网招聘信息)
    31、当当图书榜单爬虫
    30、吃什么不会胖
  • 原文地址:https://www.cnblogs.com/ywwyww/p/10351138.html
Copyright © 2011-2022 走看看