[题目链接]
https://atcoder.jp/contests/arc100/tasks/arc100_d
[题解]
首先考虑容斥 , 用存在一个 (K) - 连续段的答案减去不存在 (K) - 连续段的答案。
显然 , 总方案数为 ((N - M + 1) cdot K ^ {N - M})。
考虑不合法的方案数 , 不妨分三类情况讨论 :
(1.) (A) 中已经存在一个连续段了 , 此时答案为 (0)。
(2.) (A) 中有两个相同的元素 , 那么对于序列中一个与 (A) "重合" 的子串 , 其不可能 "跨过" 一个连续段。 不妨设 (f_{i , j}) 表示 (i) 个数 , 末尾一段极长且互不相等的序列的最大长度是 (j) 的方案数。转移的时候只需考虑是否加一个新的元素或从某一位置 "断开" 即可。注意每次转移都是对一个连续区间贡献 , 于是可以差分前缀和优化从而做到 (O(N ^ 2))。对于首尾分别做一遍这样的 (DP) , 答案显然是一个卷积的形式。
(3.) (A) 中没有相同的元素 , 那么必然满足 (K > M) , 考虑转化 , 原问题等价于长度为 (N) 的所有合法序列中有多少个 (K) - 连续段 , 注意这样是 “重新标号" 过的 , 因此准确的答案还需除以 (frac{K!}{(K - M)!})。 这个问题同样可以用 (DP) 解决 , 时间复杂度 (O(N ^ 2))
, 不赘述了。
[代码]
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MK = 410 , MN = 30010 , mod = 1e9 + 7;
int N , K , M , ans , X[MN] , Y[MN] , fac[MN] , ifac[MN] , occ[MN] , A[MN] , g[MN][MK] ,
f[MN][MK];
inline void inc(int &x , int y) {
x = x + y < mod ? x + y : x + y - mod;
}
inline void dec(int &x , int y) {
x = x - y >= 0 ? x - y : x - y + mod;
}
inline int qPow(int a , int b) {
int c = 1;
for (; b; b >>= 1 , a = 1LL * a * a % mod) if (b & 1) c = 1LL * c * a % mod;
return c;
}
inline void init(int n) {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[n] = qPow(fac[n] , mod - 2);
for (int i = n - 1; i >= 0; --i) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
return;
}
inline bool check() {
int j = 1 , cur = 0;
for (int i = 1; i <= M; ++i) {
if (!occ[A[i]]) ++cur; ++occ[A[i]];
while (occ[A[i]] > 1) --occ[A[j++]];
if (i - j + 1 == K && cur == K) return 1;
}
return 0;
}
inline bool check2() {
memset(occ , 0 , sizeof(occ));
for (int i = 1; i <= M; ++i)
if (occ[A[i]]) return 1;
else ++occ[A[i]];
return 0;
}
inline int DP1() {
g[0][0] = 1; int ret = 0;
for (int i = 0; i < N; ++i) {
for (int p , q , j = 0; j < K; ++j) {
p = 1ll * f[i][j] * (K - j) % mod ,
q = 1ll * g[i][j] * (K - j) % mod;
if (j + 1 < K) {
inc(f[i + 1][j + 1] , p) , dec(f[i + 1][j + 2] , p);
inc(g[i + 1][j + 1] , q) , dec(g[i + 1][j + 2] , q);
}
inc(f[i + 1][1] , f[i][j]) , inc(g[i + 1][1] , g[i][j]);
dec(f[i + 1][j + 1] , f[i][j]) , dec(g[i + 1][j + 1] , g[i][j]);
}
for (int j = 1; j < K; ++j) {
inc(g[i + 1][j] , g[i + 1][j - 1]);
inc(f[i + 1][j] , f[i + 1][j - 1]);
}
for (int j = M; j < K; ++j)
inc(f[i + 1][j] , g[i + 1][j]);
}
for (int i = 1; i < K; ++i) inc(ret , f[N][i]);
return ret;
}
inline void DP2(int res[MN] , int s[MN][MK] , int type) {
memset(occ , 0 , sizeof(occ));
if (!type) {
for (int i = 1; i <= M; ++i)
if (!occ[A[i]]) ++occ[A[i]];
else { s[0][i - 1] = 1; break; }
} else {
for (int i = M; i >= 1; --i)
if (!occ[A[i]]) ++occ[A[i]];
else { s[0][M - i] = 1; break; }
}
for (int i = 0; i < N - M; ++i) {
for (int val , j = 1; j < K; ++j) {
val = 1ll * s[i][j] * (K - j) % mod;
if (j + 1 < K) {
inc(s[i + 1][j + 1] , val);
dec(s[i + 1][j + 2] , val);
}
inc(s[i + 1][1] , s[i][j]);
dec(s[i + 1][j + 1] , s[i][j]);
}
for (int j = 1; j < K; ++j) {
inc(s[i + 1][j] , s[i + 1][j - 1]);
inc(res[i + 1] , s[i + 1][j]);
}
}
}
int main() {
scanf("%d%d%d" , &N , &K , &M); int mx = max(max(N , K) , M) + 1; init(mx);
for (int i = 1; i <= M; ++i) scanf("%d" , &A[i]);
int ans = 1LL * (N - M + 1) * qPow(K , N - M) % mod;
if (check()) {
printf("%d
" , ans);
return 0;
}
if (check2()) {
X[0] = Y[0] = 1;
DP2(X , f , 0) , DP2(Y , g , 1);
for (int i = 0; i <= N - M; ++i)
dec(ans , 1ll * X[i] * Y[N - M - i] % mod);
} else dec(ans , 1ll * DP1() * ifac[K] % mod * fac[K - M] % mod);
assert(ans >= 0 && ans < mod);
printf("%d
" , ans);
return 0;
}