zoukankan      html  css  js  c++  java
  • [UOJ86]mx的组合数——NTT+数位DP+原根与指标+卢卡斯定理

    题目链接:

    [UOJ86]mx的组合数

    题目大意:给出四个数$p,n,l,r$,对于$forall 0le ale p-1$,求$lle xle r,C_{x}^{n}\%p=a$的$x$的数量。$p<=3000$且保证$p$是质数,$n,l,r<=10^30$。

    对于$10\%$的数据,可以直接杨辉三角推。
    对于$20\%$的数据,因为$n$是确定的,可以递推出$C_{x+1}^{n}=C_{x}^{n}*frac{x+1}{x+1-n}$。
    对于另外$20\%$的数据,可以枚举$x$然后用$lucas$定理求。
    对于另外$30\%$的数据,可以想到将问题转化成小于等于$r$的个数$-$小于等于$l-1$的个数。由$lucas$定理可知,$C_{x}^{n} mod p=prod C_{b_{i}}^{a_{i}} mod p$,其中$a_{i},b_{i}$分别为$n,x$在$p$进制下的第$i$位。那么我们就可以用数位$DP$求,$f[i][j]$代表从最低为开始的前$i$位,每一位的值都不大于$b_{i}$且$\%p=j$的方案数;$g[i][j]$代表从最低为开始的前$i$位,每一位的值任意且$\%p=j$的方案数。设枚举第$i+1$位为$x$,$C_{x}^{a_{i+1}}=k$。那么可以得到$DP$转移方程$g[i+1][jk mod p]+=g[i][j]$,若$x<b_{i+1}$,则$f[i+1][jk mod p]+=g[i][j]$,若$x=b_{i+1}$,则$f[i+1][jk mod p]+=f[i][j]$。时间复杂度为$O(p^2log_{p})$。
    对于$100\%$的数据,我们考虑优化上述$DP$,我们拿其中第一个转移方程来说(后两个同理),我们设$h[k]=sumlimits_{x=0}^{p-1}[C_{x}^{a_{i+1}}==k]$。可以发现转移可以看成是$G[j*k mod p]=sumlimits_{j=0}^{p-1}g[j]sumlimits_{k=0}^{p-1}h[k]$,这和卷积式子很像,但他是乘法卷积,我们想办法将它变成加法卷积:因为$p$是质数,那么$p$一定有原根(设为$g$),也就是说对于任意$j$,其中$1le jle p-1$都有指标。我们设它的指标为$ind(j)$,那么$j*k mod p$就能转化为$g^{(ind(j)+ind(k)) mod (p-1)} mod p$。这样我们就能用$FFT$或$NTT$来加速$DP$了,但注意到$0$没有指标,我们在转移时先忽略$0$,在最后输出答案时用总个数减掉其他答案就是$\%p=0$的个数了。注意原根从$1$开始枚举。至于$10^{30}$可以用$\_\_int128$存。时间复杂度为$O(plog_{p}^2)$。

    两种写法,读者自选。

    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<cmath>
    #include<cstdio>
    #include<vector>
    #include<bitset>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #define ll long long
    typedef __int128 int128;
    #define MOD 998244353
    using namespace std;
    int p;
    int128 l,r,n;
    int pr[10];
    int cnt;
    int G;
    int mx;
    ll sum;
    int ind[30010];
    ll f[100000];
    ll g[100000];
    ll h[100000];
    int a[200];
    int b[200];
    ll ans[30010];
    int c[200][30010];
    int mask=1;
    ll s[100000];
    char *p1,*p2,buf[100000];
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
    int read_()
    {
    	int x=0; 
    	char c=nc(); 
    	while(c<48) 
    	{
    		c=nc(); 
    	}
    	while(c>47) 
    	{
    		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
    	}
    	return x;
    }
    int128 read() 
    {
    	int128 x=0; 
    	char c=nc(); 
    	while(c<48) 
    	{
    		c=nc(); 
    	}
    	while(c>47) 
    	{
    		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
    	}
    	return x;
    }
    ll quick(int x,int y,int mod)
    {
    	ll res=1ll;
    	while(y)
    	{
    		if(y&1)
    		{
    			res=res*x%mod;
    		}
    		y>>=1;
    		x=1ll*x*x%mod;
    	}
    	return res;
    }
    void NTT(ll *a,int len,int miku)
    {
    	for(int k=0,i=0;i<len;i++)
    	{
    		if(i>k)
    		{
    			swap(a[i],a[k]);
    		}
    		for(int j=len>>1;(k^=j)<j;j>>=1);
    	}
    	for(int k=2;k<=len;k<<=1)
    	{
    		int t=k>>1;
    		int x=quick(3,(MOD-1)/k,MOD);
    		if(miku==-1)
    		{
    			x=quick(x,MOD-2,MOD);
    		}
    		for(int i=0;i<len;i+=k)
    		{
    			ll w=1;
    			for(int j=i;j<i+t;j++)
    			{
    				ll tmp=a[j+t]*w%MOD;
    				a[j+t]=(a[j]-tmp+MOD)%MOD;
    				a[j]=(a[j]+tmp)%MOD;
    				w=w*x%MOD;
    			}
    		}
    	}
    	if(miku==-1)
    	{
    		for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++)
    		{
    			a[i]=a[i]*t%MOD;
    		}
    	}
    }
    void solve(int128 num)
    {
    	memset(f,0,sizeof(f));
    	memset(g,0,sizeof(g));
    	memset(h,0,sizeof(h));
    	memset(a,0,sizeof(a));
    	int res=0;
    	for(int i=1;num;i++)
    	{
    		a[i]=num%p;
    		num/=p;
    		res=max(res,i);
    	}
    	mx=max(res,mx);
    	g[0]=f[0]=1ll;
    	for(int k=1;k<=mx;k++)
    	{
    		memset(h,0,sizeof(h));
    		memset(s,0,sizeof(s));
    		NTT(g,mask,1);
    		NTT(f,mask,1);
    		if(a[k]>=b[k])
    		{
    			h[ind[c[k][a[k]]]]++;
    			NTT(h,mask,1);
    			for(int i=0;i<mask;i++)
    			{
    				s[i]+=1ll*h[i]*f[i]%MOD;
    				s[i]%=MOD;
    			}
    			NTT(h,mask,-1);
    			h[ind[c[k][a[k]]]]--;
    		}
    		for(int i=b[k];i<a[k];i++)
    		{
    			h[ind[c[k][i]]]++;
    		}
    		NTT(h,mask,1);
    		for(int i=0;i<mask;i++)
    		{
    			s[i]+=1ll*h[i]*g[i]%MOD;
    			s[i]%=MOD;
    		}
    		NTT(h,mask,-1);
    		NTT(s,mask,-1);
    		memset(f,0,sizeof(f));
    		for(int i=0;i<mask;i++)
    		{
    			f[i%(p-1)]+=s[i];
    			f[i%(p-1)]%=MOD;
    		}
    		for(int i=max(b[k],a[k]);i<p;i++)
    		{
    			h[ind[c[k][i]]]++;
    		}
    		NTT(h,mask,1);
    		for(int i=0;i<mask;i++)
    		{
    			s[i]=1ll*h[i]*g[i]%MOD;
    		}
    		NTT(s,mask,-1);
    		memset(g,0,sizeof(g));
    		for(int i=0;i<mask;i++)
    		{
    			g[i%(p-1)]+=s[i];
    			g[i%(p-1)]%=MOD;
    		}
    	}
    }
    int main()
    {
    	p=read_(),n=read(),l=read(),r=read();
    	l--;
    	int s=p-1;
    	while(mask<(p<<1))
    	{
    		mask<<=1;
    	}
    	for(int i=2;i*i<=s;i++)
    	{
    		if(s%i==0)
    		{
    			pr[++cnt]=i;
    			while(s%i==0)
    			{
    				s/=i;
    			}
    		}
    	}
    	if(s!=1)
    	{
    		pr[++cnt]=s;
    	}
    	for(int i=1;i<p;i++)
    	{
    		bool flag=true;
    		for(int j=1;j<=cnt;j++)
    		{
    			if(quick(i,(p-1)/pr[j],p)==1)
    			{
    				flag=false;
    				break;
    			}
    		}
    		if(flag)
    		{
    			G=i;
    			break;
    		}
    	}
    	sum=1ll;
    	for(int i=0;i<p-1;i++)
    	{
    		ind[sum]=i;
    		sum*=G,sum%=p;
    	}
    	int128 N=n;
    	for(int i=1;N;i++)
    	{
    		b[i]=N%p;
    		N/=p;
    		mx=max(mx,i);
    	}
    	for(int i=1;i<=mx;i++)
    	{
    		for(int j=0;j<b[i];j++)
    		{
    			c[i][j]=0;
    		}
    		sum=1ll;
    		for(int j=b[i];j<p;j++)
    		{
    			c[i][j]=sum;
    			sum*=(j+1),sum%=p;
    			sum*=quick(j+1-b[i],p-2,p),sum%=p;
    		}
    	}
    	solve(l);
    	for(int i=0;i<p-1;i++)
    	{
    		ans[quick(G,i,p)]-=f[i];
    	}
    	for(int i=1;i<=p-1;i++)
    	{
    		ans[i]=(ans[i]%MOD+MOD)%MOD;
    	}
    	solve(r);
    	for(int i=0;i<p-1;i++)
    	{
    		ans[quick(G,i,p)]+=f[i];
    	}
    	for(int i=1;i<=p-1;i++)
    	{
    		ans[i]%=MOD;
    	}
    	ans[0]=(r-l)%MOD;
    	for(int i=1;i<p;i++)
    	{
    		ans[0]-=ans[i];
    		ans[0]=(ans[0]%MOD+MOD)%MOD;
    	}
    	for(int i=0;i<p;i++)
    	{
    		printf("%lld
    ",ans[i]);
    	}
    }
    
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<cmath>
    #include<cstdio>
    #include<vector>
    #include<bitset>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #define ll long long
    typedef __int128 int128;
    #define MOD 998244353
    using namespace std;
    int p;
    int128 l,r,n;
    int pr[10];
    int cnt;
    int G;
    int mx;
    ll sum;
    int ind[30010];
    ll f[100000];
    ll g[100000];
    ll A[100000];
    ll B[100000];
    ll C[100000];
    int a[200];
    int b[200];
    ll ans[30010];
    int c[200][30010];
    int mask=1;
    int s[100000];
    int pw[300010];
    int fac[300010];
    int inv[300010];
    char *p1,*p2,buf[100000];
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
    int read_()
    {
    	int x=0; 
    	char c=nc(); 
    	while(c<48) 
    	{
    		c=nc(); 
    	}
    	while(c>47) 
    	{
    		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
    	}
    	return x;
    }
    int128 read() 
    {
    	int128 x=0; 
    	char c=nc(); 
    	while(c<48) 
    	{
    		c=nc(); 
    	}
    	while(c>47) 
    	{
    		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
    	}
    	return x;
    }
    ll quick(int x,int y,int mod)
    {
    	ll res=1ll;
    	while(y)
    	{
    		if(y&1)
    		{
    			res=res*x%mod;
    		}
    		y>>=1;
    		x=1ll*x*x%mod;
    	}
    	return res;
    }
    void NTT(ll *a,int len,int miku)
    {
    	for(int k=0,i=0;i<len;i++)
    	{
    		if(i>k)
    		{
    			swap(a[i],a[k]);
    		}
    		for(int j=len>>1;(k^=j)<j;j>>=1);
    	}
    	for(int k=2;k<=len;k<<=1)
    	{
    		int t=k>>1;
    		int x=quick(3,(MOD-1)/k,MOD);
    		if(miku==-1)
    		{
    			x=quick(x,MOD-2,MOD);
    		}
    		for(int i=0;i<len;i+=k)
    		{
    			ll w=1;
    			for(int j=i;j<i+t;j++)
    			{
    				ll tmp=a[j+t]*w%MOD;
    				a[j+t]=(a[j]-tmp+MOD)%MOD;
    				a[j]=(a[j]+tmp)%MOD;
    				w=w*x%MOD;
    			}
    		}
    	}
    	if(miku==-1)
    	{
    		for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++)
    		{
    			a[i]=a[i]*t%MOD;
    		}
    	}
    }
    void solve(int128 num)
    {
    	memset(f,0,sizeof(f));
    	memset(g,0,sizeof(g));
    	memset(a,0,sizeof(a));
    	int res=0;
    	for(int i=1;num;i++)
    	{
    		a[i]=num%p;
    		num/=p;
    		res=max(res,i);
    	}
    	mx=max(res,mx);
    	g[1]=f[1]=1ll;
    	for(int k=1;k<=mx;k++)
    	{
    		memset(A,0,sizeof(A));
    		memset(B,0,sizeof(B));
    		for(int i=b[k];i<p;i++)
    		{
    			if(c[k][i])
    			{
    				A[ind[c[k][i]]]++;
    			}
    		}
    		for(int i=1;i<p;i++)
    		{	
    			B[ind[i]]+=g[i];
    			B[ind[i]]%=MOD;
    		}
    		NTT(A,mask,1);
    		NTT(B,mask,1);
    		for(int i=0;i<mask;i++)
    		{
    			C[i]=A[i]*B[i]%MOD;
    		}
    		NTT(C,mask,-1);
    		memset(g,0,sizeof(g));
    		for(int i=0;i<mask;i++)
    		{
    			(g[quick(G,i%(p-1),p)]+=C[i])%=MOD;
    		}
    		memset(A,0,sizeof(A));
    		for(int i=b[k];i<a[k];i++)
    		{
    			if(c[k][i])
    			{
    				A[ind[c[k][i]]]++;
    			}
    		}
    		NTT(A,mask,1);
    		for(int i=0;i<mask;i++)
    		{
    			C[i]=A[i]*B[i]%MOD;
    		}
    		NTT(C,mask,-1);
    		memset(s,0,sizeof(s));
    		for(int i=0;i<mask;i++)
    		{
    			(s[quick(G,i%(p-1),p)]+=C[i])%=MOD;
    		}
    		if(c[k][a[k]])
    		{
    			for(int i=1;i<p;i++)
    			{
    				(s[c[k][a[k]]*i%p]+=f[i])%=MOD;;
    			}
    		}
    		for(int i=1;i<p;i++)
    		{
    			f[i]=s[i];
    		}
    	}
    }
    int get_ori(int p)
    {
    	int s=p-1;
    	for(int i=2;i*i<=s;i++)
    	{
    		if(s%i==0)
    		{
    			pr[++cnt]=i;
    			while(s%i==0)
    			{
    				s/=i;
    			}
    		}
    	}
    	if(s!=1)
    	{
    		pr[++cnt]=s;
    	}
    	for(int i=1;i<p;i++)
    	{
    		bool flag=true;
    		for(int j=1;j<=cnt;j++)
    		{
    			if(quick(i,(p-1)/pr[j],p)==1)
    			{
    				flag=false;
    				break;
    			}
    		}
    		if(flag)
    		{
    			return i;
    			break;
    		}
    	}
    }
    int main()
    {
    	p=read_(),n=read(),l=read(),r=read();
    	while(mask<(p<<1))
    	{
    		mask<<=1;
    	}
    	G=get_ori(p);
    	pw[0]=1ll;
    	for(int i=1;i<p;i++)
    	{
    		pw[i]=pw[i-1]*G%p;
    	}
    	sum=1ll;
    	for(int i=0;i<p-1;i++)
    	{
    		ind[sum]=i;
    		sum*=G,sum%=p;
    	}
    	int128 N=n;
    	for(int i=1;N;i++)
    	{
    		b[i]=N%p;
    		N/=p;
    		mx=max(mx,i);
    	}
    	fac[0]=inv[0]=1ll;
    	for(int i=1;i<p;i++)
    	{
    		fac[i]=fac[i-1]*i%p;
    	}
    	inv[p-1]=quick(fac[p-1],p-2,p);
    	for(int i=p-2;i>=1;i--)
    	{
    		inv[i]=inv[i+1]*(i+1)%p;
    	}
    	for(int i=1;i<=120;i++)
    	{
    		for(int j=b[i];j<p;j++)
    		{
    			c[i][j]=fac[j]*inv[j-b[i]]%p*inv[b[i]]%p;
    		}
    	}
    	solve(r);
    	for(int i=1;i<p;i++)
    	{
    		ans[i]=f[i];
    	}
    	solve(l-1);
    	for(int i=1;i<p;i++)
    	{
    		ans[i]=((ans[i]-f[i])%MOD+MOD)%MOD;
    	}
    	ans[0]=(r-l+1)%MOD;
    	for(int i=1;i<p;i++)
    	{
    		ans[0]=((ans[0]-ans[i])%MOD+MOD)%MOD;
    	}
    	for(int i=0;i<p;i++)
    	{
    		printf("%lld
    ",ans[i]);
    	}
    }
  • 相关阅读:
    mybatis批量插入数据
    oracle的dmp数据文件的导出和导入以及创建用户
    maven安装第三方jar包到本地仓库
    IntelliJ IDEA 注册码,激活
    分布式事务实现-Spanner
    Redis Cluster原理
    twemproxy源码分析
    Paxos可容错的一致性协议
    UpdateServer事务实现机制
    Coroutine及其实现
  • 原文地址:https://www.cnblogs.com/Khada-Jhin/p/10460022.html
Copyright © 2011-2022 走看看