在做本题之前,你需要一个预备知识:任意模数NTT
如果不会,请看这里
(其实那个不是真正的任意模数NTT,而是一种奇技淫巧,但是...能用就行!)
然后我们来讨论本题
首先不难发现,后来的一个数$A$的二进制表示一定至少有一位上是$1$,且原来的数上这一位都是$0$
这是很显然的,否则无法满足$B$序列单调递增
那么我们设$dp[i][j]$表示前$i$个数中已经放下了$j$个1(暂时不考虑位置),那么答案即为$sum_{i=1}^{k}C_{k}^{i}dp[n][i]$
接下来我们考虑转移:
$dp[i][j]=sum_{k=0}^{j-1}dp[i-1][k]*2^{k}*C_{j}^{k}$
这个转移表示现在认为这$j$个1中有$k$个是原来已有的,原来已有的$1$可放可不放,剩下的$1$必须放
这个转移是$O(n^3)$,显然不够优秀
考虑优化:拆开组合数之后,可得:
$dp[i][j]=sum_{k=0}^{j-1}dp[i-1][k]*2^{k}*frac{j!}{k!(j-k)!}$
移项,整理:
$frac{dp[i][j]}{j!}=sum_{k=0}^{j-1}frac{dp[i-1][k]*2^{k}}{k!}*frac{1}{(j-k)!}$
这是一个卷积的形式,用上面的技术可以优化成$O(n^{2}log_{2}n)$
能不能再优化呢?
再推一下,可以发现:
$dp[x+y][j]=sum_{k=0}^{j-1}dp[x][k]*dp[y][j-k]*C_{j}^{k}*2^{y*k}$
这就可以快速幂了
时间复杂度$O(nlog_{2}^{2}n)$
贴代码:
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> #define ll long long #define ld long double using namespace std; const ld pi=acos(-1.0); const int siz=(1<<19)+5; ll n,m,p=1000000007; ll inv[1000005]; ll minv[1000005]; ll mul[1000005]; void init() { inv[0]=inv[1]=minv[0]=minv[1]=mul[0]=mul[1]=1; for(int i=2;i<=1000000;i++) { inv[i]=(p-p/i)*inv[p%i]%p; mul[i]=mul[i-1]*i%p; minv[i]=minv[i-1]*inv[i]%p; } } const ll M=32767; struct cp { ld x,y; }; int to[siz]; int lim=1,l; ll A[siz],B[siz]; cp a[siz],b[siz],c[siz],d[siz],e[siz],f[siz],g[siz],h[siz]; ll ret[siz]; ll Aa[siz],Bb[siz]; cp operator + (cp a,cp b) { return (cp){a.x+b.x,a.y+b.y}; } cp operator - (cp a,cp b) { return (cp){a.x-b.x,a.y-b.y}; } cp operator * (cp a,cp b) { return (cp){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x}; } void FFT(cp *a,int len,int k) { for(int i=0;i<len;i++)if(i<to[i])swap(a[i],a[to[i]]); for(int i=1;i<len;i<<=1) { cp w0=(cp){cos(pi/i),k*sin(pi/i)}; for(int j=0;j<len;j+=(i<<1)) { cp w=(cp){1,0}; for(int o=0;o<i;o++,w=w*w0) { cp w1=a[j+o],w2=a[j+o+i]*w; a[j+o]=w1+w2,a[j+o+i]=w1-w2; } } } } void MTT() { memset(a,0,sizeof(a)),memset(b,0,sizeof(b)),memset(c,0,sizeof(c)),memset(d,0,sizeof(d)); memset(e,0,sizeof(e)),memset(f,0,sizeof(f)),memset(g,0,sizeof(g)),memset(h,0,sizeof(h)); for(int i=0;i<=m;i++)a[i].x=A[i]/M,b[i].x=A[i]%M; for(int i=0;i<=m;i++)c[i].x=B[i]/M,d[i].x=B[i]%M; FFT(a,lim,1),FFT(b,lim,1),FFT(c,lim,1),FFT(d,lim,1); for(int i=0;i<lim;i++)e[i]=a[i]*c[i],f[i]=a[i]*d[i],g[i]=b[i]*c[i],h[i]=b[i]*d[i]; FFT(e,lim,-1),FFT(f,lim,-1),FFT(g,lim,-1),FFT(h,lim,-1); for(int i=0;i<lim;i++)ret[i]=((ll)(e[i].x/lim+0.1)%p*M%p*M%p+(ll)(f[i].x/lim+0.1)%p*M%p+(ll)(g[i].x/lim+0.1)%p*M%p+(ll)(h[i].x/lim+0.1)%p)%p; } void solve(ll *aa,ll *bb,ll x) { ll temp=1; for(int i=0;i<=m;i++)B[i]=bb[i]*temp%p,temp=temp*x%p; for(int i=0;i<=m;i++)A[i]=aa[i]; MTT(); for(int i=0;i<=m;i++)bb[i]=ret[i]; } ll get_C(ll x,ll y) { return mul[x]*minv[y]%p*minv[x-y]%p; } int main() { // freopen("bp.in","r",stdin); // freopen("bp.out","w",stdout); scanf("%lld%lld",&n,&m); if(n>m){printf("0 ");return 0;} init(); while(lim<=2*m)lim<<=1,l++; for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1))); Bb[0]=1; ll tt=2; for(int i=1;i<=m;i++)Aa[i]=minv[i]; while(n) { if(n&1)solve(Aa,Bb,tt); solve(Aa,Aa,tt),tt=tt*tt%p,n>>=1; } ll ans=0; for(int i=0;i<=m;i++)ans=(ans+Bb[i]*mul[i]%p*get_C(m,(ll)i)%p)%p; printf("%lld ",ans); return 0; }