题意:
给定一个集合S,里面的数都是小于m的非负整数。
求长度为n的数列个数,满足:
- 数列中所有数都属于S。
- 数列中所有数的乘积对m取模等于x。
称数列A和B不同当且仅当存在i使得$A_i eq B_i$。
答案对1004535809取模。
$nleq 10^{9},mleq 8000,m是质数$。
题解:
实际上就是要求类似于$F_i=sum limits_{jk=i}{f_j g_k}$的卷积。
这东西没法直接做,但如果我们令$j=g^{x},k=g^{y}$,那么就变成了我们熟悉的$F_z=sum limits_{x+y=z}{f_x g_y}$的形式。
这题中g就等于m的原根。找原根可以直接枚举g,如果$exists x eq m-1,g^{x}=1$,则g不合法。
求出原根后直接跑个卷积快速幂即可,注意此时最高次项应为$m-1$。
由于每次需要把$f_{i+m}$加到$f_i$上,然后令$f_{i+m}=0$,所以不能直接对点值快速幂,必须每乘一次就转成系数表示法。
复杂度$O(m log{m}log{n})$。
套路:
- 形如$F_i=sum limits_{jk=i}{f_j g_k}$的卷积$ ightarrow$变成普通卷积$ ightarrow$找到g使得$j=g^{x},k=g^{y}$。
代码:
#include<bits/stdc++.h> #define maxn 200005 #define maxm 500005 #define inf 0x7fffffff #define mod 1004535809 #define g 3 #define ll long long #define rint register ll #define debug(x) cerr<<#x<<": "<<x<<endl #define fgx cerr<<"--------------"<<endl #define dgx cerr<<"=============="<<endl using namespace std; ll m,vis[maxn],ind[maxn]; inline ll read(){ ll x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline ll pw(ll a,ll b,ll mo){ll r=1;while(b)r=(b&1)?r*a%mo:r,a=a*a%mo,b>>=1;return r;} struct poly{ ll a[maxn],n; inline void clear(){memset(a,0,sizeof(a)),n=0;} inline void ntt(ll op){ for(ll i=0;i<n;i++) if(ind[i]>i) swap(a[i],a[ind[i]]); for(ll l=1;l<=n;l<<=1){ ll p=pw(g,(mod-1)/l,mod); if(op==-1) p=pw(p,mod-2,mod); for(ll i=0;i<n;i+=l) for(ll j=i,w=1;j<i+l/2;j++,w=w*p%mod){ ll x=a[j],y=w*a[j+l/2]%mod; a[j]=(x+y)%mod,a[j+l/2]=(x-y+mod)%mod; } } if(op==-1){ ll inv=pw(n,mod-2,mod); for(ll i=0;i<n;i++) a[i]=a[i]*inv%mod; } } }; inline poly mul(poly A,poly B){ A.ntt(1),B.ntt(1); poly R; R.clear(),R.n=A.n; for(ll i=0;i<R.n;i++) R.a[i]=A.a[i]*B.a[i]%mod; R.ntt(-1); for(ll i=m;i<2*m;i++) R.a[i-m]=(R.a[i-m]+R.a[i])%mod,R.a[i]=0; return R; } inline poly solve(poly A,ll b){ for(ll i=0;i<A.n;i++) ind[i]=(i&1)?((ind[i>>1]>>1)|(A.n>>1)):(ind[i>>1]>>1); poly R; R.clear(),R.n=0; while(b){ if(b&1) R=(R.n==0)?A:mul(R,A); A=mul(A,A),b>>=1; } return R; } inline ll calc(ll x){ for(ll i=2;i<x;i++){ bool flag=1; for(ll j=2;j*j<x;j++) if(pw(i,(x-1)/j,x)==1) flag=0; if(flag) return i; } } int main(){ ll n=read(); m=read()-1; ll x=read(),S=read(),gen=calc(m+1); for(ll i=1;i<=S;i++) vis[read()]=1; ll N=1; while(N<2*m) N<<=1; poly A; A.clear(),A.n=N; for(ll i=0;i<m;i++) A.a[i]=vis[pw(gen,i,m+1)]; poly R=solve(A,n); for(ll i=0;i<m;i++) if(pw(gen,i,m+1)==x) printf("%lld ",R.a[i]); return 0; }