和前几天做的AC自动机类似。
思路简单但是代码200余行。。
假设solve_sub(i)表示长度为i的不含危险单词的总数。
最终答案为用总数(26^1+26^2+...+26^n)减去(solve_sub(1)+solve_sub(2)+...+solve_sub(n))。前者构造f[i]=f[i-1]*26+26然后矩阵快速幂即可(当然也可以分治的方法)。后者即构造出dp矩阵p,然后计算(p^1+p^2+...+p^n),对其分治即可。
代码如下:
1 #include <stdio.h> 2 #include <algorithm> 3 #include <string.h> 4 #include <vector> 5 #include <queue> 6 #include <iostream> 7 using namespace std; 8 const int MAX_Tot = 30 + 5; 9 const int mod = 100000; 10 typedef unsigned long long ull; 11 12 int m,n; 13 14 struct matrix 15 { 16 ull e[MAX_Tot][MAX_Tot]; 17 int n,m; 18 matrix() {} 19 matrix(int _n,int _m): n(_n),m(_m) {memset(e,0,sizeof(e));} 20 matrix operator * (const matrix &temp)const 21 { 22 matrix ret = matrix(n,temp.m); 23 for(int i=1;i<=ret.n;i++) 24 { 25 for(int j=1;j<=ret.m;j++) 26 { 27 for(int k=1;k<=m;k++) 28 { 29 ret.e[i][j] += e[i][k]*temp.e[k][j]; 30 } 31 } 32 } 33 return ret; 34 } 35 matrix operator + (const matrix &temp)const 36 { 37 matrix ret = matrix(n,m); 38 for(int i=1;i<=n;i++) 39 { 40 for(int j=1;j<=m;j++) 41 { 42 ret.e[i][j] += e[i][j]+temp.e[i][j]; 43 } 44 } 45 return ret; 46 } 47 void getE() 48 { 49 for(int i=1;i<=n;i++) 50 { 51 for(int j=1;j<=m;j++) 52 { 53 e[i][j] = i==j?1:0; 54 } 55 } 56 } 57 }; 58 59 matrix qpow(matrix temp,int x) 60 { 61 int sz = temp.n; 62 matrix base = matrix(sz,sz); 63 base.getE(); 64 while(x) 65 { 66 if(x & 1) base = base * temp; 67 x >>= 1; 68 temp = temp * temp; 69 } 70 return base; 71 } 72 73 matrix solve(matrix a, int k) 74 { 75 if(k == 1) return a; 76 int n = a.n; 77 matrix temp = matrix(n,n); 78 temp.getE(); 79 if(k & 1) 80 { 81 matrix ex = qpow(a,k); 82 k--; 83 temp = temp + qpow(a,k/2); 84 return temp * solve(a,k/2) + ex; 85 } 86 else 87 { 88 temp = temp + qpow(a,k/2); 89 return temp * solve(a,k/2); 90 } 91 } 92 93 struct Aho 94 { 95 struct state 96 { 97 int nxt[26]; 98 int fail,cnt; 99 }stateTable[MAX_Tot]; 100 101 int size; 102 103 queue<int> que; 104 105 void init() 106 { 107 while(que.size()) que.pop(); 108 for(int i=0;i<MAX_Tot;i++) 109 { 110 memset(stateTable[i].nxt,0,sizeof(stateTable[i].nxt)); 111 stateTable[i].fail = stateTable[i].cnt = 0; 112 } 113 size = 1; 114 } 115 116 void insert(char *s) 117 { 118 int n = strlen(s); 119 int now = 0; 120 for(int i=0;i<n;i++) 121 { 122 char c = s[i]; 123 if(!stateTable[now].nxt[c-'a']) 124 stateTable[now].nxt[c-'a'] = size++; 125 now = stateTable[now].nxt[c-'a']; 126 } 127 stateTable[now].cnt = 1; 128 } 129 130 void build() 131 { 132 stateTable[0].fail = -1; 133 que.push(0); 134 135 while(que.size()) 136 { 137 int u = que.front();que.pop(); 138 for(int i=0;i<26;i++) 139 { 140 if(stateTable[u].nxt[i]) 141 { 142 if(u == 0) stateTable[stateTable[u].nxt[i]].fail = 0; 143 else 144 { 145 int v = stateTable[u].fail; 146 while(v != -1) 147 { 148 if(stateTable[v].nxt[i]) 149 { 150 stateTable[stateTable[u].nxt[i]].fail = stateTable[v].nxt[i]; 151 // 在匹配fail指针的时候顺便更新cnt 152 if(stateTable[stateTable[stateTable[u].nxt[i]].fail].cnt == 1) 153 stateTable[stateTable[u].nxt[i]].cnt = 1; 154 break; 155 } 156 v = stateTable[v].fail; 157 } 158 if(v == -1) stateTable[stateTable[u].nxt[i]].fail = 0; 159 } 160 que.push(stateTable[u].nxt[i]); 161 } 162 /*****建立自动机nxt指针*****/ 163 else 164 { 165 if(u == 0) stateTable[u].nxt[i] = 0; 166 else 167 { 168 int p = stateTable[u].fail; 169 while(p != -1 && stateTable[p].nxt[i] == 0) p = stateTable[p].fail; 170 if(p == -1) stateTable[u].nxt[i] = 0; 171 else stateTable[u].nxt[i] = stateTable[p].nxt[i]; 172 } 173 } 174 /*****建立自动机nxt指针*****/ 175 } 176 } 177 } 178 179 matrix build_matrix() 180 { 181 matrix ans = matrix(size,size); 182 for(int i=0;i<size;i++) 183 { 184 for(int j=0;j<26;j++) 185 { 186 if(!stateTable[i].cnt && !stateTable[stateTable[i].nxt[j]].cnt) 187 ans.e[i+1][stateTable[i].nxt[j]+1]++; 188 } 189 } 190 return ans; 191 } 192 }aho; 193 194 void print(matrix p) 195 { 196 int n = p.n; 197 int m = p.m; 198 for(int i=1;i<=n;i++) 199 { 200 for(int j=1;j<=m;j++) 201 { 202 printf("%d ",p.e[i][j]); 203 } 204 puts(""); 205 } 206 } 207 208 int main() 209 { 210 while(scanf("%d%d",&m,&n) == 2) 211 { 212 aho.init(); 213 char s[15]; 214 for(int i=1;i<=m;i++) 215 { 216 scanf("%s",s); 217 aho.insert(s); 218 } 219 aho.build(); 220 matrix p = aho.build_matrix(); 221 p = solve(p,n); 222 ull temp = 0; 223 for(int i=1;i<=aho.size;i++) temp += p.e[1][i]; 224 matrix t = matrix(1,2); 225 t.e[1][2] = 1; 226 matrix A = matrix(2,2); 227 A.e[1][1] = A.e[2][1] = 26; A.e[2][2] = 1; 228 t = t * qpow(A,n); 229 ull ans = t.e[1][1] - temp; 230 printf("%llu ",ans); 231 } 232 return 0; 233 }
最后觉得,,我之前矩阵模板里的print()真好用啊233= =。