http://www.lydsy.com/JudgeOnline/problem.php?id=3992
这道题好难啊。
第一眼谁都能看出来是个dp,设(f(i,j))表示转移到第i位时前i位的乘积模m等于j的方案数。
转移很显然啊(f(i,j)=sum_{x,yin[0,m)}[xymod m=j]f(i-1,x)*f(i-1,y))。
这个下标是乘积取模的转移根本无法优化啊。
但注意到题目最下方说m是一个质数。。。
把x=0特判掉,剩下(xin[1,m-1))时把x转化为m的原根的幂次。
设m的原根为(g_m)。
那么(f(i,g_m^j)=sum_{x,yin[0,m-1)}[(x+y)mod m=j]f(i-1,g_m^x)*f(i-1,g_m^y))。
这样通过原根在[0,m-1)上的不重不漏的一一映射,乘积取模变成加法取模,化成了一个循环卷积的形式。
(话说看模数也知道是NTT啊qwq)循环卷积直接用NTT做就可以了。
但要做N次循环卷积,(Nleq 10^9)。。。
在外面套层快速幂就可以了O(∩_∩)O~~
快速幂套循环卷积的正确性?先不循环卷积然后再压成循环卷积就很好证明啊。不过也可以把快速幂看成一个倍增,每次合并两个dp数组之类的,正确性都显然啊qwq
注意数组不要开小!用于NTT的数组要开到2的幂次qwq
时间复杂度(O(m^2+mlog mlog n))。
(看了Menci大大的博客,“把原根的幂次看成多项式的幂次,dp数组记录在系数里”这个东西还叫生成函数?)
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int M = 8193;
const int g = 267198;
const int p = 1004535809;
int nin;
int ipow(int a, int b) {
int ret = 1, w = a;
while (b) {
if (b & 1) ret = 1ll * ret * w % p;
w = 1ll * w * w % p;
b >>= 1;
}
return ret;
}
int da[M << 1], db[M << 1], dc[M << 1], rev[M << 1], nWN[15], WN[15], n, m, x, s, C;
void DNT(int *a, int *A, int flag) {
int tmp = 1;
for (int i = 0; i < n; ++i) A[rev[i]] = a[i];
for (int len = 2; len <= n; len <<= 1, ++tmp) {
int mid = len >> 1, wn = flag == 1 ? WN[tmp] : nWN[tmp];
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < mid; ++j) {
int t = A[i + j], u = 1ll * A[i + j + mid] * w % p;
A[i + j] = (t + u) % p;
A[i + j + mid] = (t - u + p) % p;
w = 1ll * w * wn % p;
}
}
}
if (flag == -1)
for (int i = 0; i < n; ++i)
A[i] = 1ll * A[i] * nin % p;
}
int top;
void NTTsqr(int *a) {
DNT(a, da, 1);
for (int i = 0; i < n; ++i)
da[i] = 1ll * da[i] * da[i] % p;
DNT(da, a, -1);
for (int i = 0; i < top; ++i) {
(a[i] += a[i + top]) %= p;
a[i + top] = 0;
}
}
void NTT(int *a, int *b) {
DNT(a, da, 1); DNT(b, db, 1);
for (int i = 0; i < n; ++i)
dc[i] = 1ll * da[i] * db[i] % p;
DNT(dc, a, -1);
for (int i = 0; i < top; ++i) {
(a[i] += a[i + top]) %= p;
a[i + top] = 0;
}
}
void init() {
int tot = 0, num = top << 1;
while (num) {num >>= 1; ++tot;}
n = 1 << tot;
nin = ipow(n, p - 2);
int res;
for (int i = 0; i < n; ++i) {
num = i; res = 0;
for (int j = 0; j < tot; ++j) {
res <<= 1;
if (num & 1) res |= 1;
num >>= 1;
}
rev[i] = res;
}
WN[14] = g, nWN[14] = ipow(g , p - 2);
for (int i = 13; i >= 1; --i) {
WN[i] = 1ll * WN[i + 1] * WN[i + 1] % p;
nWN[i] = 1ll * nWN[i + 1] * nWN[i + 1] % p;
}
}
bool shown[M];
int r[M << 1], ww[M << 1], c[M];
int main() {
scanf("%d%d%d%d", &C, &m, &x, &s); top = m - 1;
if (x == 0) {printf("%d
", (ipow(m, n) - ipow(m - 1, n) + p) % p); return 0;}
int num;
for (int i = 2; i < m; ++i) {
int ret = 1; bool flag = true;
for (int j = 0; j < top; ++j) {
ret = 1ll * ret * i % m;
if (shown[ret]) {flag = false; break;}
shown[ret] = true;
}
if (!flag || ret != 1) {
ret = 1;
for (int j = 0; j < top; ++j) {
ret = 1ll * ret * i % m;
if (shown[ret]) shown[ret] = false;
else break;
}
} else {
num = i;
break;
}
}
int ret = 1;
for (int i = 0; i < top; ++i) {
c[ret] = i;
ret = 1ll * ret * num % m;
}
init();
int tt;
for (int i = 1; i <= s; ++i) {
scanf("%d", &tt);
if (tt != 0) ww[c[tt]] = 1;
}
r[0] = 1;
while (C) {
if (C & 1) NTT(r, ww);
NTTsqr(ww);
C >>= 1;
}
printf("%d
", r[c[x]]);
return 0;
}