[BZOJ3992][SDOI2015]序列统计
试题描述
小C有一个集合 (S),里面的元素都是小于 (M) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 (N) 的数列,数列中的每个数都属于集合 (S)。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:
给定整数 (x),求所有可以生成出的,且满足数列中所有数的乘积 (mod M) 的值等于 (x) 的不同的数列的有多少个。小C认为,两个数列 ({ A_i }) 和 ({ B_i })不同,当且仅当至少存在一个整数 (i),满足 (A_i e B_i)。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案 (mod 1004535809) 的值就可以了。
输入
第一行,四个整数,(N)、(M)、(x)、(|S|),其中 (|S|) 为集合 (S) 中元素个数。
第二行,(|S|) 个整数,表示集合 (S) 中的所有元素。
输出
一行,一个整数,表示你求出的种类数 (mod 1004535809) 的值。
输入示例
4 3 1 2
1 2
输出示例
8
数据规模及约定
(1 le N le 10^9),(3 le M le 8000),(M) 为质数
(0 le x le M-1),输入数据保证集合 (S) 中元素不重复
题解
一提到原根应该就会做了。会做的就不用再往下看了……
原根的定义中运算都在模 (P) 意义下,原根就是一个整数 (g),满足 ({ g^i | i in [0, P-2] }) 能与 ({ i | i in [1, P-1] }) 一一对应。找原根的方法是暴力枚举 (g),检查是否 (g^k equiv 1) 是否是 (k = P-1 或 k = 0) 的充要条件,是则表明 (g) 是原根,否则继续枚举。
这样,每个非 (0) 整数都能够用 (g) 的若干次方来表示了,并且乘积变成了指数的相加。于是就可以解决“转移是乘法”的问题了。
按照套路,还是讲一下暴力 dp 吧。设 (f(i, j)) 表示长度为 (i) 的数列,数列乘积的结果是 (g^j mod M),这样的数列的个数。那么我们找到 (S) 中的元素的指数,即找到 (P_i) 满足 (g^{P_i} equiv S_i(mod M)),那么转移就是
其中,( ightarrow) 表示累加。
然后搞生成函数 (F_i(x)) 为 (f(i, j)) 的生成函数,令 (G(x) = sum_{i=1}^{|S|} x^{P_i}),有
注意这里的多项式乘法是 (M-1) 位循环卷积。上面的式子可以用倍增 + NTT 来做,每次乘法完毕之后暴力把多出来的部分累加到前面去。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)
int read() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
#define maxn 16384
#define MOD 1004535809
#define Groot 3
#define LL long long
int Pow(int a, int b) {
int ans = 1, t = a;
while(b) {
if(b & 1) ans = (LL)ans * t % MOD;
t = (LL)t * t % MOD; b >>= 1;
}
return ans;
}
int brev[maxn];
void FFT(int *a, int len, int tp) {
int n = 1 << len;
rep(i, 0, n - 1) if(i < brev[i]) swap(a[i], a[brev[i]]);
rep(i, 1, len) {
int wn = Pow(Groot, MOD - 1 >> i);
if(tp < 0) wn = Pow(wn, MOD - 2);
for(int j = 0; j < n; j += 1 << i) {
int w = 1;
rep(k, 0, (1 << i >> 1) - 1) {
int la = a[j+k], ra = (LL)w * a[j+k+(1<<i>>1)] % MOD;
a[j+k] = (la + ra) % MOD;
a[j+k+(1<<i>>1)] = (la - ra + MOD) % MOD;
w = (LL)w * wn % MOD;
}
}
}
if(tp < 0) {
int invn = Pow(n, MOD - 2);
rep(i, 0, n - 1) a[i] = (LL)a[i] * invn % MOD;
}
return ;
}
void Mul(int *A, int *B, int n, int m, bool recover = 0) {
int N = 1, len = 0;
while(N <= n + m) N <<= 1, len++;
rep(i, 0, N - 1) brev[i] = (brev[i>>1] >> 1) | ((i & 1) << len >> 1);
rep(i, n + 1, N - 1) A[i] = 0;
rep(i, m + 1, N - 1) B[i] = 0;
FFT(A, len, 1); FFT(B, len, 1);
rep(i, 0, N - 1) A[i] = (LL)A[i] * B[i] % MOD;
FFT(A, len, -1); if(recover) FFT(B, len, -1);
rep(i, n + 1, n + m) (A[i%(n+1)] += A[i]) %= MOD;
return ;
}
int findg(int M) {
rep(i, 2, M - 1) {
int x = 1; bool ok = 1;
rep(j, 1, M - 2) {
x = (LL)x * i % M;
if(x == 1){ ok = 0; break; }
}
if(ok) return i;
}
return 0;
}
int s[maxn], p[maxn], Exp[maxn];
void getExp(int g, int M, int k) {
int x = 1;
rep(i, 0, M - 2) Exp[x] = i, x = (LL)x * g % M;
rep(i, 1, k) p[i] = s[i] ? Exp[s[i]] : -1;
return ;
}
int F[maxn], G[maxn], tmp[maxn];
void p_pow(int n, int M) {
while(n) {
if(n & 1) Mul(F, G, M, M, 1);
memcpy(tmp, G, sizeof(G));
Mul(G, tmp, M, M); n >>= 1;
}
return ;
}
int main() {
int n = read(), M = read(), q = read(), k = read(), g = findg(M);
rep(i, 1, k) s[i] = read();
getExp(g, M, k);
rep(i, 1, k) if(p[i] >= 0) G[p[i]] = 1;
F[0] = 1;
p_pow(n, M - 2);
printf("%d
", F[Exp[q]]);
return 0;
}