zoukankan      html  css  js  c++  java
  • Codeforces 438E. The Child and Binary Tree 多项式,FFT

    原文链接www.cnblogs.com/zhouzhendong/p/CF438E.html

    前言

    没做过多项式题,来一道入门题试试刀。

    题解

    设 $a_i$ 表示节点权值和为 $i$ 的二叉树个数,特别的,我们定义 $a_0 = 1$ ,即我们认为没有节点也算一种二叉树。

    $$g(x) = sum_{i=1}^n x^{c_i}\f(x) = sum_{i=0}^{infty} a_i x^i$$

    根据组合意义可得

    $$f^2(x) g(x) + 1 = f(x) $$

    于是

    $$f^2(x) g(x) - f(x) + 1 = 0$$

    注意到 $g(0) = 0$ ,所以当 $x = 0$ 时, $f(x) = 1$ 。

    直接用求根公式得到

    $$f(x) = frac{1 pm sqrt{1 - 4g(x)}}{2g(x)}$$

    由于 $g(0) = 0$ ,所以 $g(x)$ 不存在逆元,所以 $g(x)$ 在分母上会不舒服,于是我们对式子操作一波:

    $$f(x) = frac{1 - (1 - 4g(x))}{2g(x) (1pm sqrt{1-4g(x)})}\ = frac 2 {1 pm sqrt {1-4g(x)}}$$

    接下来我们看看这个 $pm$ 到底应该是正的还是负的。

    由于 $f(0) = 1$ ,所以 $frac{2}{1pmsqrt{1-4g(0)}} = 1$ ,于是得到这里为正号。

    于是

    $$frac 2 {1+sqrt {1-4g(x)}}$$

    于是用多项式求逆和多项式开根即可解决此问题。

    时间复杂度 $O(mlog m)$ 。

    代码

    #include <bits/stdc++.h>
    #define clr(x) memset(x,0,sizeof (x))
    #define clrint(x,n) memset(x,0,(n)<<2)
    #define cpyint(a,b,n) memcpy(a,b,(n)<<2)
    #define For(i,a,b) for (int i=a;i<=b;i++)
    #define Fod(i,b,a) for (int i=b;i>=a;i--)
    #define pb(x) push_back(x)
    #define mp(x,y) make_pair(x,y)
    #define fi first
    #define se second
    #define real __zzd001
    #define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I')
    #define outval(x) printf(#x" = %d
    ",x)
    #define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("")
    #define outtag(x) puts("----------"#x"----------")
    #define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R);
    						For(_v2,L,R)printf("%d ",a[_v2]);puts("");
    using namespace std;
    typedef long long LL;
    typedef unsigned long long ULL;
    typedef vector <int> vi;
    LL read(){
    	LL x=0,f=0;
    	char ch=getchar();
    	while (!isdigit(ch))
    		f|=ch=='-',ch=getchar();
    	while (isdigit(ch))
    		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    	return f?-x:x;
    }
    const int N=1<<19,mod=998244353,inv2=(mod+1)>>1;
    const int YG=3;
    int Pow(int x,int y){
    	int ans=1;
    	for (;y;y>>=1,x=(LL)x*x%mod)
    		if (y&1)
    			ans=(LL)ans*x%mod;
    	return ans;
    }
    void Add(int &x,int y){
    	if ((x+=y)>=mod)
    		x-=mod;
    }
    void Del(int &x,int y){
    	if ((x-=y)<0)
    		x+=mod;
    }
    int Add(int x){
    	return x>=mod?x-mod:x;
    }
    int Del(int x){
    	return x<0?x+mod:x;
    }
    namespace Math{
    	int Iv[N];
    	void prework(){
    		int n=N-1;
    		Iv[1]=1;
    		For(i,2,n)
    			Iv[i]=(LL)(mod-mod/i)*Iv[mod%i]%mod;
    	}
    	map <int,int> Map;
    	int ind(int x){
    		static int M,bas;
    		if (Map.empty()){
    			M=max((int)sqrt(mod),1);
    			bas=Pow(YG,M);
    			for (int i=1,v=YG;i<=M;i++,v=(LL)v*YG%mod)
    				Map[v]=i;
    		}
    		for (int i=M,v=(LL)bas*Pow(x,mod-2)%mod;i<=mod-1+M;i+=M,v=(LL)v*bas%mod)
    			if (Map[v])
    				return i-Map[v];
    		return -1;
    	}
    }
    namespace fft{
    	int w[N],R[N];
    	int Log[N+1];
    	void init(int n){
    		if (!Log[2]){
    			For(i,2,N)
    				Log[i]=Log[i>>1]+1;
    		}
    		int d=Log[n];
    		assert(n==(1<<d));
    		For(i,0,n-1)
    			R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
    		w[0]=1,w[1]=Pow(YG,(mod-1)/n);
    		For(i,2,n-1)
    			w[i]=(LL)w[i-1]*w[1]%mod;
    	}
    	void FFT(int *a,int n,int flag){
    		if (flag<0)
    			reverse(w+1,w+n);
    		For(i,0,n-1)
    			if (i<R[i])
    				swap(a[i],a[R[i]]);
    		for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
    			for (int i=0;i<n;i+=d<<1)
    				for (int j=0;j<d;j++){
    					int tmp=(LL)w[t*j]*a[i+j+d]%mod;
    					a[i+j+d]=Del(a[i+j]-tmp);
    					Add(a[i+j],tmp);
    				}
    		if (flag<0){
    			reverse(w+1,w+n);
    			int inv=Pow(n,mod-2);
    			For(i,0,n-1)
    				a[i]=(LL)a[i]*inv%mod;
    		}
    	}
    	void CirMul(int *a,int *b,int *c,int n){
    		init(n),FFT(a,n,1),FFT(b,n,1);
    		For(i,0,n-1)
    			c[i]=(LL)a[i]*b[i]%mod;
    		FFT(c,n,-1);
    	}
    }
    using fft::FFT;
    using fft::CirMul;
    int calc_up(int x){
    	int n=1;
    	while (n<=x)
    		n<<=1;
    	return n;
    }
    void Inv(int *a,int *b,int n){
    	static int f[N],g[N];
    	b[0]=Pow(a[0],mod-2);
    	int now=1;
    	while (now<n){
    		int len=now<<2;
    		For(i,0,len-1)
    			f[i]=g[i]=0;
    		cpyint(g,b,now),now<<=1,cpyint(f,a,min(n,now));
    		fft::init(len);
    		FFT(f,len,1),FFT(g,len,1);
    		For(i,0,len-1)
    			g[i]=(2LL*g[i]-(LL)f[i]*g[i]%mod*g[i]%mod+mod)%mod;
    		FFT(g,len,-1);
    		cpyint(b,g,min(n,now));
    	}
    }
    int Sqrt(int a){
    	int k=Math::ind(a);
    	assert(~k&1);
    	k=Pow(YG,k>>1);
    	return min(k,mod-k);
    }
    void Sqrt(int *a,int *b,int n){
    	static int f[N],g[N],h[N];
    	b[0]=Sqrt(a[0]);
    	int now=1;
    	while (now<n){
    		int len=now<<2;
    		For(i,0,len-1)
    			f[i]=g[i]=h[i]=0;
    		cpyint(f,b,now),now<<=1,Inv(f,h,now),cpyint(g,a,min(n,now));
    		CirMul(g,h,g,len);
    		For(i,0,len-1)
    			f[i]=((g[i]+f[i])&1)?Add(((LL)g[i]+f[i]+mod)>>1):((g[i]+f[i])>>1);
    		cpyint(b,f,min(n,now));
    	}
    }
    void Der(int *a,int n){
    	For(i,0,n-2)
    		a[i]=(LL)a[i+1]*(i+1)%mod;
    	a[n-1]=0;
    }
    void Int(int *a,int n){
    	if (!Math::Iv[1])
    		Math::prework();
    	Fod(i,n,1)
    		a[i]=(LL)a[i-1]*Math::Iv[i]%mod;
    	a[0]=0;
    }
    void Ln(int *a,int *b,int n){
    	static int f[N],g[N];
    	int len=calc_up(n*2);
    	For(i,0,len-1)
    		f[i]=g[i]=0;
    	cpyint(f,a,n),Inv(f,g,n),Der(f,n);
    	CirMul(f,g,f,len);
    	Int(f,n),cpyint(b,f,n);
    }
    void Exp(int *a,int *b,int n){
    	static int f[N],g[N],h[N];
    	b[0]=1;
    	int now=1;
    	while (now<n){
    		int len=now<<2;
    		For(i,0,len-1)
    			f[i]=g[i]=h[i]=0;
    		cpyint(f,b,now),now<<=1,Ln(f,g,now),cpyint(h,a,min(n,now));
    		For(i,0,now-1)
    			g[i]=Del(h[i]-g[i]);
    		Add(g[0],1);
    		CirMul(f,g,f,len),cpyint(b,f,min(n,now));
    	}
    }
    void Pow(int *a,int *b,int n,int k){
    	static int f[N];
    	clrint(b,n);
    	if (k==0)
    		return (void)(b[0]=1);
    	int fir=0;
    	for (;fir<n&&!a[fir];fir++);
    	if ((LL)fir*k>=n)
    		return;
    	int m=n-fir*k;
    	cpyint(f,a+fir,m);
    	int t=Pow(f[0],k),it=Pow(f[0],mod-2);
    	For(i,0,m-1)
    		f[i]=(LL)f[i]*it%mod;
    	Ln(f,f,m);
    	For(i,0,m-1)
    		f[i]=(LL)f[i]*k%mod;
    	Exp(f,b+fir*k,m);
    	For(i,fir*k,n-1)
    		b[i]=(LL)b[i]*t%mod;
    }
    int n,m;
    int f[N],g[N];
    int main(){
    	n=read(),m=read()+1;
    	while (n--)
    		g[read()]=mod-4;
    	Add(g[0],1);
    	Sqrt(g,f,m);
    	Add(f[0],1);
    	Inv(f,g,m);
    	For(i,1,m-1)
    		printf("%d
    ",g[i]=g[i]*2%mod);
    	return 0;
    }
    

      

  • 相关阅读:
    LeetCode 50: Pow(x, n)
    Codility 1: equilibrium
    LeetCode 49: Anagrams
    crackme160--(5)
    crackme160--(4)
    魔兽显血改键工具的脱壳和修改
    crackme160--(3)
    crackme160--(1)
    python-装饰器
    编程语言C++01-C++入门学习路线
  • 原文地址:https://www.cnblogs.com/zhouzhendong/p/CF438E.html
Copyright © 2011-2022 走看看