hdu6086
题意
字符串只由 (01) 组成,求长度为 (2L) 且包含给定的 (n) 个子串的字符串的个数(且要求字符串满足 (s[i] eq s[|s| - i + 1]))。
分析
没有想到可以暴力预处理中间那些字符。
官方题解:
如果没有反对称串的限制,直接求一个长度为 (L) 的 (01) 串满足所有给定串都出现过,那么是一个经典的 AC 自动机的问题,状态 (f[i][j][S]) 表示长度为 (i),目前在 AC 自动机的节点 (j) 上,已经出现的字符串集合为 (S) 的方案数,然后直接转移即可,时间复杂度 (O(2^nLsum |s|))。
然后如果不考虑有串跨越中轴线,那么可以预处理所有正串的 AC 自动机和所有反串(即原串左右翻转)的 AC 自动机,然后从中间向两边 DP,每一次枚举右侧下一个字符是 (0) 还是 (1),那么另一侧一定是另外一个字符。状态 (f[i][j][k][S]) 表示长度为 (2i),目前右半边在正串 AC 自动机的节点 (j) 上,左半边的反串在反串 AC 自动机的节点 (k) 上,已经出现的字符串集合为 (S) 的方案数,然后直接转移,时间复杂度 (O(2^nL(sum |s|)^2))。
现在考虑有串跨越中轴线,可以先爆枚从中间开始左右各 (max|s|-1) 个字符,统计出哪些串以及出现了。对于之后左右扩展出去的字符来说,肯定没有经过的它们的字符串跨越中轴线,因此可以以爆枚的结果为 DP 的初始值,从第 (max|s|) 个字符开始 DP。
时间复杂度 (O(2^nL(sum |s|)^2+max|s|2^{max|s|}))。
数组要开成滚动数组,然后爆搜的时候自动机上的状态也要跟着转移。
时限还是很宽松的。
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<iostream>
using namespace std;
typedef long long ll;
const int MAXN = 121;
const int MOD = 998244353;
struct Trie {
int root, L, nxt[MAXN][2], fail[MAXN], val[MAXN];
int newnode() {
memset(nxt[L], -1, sizeof nxt[L]);
return L++;
}
void init() {
L = 0;
root = newnode();
memset(val, 0, sizeof val);
memset(fail, 0, sizeof fail);
}
void insert(int id, char S[]) {
int len = strlen(S);
int now = root;
for(int i = 0; i < len; i++) {
int d = S[i] - '0';
if(nxt[now][d] == -1) nxt[now][d] = newnode();
now = nxt[now][d];
}
val[now] |= (1 << id);
}
void build() {
queue<int> Q;
for(int i = 0; i < 2; i++) {
if(nxt[root][i] == -1) nxt[root][i] = 0;
else { fail[nxt[root][i]] = root; Q.push(nxt[root][i]); }
}
while(!Q.empty()) {
int now = Q.front(); Q.pop();
val[now] |= val[fail[now]];
for(int i = 0; i < 2; i++) {
if(nxt[now][i] == -1) nxt[now][i] = nxt[fail[now]][i];
else { fail[nxt[now][i]] = nxt[fail[now]][i]; Q.push(nxt[now][i]); }
}
}
}
int query(char S[], int l, int r) {
int now = root;
int res = 0;
int flg = 0;
int mid = (r - l) / 2 + l;
for(int i = l; i <= r; i++) {
int d = S[i] - '0';
now = nxt[now][d];
res |= val[now];
}
return res;
}
}trie1, trie2;
int n, L, mx;
int dp[2][MAXN][MAXN][64];
void dfs(char s[], int l, int r, int nl, int nr) {
int len = r - l + 1;
if(len / 2 >= mx) {
int tmp = trie2.query(s, l, r);
dp[1][nl][nr][tmp]++;
return;
}
s[l - 1] = '0'; s[r + 1] = '1';
dfs(s, l - 1, r + 1, trie1.nxt[nl][0], trie2.nxt[nr][1]);
s[l - 1] = '1'; s[r + 1] = '0';
dfs(s, l - 1, r + 1, trie1.nxt[nl][1], trie2.nxt[nr][0]);
}
int cnt[64];
int main() {
cnt[0] = 0;
for(int i = 1; i < 64; i++) {
int j = 0;
while(!((i >> j) & 1)) j++;
cnt[i] = cnt[i - (1 << j)] + 1;
}
int T;
scanf("%d", &T);
while(T--) {
scanf("%d%d", &n, &L);
trie1.init();
trie2.init();
mx = 0;
for(int i = 0; i < n; i++) {
char s[22];
scanf("%s", s);
trie2.insert(i, s);
int len = strlen(s);
mx = max(mx, len);
reverse(s, s + len);
trie1.insert(i, s);
}
mx--;
trie1.build();
trie2.build();
memset(dp, 0, sizeof dp);
char s[65];
dfs(s, 23, 22, 0, 0);
int z = 1;
for(int i = mx; i < L; i++, z = !z) {
memset(dp[!z], 0, sizeof dp[!z]);
for(int j = 0; j < trie1.L; j++) {
for(int k = 0; k < trie2.L; k++) {
for(int p = 0; p < (1 << n); p++) {
if(!dp[z][j][k][p]) continue;
for(int q = 0; q < 2; q++) {
int tmp1 = trie1.nxt[j][q], tmp2 = trie2.nxt[k][!q];
(dp[!z][tmp1][tmp2][p | trie1.val[tmp1] | trie2.val[tmp2]] += dp[z][j][k][p]) %= MOD;
}
}
}
}
}
int sum = 0;
for(int i = 0; i < trie1.L; i++) {
for(int j = 0; j < trie2.L; j++) {
sum = (sum + dp[z][i][j][(1 << n) - 1]) % MOD;
}
}
printf("%d
", sum);
}
return 0;
}