zoukankan      html  css  js  c++  java
  • [LibreOJ #2983]【WC2019】数树【计数】【DP】【多项式】

    Description

    此题含有三个子问题
    问题1:
    给出n个点的两棵树,记m为只保留同时在两棵树中的边时连通块的个数,求(y^m)
    问题2:
    给出n个点的一棵树,另外一棵树任意生成,求所有方案总的(y^m)的和
    问题3:
    两棵树均任意生成,求所有方案总的(y^m)的和
    n<=100000,答案对998244353取模

    Solution

    问题1:

    求出同时在两棵树中的边数v,容易得到m=n-v,直接算即可。

    问题2:

    (z=y^{-1})即y在模意义下的逆元,我们就是要求(z^v)的总和
    直接计算比较困难
    考虑这样一个转化(z^v=sumlimits_{i=0}^{v}(z-1)^i{vchoose i})
    考虑它的组合意义,实际上就是我们任意枚举一个重合边集E的子集E0,贡献就是((z-1)^{|E_0|})

    我们不妨枚举原树边集的子集(E_0),再计算有多少个生成树包含(E_0)
    假设(E_0)构成了m个连通块,它们的点数分别是(a_1,a_2,...,a_m)

    那么包含(E_0)的生成树个数就是(n^{m-2}prodlimits_{i=1}^{m} a_i),总的贡献就是((z-1)^{n-m}n^{m-2}prodlimits_{i=1}^{m} a_i)
    这东西考虑用prufer序来理解,我们将一个连通块看做一个大的点,那么m-2个连通块在prufer序上都有一个“父亲”,但是这个父亲是一个单点,因此是(n^{m-2}),此外后面的大小之积相当于枚举每个连通块prufer序连出去的边是哪一条。

    这显然可以用一个n^2的树形DP来完成,考虑优化
    考虑上面的式子的组合意义,我们把常数因子提出来(((z-1)^nn^{-2})),相当于每个连通块都有一个固定的贡献((z-1)^{-1}n),此外再在每个连通块中选择一个点,求总的贡献和。

    这就可以用树形DP来做了,记(F[i][0/1])表示当前做完了i的子树,i所在的连通块是否已经贡献过。
    直接讨论相邻边是否出现即可转移,时间复杂度(O(N)),非常的巧妙。

    问题3:

    有了问题2的基础,我们来考虑问题3
    不妨同样枚举重合边集的一个子集(E_0),m,a的定义同上
    那么贡献变成了$$(z-1)^{n-m}left(n^{m-2}prod a_i ight)^2$$
    平方乘到指数上,依然是将常数因子提出来,把贡献分配到每个连通块上
    此外由于两棵树都是生成的,因此连通块内部的边也不确定,还要乘上连通块内部的生成树个数

    那么一个连通块i的贡献就是
    ((z-1)^{-1}n^2a_i^2a_i^{a_i-2}=(z-1)^{-1}n^2a_i^{a_i})

    那么上面的式子相当于连通块的带标号重复拼接
    一个连通块的EGF就是$$F(x)=sumlimits_{i>0}{(z-1)^{-1}n^2i^iover i!}$$
    拼接以后就是([x^n]e^{F(x)}),乘上前面提出来的贡献就是答案了。

    Code

    #include <bits/stdc++.h>
    #define fo(i,a,b) for(int i=a;i<=b;++i)
    #define fod(i,a,b) for(int i=a;i>=b;--i)
    #define N 100005
    #define LL long long
    #define mo 998244353
    using namespace std;
    int n,tp;
    LL m;
    LL ksm(LL k,LL n)
    {
        LL s=1;
        for(;n;n>>=1,k=k*k%mo) if(n&1) s=s*k%mo;
        return s;
    }
    namespace subtask1
    {
        map<int,bool> mp[N];
        void solve1()
        {
            fo(i,1,n-1)
            {
                int x,y;
                scanf("%d%d",&x,&y);
                mp[x][y]=mp[y][x]=1;
            }
            int c=0;
            fo(i,1,n-1)
            {
                int x,y;
                scanf("%d%d",&x,&y);
                if(mp[x][y]) c++;
            }
            printf("%lld
    ",ksm(m,n-c));
        }
    }
    using namespace subtask1;
    
    namespace subtask2
    {
        LL f[N][2],rm,pm;
        int fs[N],m1,nt[2*N],dt[2*N];
        void link(int x,int y)
        {
            nt[++m1]=fs[x];
            dt[fs[x]=m1]=y;
        }
        void dp(int k,int fa)
        {
            f[k][0]=1,f[k][1]=0;
            for(int i=fs[k];i;i=nt[i])
            {
                int p=dt[i];
                if(p!=fa)
                {
                    dp(p,k);
                    f[k][1]=(f[k][1]*(f[p][0]+f[p][1])%mo+f[k][0]*f[p][1])%mo;
                    f[k][0]=f[k][0]*(f[p][0]+f[p][1])%mo;
                }
            }
            f[k][1]=(f[k][1]+f[k][0]*pm%mo*(LL)n)%mo;
        }
        void solve2()
        {
        	if(m==1) 
        	{
        		printf("%lld
    ",ksm(n,n-2));
        		return;
    		}
            fo(i,1,n-1)
            {
                int x,y;
                scanf("%d%d",&x,&y);
                link(x,y),link(y,x);
            }	
            rm=ksm(m,mo-2);
            pm=ksm(rm-1,mo-2);
            dp(1,0);
            printf("%lld
    ",f[1][1]*ksm(ksm(n,mo-2),2)%mo*ksm(rm-1,n)%mo*ksm(m,n)%mo);
        }
    }
    using namespace subtask2;
    
    namespace subtask3
    {
    	#define M 262144
    	LL wg[M+1],wi[M+1],a[M+1],b[M+1],js[M+1],ny[M+1],ns[M+1];
    	int bit[M+1],cf[19],l2[M+1];
    	void prp(int num)
    	{
    		fo(i,0,num)
    		{
    			wi[i]=wg[i*(M/num)];
    			bit[i]=(bit[i>>1]>>1)|((i&1)<<(l2[num]-1));
    		}
    	}
    	void NTT(LL *a,bool pd,int num)
    	{
    		LL v;
    		fo(i,0,num-1) if(i<bit[i]) swap(a[i],a[bit[i]]);
    		for(int h=1,m=2,l=num>>1;m<=num;h=m,m<<=1,l>>=1)
    		{
    			int c=(!pd)?l:-l;
    			for(int j=0;j<num;j+=m)
    			{
    				LL *x=a+j,*y=a+h+j,*w=(!pd)?wi:wi+num;
    				fo(i,0,h-1)
    				{
    					v=*y * *w%mo;
    					*y=(*x-v+mo)%mo,*x=(*x+v)%mo;
    					x++,y++,w+=c;
    				}	
    			}
    		}
    		if(pd) fo(i,0,num-1) a[i]=a[i]*ny[num]%mo;
    	}
    	void getinv(int n,LL *a,LL *b)
        {
        	static LL u1[M+1],u2[M+1];
        	fo(i,0,cf[l2[n]]-1) b[i]=0;
        	b[0]=ksm(a[0],mo-2);
        	for(int m=1,t=2,num=4;m<n;m=t,t=num,num<<=1)
        	{
        		prp(num);
        		fo(i,0,num-1) u1[i]=u2[i]=0;
    			fo(i,0,m-1) u1[i]=b[i];
        		fo(i,0,t-1) u2[i]=a[i];
        		NTT(u1,0,num),NTT(u2,0,num);
        		fo(i,0,num-1) u1[i]=u1[i]*u1[i]%mo*u2[i]%mo;
        		NTT(u1,1,num);
        		fo(i,0,t-1) b[i]=((LL)2*b[i]-u1[i]+mo)%mo;
    		}
    	}
    	void getln(int n,LL *a,LL *b)
    	{
    		static LL u1[M+1],u2[M+1];
    		int num=cf[l2[2*n+1]];
    		getinv(n+1,a,u1);
    		fo(i,n+1,num-1) u1[i]=0;
    		prp(num);
    		fo(i,0,n-1) u2[i]=a[i+1]*(LL)(i+1)%mo;
    		fo(i,n,num-1) u2[i]=0;
    		NTT(u1,0,num),NTT(u2,0,num);
    		fo(i,0,num-1) u1[i]=u1[i]*u2[i]%mo;
    		NTT(u1,1,num);
    		b[0]=0;
    		fo(i,1,n) b[i]=u1[i-1]*ny[i]%mo;
    	}
    	void getexp(int n,LL *a,LL *b)
    	{
    		static LL u1[M+1],u2[M+1];
    		b[0]=1;
    		fo(i,1,cf[l2[n]]) b[i]=0;
    		for(int m=1,t=2,num=4;m<=n;m=t,t=num,num<<=1)
    		{
    			getln(t-1,b,u1);
    			fo(i,0,t-1) u1[i]=(-u1[i]+a[i]+mo+mo)%mo;
    			fo(i,t,num-1) u1[i]=0;
     			u1[0]++;
    			fo(i,0,m-1) u2[i]=b[i];
    			fo(i,m,num-1) u2[i]=0;
    			prp(num); 
    			NTT(u1,0,num),NTT(u2,0,num);
    			fo(i,0,num-1) u1[i]=u1[i]*u2[i]%mo;
    			NTT(u1,1,num);
    			fo(i,0,t-1) b[i]=u1[i];
    		}
    	}
    	void solve3()
        {
        	if(m==1) {printf("%lld
    ",ksm(n,2*n-4));return;}
            wg[0]=1,wg[1]=ksm(3,(mo-1)/M);
            fo(i,2,M) wg[i]=wg[i-1]*wg[1]%mo;
            fo(i,0,18) l2[cf[i]=1<<i]=i;
            fod(i,M-1,2) if(!l2[i]) l2[i]=l2[i+1];
            
            js[0]=ns[0]=js[1]=ns[1]=ny[1]=1;
            fo(i,2,M) js[i]=js[i-1]*(LL)i%mo,ny[i]=(-ny[mo%i]*(LL)(mo/i)%mo+mo)%mo;
            fo(i,2,M) ns[i]=ns[i-1]*ny[i]%mo;
            
            LL rm=ksm(m,mo-2),pm=ksm(rm-1,mo-2);
    		fo(i,1,n) a[i]=pm*(LL)n%mo*(LL)n%mo*ksm(i,i)%mo*ns[i]%mo;
           	getexp(n,a,b);
    		printf("%lld
    ",b[n]*ksm(ksm(n,mo-2),4)%mo*ksm(rm-1,n)%mo*ksm(m,n)%mo*js[n]%mo);
    	}
    }
    using namespace subtask3;
    
    int main()
    {
        cin>>n>>m>>tp;
        if(tp==0) solve1();
        else if(tp==1) solve2();
        else solve3();
    }
    
  • 相关阅读:
    商标查询网
    java: jsp:param中文乱码
    java:maven中webapp下的jsp不能访问web-inf下面的bean
    java:类集回顾
    java:类集操作,多对多的关系
    java:类集操作总结
    java:练习学校学生
    php发邮件:swiftmailer, php邮件库——swiftmailer
    java:练习超市卖场
    phalcon: 按年分表的model怎么建?table2017,table2018...相同名的分表模型怎么建
  • 原文地址:https://www.cnblogs.com/BAJimH/p/10793299.html
Copyright © 2011-2022 走看看