zoukankan      html  css  js  c++  java
  • 【XSY2666】排列问题 DP 容斥原理 分治FFT

    题目大意

      有(n)种颜色的球,第(i)种有(a_i)个。设(m=sum a_i)。你要把这(m)个小球排成一排。有(q)个询问,每次给你一个(x),问你有多少种方案使得相邻的小球同色的对数为(x)

      (nleq 10000,mleq 200000)

    题解

      我们考虑把这些小球分段,每段内所有小球颜色相同,但相邻两段的小球颜色可以相同。

      设第(i)种颜色有(b_i)段,那么分(j)段的方案数是(frac{(sum b_i)!}{sum(bi!)}=frac{j!}{sum(bi!)})

      那么先DP,设(f_{i,j})为前(i)种颜色,分了(j)段的方案数(div b_i!)显然枚举第(i)中颜色分(k)段得

    [f_{i,j}+=f_{i-1,j-k} imes inom{a_i-1}{k-1} imesfrac{1}{k!} ]

      那个组合数是插板法得到的。

      这个DP的时间复杂度是(O(m^2))(因为枚举第(i)种颜色时(k=1ldots a_i,j=1ldots s_i)(s)(a)的前缀和))

      然后这个东西可以分治FFT优化到(O(mlog mlog n))

      这样我们得到了分成(i)段的方案数(g_i=f_{n,i} imes i!),但相邻两段可能颜色相同。我们还要减掉这种情况。

      就是对于一种实际上分成 (j) 段的方案,它在分成 (i) 段的方案数中会被计算 (inom{m-j}{m-i}) 次(就是在 (m-j) 个间隔中取 (m-i) 个)。

      答案 (ans_i=g_i-sum_{j<i}ans_jinom{m-j}{i-j})

      可以简单暴力的通过分治FFT优化到(O(mlog^2 m))。但有更好的做法。

      考虑容斥。其实总的(g_j)(ans_i)的贡献就是({(-1)}^{i-j}inom{m-j}{i-j})。直接FFT一次就可以得到答案。

    [egin{align} ans_{k->i}&=sum_{j=k}^{i-1}{(-1)^{j-k}}inom{m-k}{j-k}inom{m-j}{i-j}\ &=sum_{j=k}^{i-1}{(-1)^{j-k}}frac{(m-k)!(m-j)!}{(j-k)!(m-j)!(i-j)!(m-i)!}\ &=sum_{j=k}^{i-1}{(-1)^{j-k}}frac{(m-k)!}{(j-k)!(i-j)!(m-i)!}\ &=frac{(m-k)!}{(m-i)!(i-k)!}sum_{j=k}^{i-1}{(-1)^{j-k}}frac{(i-k)!}{(i-j)!(j-k)!}\ &=inom{m-k}{i-k}sum_{j=k}^{i-1}{(-1)^{j-k}}inom{i-k}{j-k}\ &=inom{m-k}{i-k}{(-1)}^{i-k} end{align} ]

      那么相邻的小球同色的对数为(x)的答案就是(ans_{m-x})

      时间复杂度:(O(mlog mlog n+q))

    代码

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cstdlib>
    #include<ctime>
    #include<utility>
    #include<cmath>
    #include<functional>
    #include<vector>
    #include<queue>
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int,int> pii;
    typedef pair<ll,ll> pll;
    void sort(int &a,int &b)
    {
    	if(a>b)
    		swap(a,b);
    }
    void open(const char *s)
    {
    #ifndef ONLINE_JUDGE
    	char str[100];
    	sprintf(str,"%s.in",s);
    	freopen(str,"r",stdin);
    	sprintf(str,"%s.out",s);
    	freopen(str,"w",stdout);
    #endif
    }
    int rd()
    {
    	int s=0,c;
    	while((c=getchar())<'0'||c>'9');
    	do
    	{
    		s=s*10+c-'0';
    	}
    	while((c=getchar())>='0'&&c<='9');
    	return s;
    }
    void put(int x)
    {
    	if(!x)
    	{
    		putchar('0');
    		return;
    	}
    	static int c[20];
    	int t=0;
    	while(x)
    	{
    		c[++t]=x%10;
    		x/=10;
    	}
    	while(t)
    		putchar(c[t--]+'0');
    }
    int upmin(int &a,int b)
    {
    	if(b<a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    int upmax(int &a,int b)
    {
    	if(b>a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    const int p=998244353;
    int fp(int a,int b)
    {
    	int s=1;
    	for(;b;b>>=1,a=1ll*a*a%p)
    		if(b&1)
    			s=1ll*s*a%p;
    	return s;
    }
    int inv[600010];
    int fac[600010];
    int ifac[600010];
    namespace ntt
    {
    	const int g=3;
    	int rev[600010];
    	int w1[600010];
    	int w2[600010];
    	int n;
    	void init(int m)
    	{
    		n=1;
    		while(n<=m)
    			n<<=1;
    		int i;
    		rev[0]=0;
    		for(i=1;i<n;i++)
    			rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
    		for(i=1;i<=n;i<<=1)
    		{
    			w1[i]=fp(g,(p-1)/i);
    			w2[i]=fp(w1[i],p-2);
    		}
    	}
    	void ntt(int *a,int t)
    	{
    		int i,j,k;
    		int u,v,w,wn;
    		for(i=0;i<n;i++)
    			if(rev[i]<i)
    				swap(a[i],a[rev[i]]);
    		for(i=2;i<=n;i<<=1)
    		{
    			wn=(t==1?w1[i]:w2[i]);
    			for(j=0;j<n;j+=i)
    			{
    				w=1;
    				for(k=j;k<j+i/2;k++)
    				{
    					u=a[k];
    					v=1ll*a[k+i/2]*w%p;
    					a[k]=(u+v)%p;
    					a[k+i/2]=(u-v)%p;
    					w=1ll*w*wn%p;
    				}
    			}
    		}
    		if(t==-1)
    		{
    			int inv=fp(n,p-2);
    			for(i=0;i<n;i++)
    				a[i]=1ll*a[i]*inv%p;
    		}
    	}
    };
    int g[600010];
    int h[600010];
    int ans[600010];
    int a[600010];
    int s[600010];
    int n,m;
    void add(int &a,int b)
    {
    	a=(a+b)%p;
    }
    typedef vector<int> vec;
    vec mul(vec &a,vec &b)
    {
    	static int c[600010],d[600010];
    	int n1=a.size()-1;
    	int n2=b.size()-1;
    	int m=n1+n2+1;
    	ntt::init(m);
    	int i;
    	for(i=0;i<=n1;i++)
    		c[i]=a[i];
    	for(i=n1+1;i<ntt::n;i++)
    		c[i]=0;
    	for(i=0;i<=n2;i++)
    		d[i]=b[i];
    	for(i=n2+1;i<ntt::n;i++)
    		d[i]=0;
    	ntt::ntt(c,1);
    	ntt::ntt(d,1);
    	for(i=0;i<ntt::n;i++)
    		c[i]=1ll*c[i]*d[i]%p;
    	ntt::ntt(c,-1);
    	vec s(n1+n2+1);
    	for(i=1;i<=n1+n2;i++)
    		s[i]=c[i];
    	return s;
    }
    vec solve(int l,int r)
    {
    	if(l==r)
    	{
    		vec s(a[l]+1);
    		int i;
    		for(i=1;i<=a[l];i++)
    			s[i]=1ll*ifac[i-1]*ifac[i]%p*ifac[a[l]-i]%p;
    		return s;
    	}
    	int mid=(l+r)>>1;
    	vec s1=solve(l,mid);
    	vec s2=solve(mid+1,r);
    	return mul(s1,s2);
    }
    int c[600010];
    int d[600010];
    priority_queue<pii,vector<pii>,greater<pii> > q;
    void gao()
    {
    	int i;
    	c[0]=0;
    	for(i=1;i<=m;i++)
    		c[i]=g[i];
    	for(i=0;i<=m;i++)
    	{
    		d[i]=ifac[i];
    		if(i&1)
    			d[i]=-d[i];
    	}
    	ntt::init(2*m);
    	for(i=m+1;i<ntt::n;i++)
    		c[i]=d[i]=0;
    	ntt::ntt(c,1);
    	ntt::ntt(d,1);
    	for(i=0;i<ntt::n;i++)
    		c[i]=1ll*c[i]*d[i]%p;
    	ntt::ntt(c,-1);
    	for(i=1;i<=m;i++)
    		g[i]=c[i];
    }
    int t=0;
    vec f[20010];
    int main()
    {
    	open("c");
    	scanf("%d",&n);
    	int i;
    	for(i=1;i<=n;i++)
    	{
    		scanf("%d",&a[i]);
    		s[i]=s[i-1]+a[i];
    	}
    	m=s[n];
    	inv[0]=inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
    	for(i=2;i<=m;i++)
    	{
    		inv[i]=-1ll*p/i*inv[p%i]%p;
    #ifndef ONLINE_JUDGE
    		inv[i]=(inv[i]+p)%p;
    #endif
    		fac[i]=1ll*fac[i-1]*i%p;
    		ifac[i]=1ll*ifac[i-1]*inv[i]%p;
    	}
    //	f[0][0]=1;
    	int times=1;
    	for(i=1;i<=n;i++)
    		times=1ll*times*fac[a[i]-1]%p;
    //	for(i=1;i<=n;i++)
    //	{
    //		times=times*fac[a[i]-1]%p;
    //		for(j=1;j<=s[i];j++)
    //		{
    //			for(k=1;k<=a[i]&&k<=j;k++)
    //				add(f[i][j],f[i-1][j-k]*ifac[k-1]%p*ifac[a[i]-k]%p*ifac[k]%p);
    ////				add(f[i][j],f[i-1][j-k]*c(a[i]-1,k-1)%p*ifac[k]%p);
    ////			f[i][j]=f[i][j]*fac[a[i]-1]%p;
    //		}
    //	}
    	int j;
    	for(i=1;i<=n;i++)
    	{
    		f[i].resize(a[i]+1);
    		for(j=1;j<=a[i];j++)
    			f[i][j]=1ll*ifac[j-1]*ifac[j]%p*ifac[a[i]-j]%p;
    		q.push(pii(a[i],i));
    	}
    	t=n;
    	for(i=1;i<n;i++)
    	{
    		int n1=q.top().first;
    		int x=q.top().second;
    		q.pop();
    		int n2=q.top().first;
    		int y=q.top().second;
    		q.pop();
    		f[++t]=mul(f[x],f[y]);
    		f[x].clear();
    		f[y].clear();
    		q.push(pii(n1+n2+1,t));
    	}
    	vec ss=f[t];
    //	vec ss=solve(1,n);
    	for(i=1;i<=m;i++)
    		g[i]=1ll*ss[i]*fac[i]%p*times%p;
    #ifndef ONLINE_JUDGE
    	for(i=1;i<=m;i++)
    		add(g[i],p);
    #endif
    //		g[i]=f[n][i]*fac[i]%p*times%p;	
    	for(i=1;i<=m;i++)
    		g[i]=1ll*g[i]*fac[m-i]%p;
    	gao();
    	for(i=1;i<=m;i++)
    	{
    		g[i]=1ll*g[i]*ifac[m-i]%p;
    		add(g[i],p);
    	}
    //	for(i=1;i<=m;i++)
    //	{
    //		for(j=1;j<i;j++)
    //			add(ans[i],h[j]%p*ifac[i-j]%p);
    //		ans[i]=-ans[i]*ifac[m-i]%p;
    //		ans[i]=(ans[i]+g[i])%p;
    //			add(ans[i],-ans[j]*c(m-j,i-j));
    //		add(ans[i],p);
    //		h[i]=ans[i]*fac[m-i]%p;
    //	}
    	int q;
    	int x;
    	scanf("%d",&q);
    	while(q--)
    	{
    		scanf("%d",&x);
    		printf("%lld
    ",g[m-x]);
    	}
    	return 0;
    }
    
  • 相关阅读:
    isinstance函数
    Django之ORM那些相关操作
    Django ~ 2
    Django ~ 1
    Django详解之models操作
    Django模板语言相关内容
    livevent的几个问题
    客户端,服务器发包走向
    关闭客户端连接的两种情况
    std::vector<Channel2*> m_allChannels;容器,以及如何根据channelid的意义
  • 原文地址:https://www.cnblogs.com/ywwyww/p/8513349.html
Copyright © 2011-2022 走看看