题意:求一个字符串中有多少形如AABB的子串。
解:嗯...我首先极度SB的想了一个后缀自动机套线段树启发式合并的做法,想必会TLE。
然后跑去看题解,发现实在是妙不可言...
显然要对每个位置求出向左有多少个AA,向右有多少个BB。
我的想法是对于每个前缀,两两求lca,如果lca的len大于他们的位置之差,显然就有一组了。
这时候把贡献加到其中较长的前缀上。然后反着来一遍就行了。
怎么批量求lca和贡献呢?
考虑计算每个点作为lca时的贡献,显然线段树维护子树内有哪些前缀。合并的时候好像没啥好的办法...但是我们有启发式合并!
每次取出小的线段树中的所有元素,依次加入大的线段树中。对于大的线段树中比它小的一段区间内的元素,我们要给它自己加上贡献。对于比它大的一段区间中的元素,要给那些大的元素每个+1贡献。我们就在每次需要插入元素的时候往下推。推到底的时候加贡献即可。(应该支持吧...)
比较菜没写代码...感觉实现起来毒瘤的紧。
然后说正解。
考虑枚举AA串的长度。
对于一个长为2len的AA串,如果我们每隔len放一个点,那么这样的串将会且仅会覆盖两个连续的点。
对于每两个连续的点,我们求它们的最长公共前/后缀长度,分别设为x,y。
如果x + y >= len的话就是存在这样的AA串经过这两点。然后就是个线段树区间+1
最后遍历线段树统计答案即可。
求lcp不就是SAM的fail树上lca嘛,我会倍增!
Tnlog2n成功T飞...
然后就O(1)lca过了...果然O(1)lca还是有用的。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 5 typedef long long LL; 6 const int N = 80010; 7 8 char str[N]; 9 int pos[N], pos2[N], pw[N * 2], n; 10 11 struct SAM { 12 13 struct Edge { 14 int nex, v; 15 }edge[N]; int top; 16 17 int tr[N][26], fail[N], len[N], tot, last; 18 int ST[N * 2][20], pos[N * 2], num, e[N], d[N]; 19 20 SAM() { 21 tot = last = 1; 22 } 23 24 inline void add(int x, int y) { 25 top++; 26 edge[top].v = y; 27 edge[top].nex = e[x]; 28 e[x] = top; 29 return; 30 } 31 32 inline void insert(char c) { 33 int f = c - 'a'; 34 int p = last, np = ++tot; 35 last = np; 36 len[np] = len[p] + 1; 37 while(p && !tr[p][f]) { 38 tr[p][f] = np; 39 p = fail[p]; 40 } 41 if(!p) { 42 fail[np] = 1; 43 } 44 else { 45 int Q = tr[p][f]; 46 if(len[Q] == len[p] + 1) { 47 fail[np] = Q; 48 } 49 else { 50 int nQ = ++tot; 51 len[nQ] = len[p] + 1; 52 fail[nQ] = fail[Q]; 53 fail[Q] = fail[np] = nQ; 54 memcpy(tr[nQ], tr[Q], sizeof(tr[Q])); 55 while(tr[p][f] == Q) { 56 tr[p][f] = nQ; 57 p = fail[p]; 58 } 59 } 60 } 61 } 62 63 void DFS(int x) { 64 pos[x] = ++num; 65 ST[num][0] = x; 66 for(int i = e[x]; i; i = edge[i].nex) { 67 int y = edge[i].v; 68 d[y] = d[x] + 1; 69 DFS(y); 70 ST[++num][0] = x; 71 } 72 return; 73 } 74 75 inline void prework() { 76 for(int i = 2; i <= tot; i++) { 77 add(fail[i], i); 78 } 79 d[1] = 1; 80 DFS(1); 81 for(int j = 1; j <= pw[num]; j++) { 82 for(int i = 1; i + (1 << j) - 1 <= num; i++) { 83 if(d[ST[i][j - 1]] <= d[ST[i + (1 << (j - 1))][j - 1]]) { 84 ST[i][j] = ST[i][j - 1]; 85 } 86 else { 87 ST[i][j] = ST[i + (1 << (j - 1))][j - 1]; 88 } 89 } 90 } 91 return; 92 } 93 94 inline int lca(int x, int y) { 95 x = pos[x]; 96 y = pos[y]; 97 if(x > y) { 98 std::swap(x, y); 99 } 100 int t = pw[y - x + 1]; 101 if(d[ST[x][t]] <= d[ST[y - (1 << t) + 1][t]]) { 102 return ST[x][t]; 103 } 104 return ST[y - (1 << t) + 1][t]; 105 } 106 107 inline void clear() { 108 for(int i = 1; i <= tot; i++) { 109 d[i] = e[i] = 0; 110 for(int f = 0; f < 26; f++) { 111 tr[i][f] = 0; 112 } 113 } 114 tot = last = 1; 115 top = num = 0; 116 return; 117 } 118 119 inline int lcp(int x, int y) { 120 return std::min(std::min(len[x], len[y]), len[lca(x, y)]); 121 } 122 123 }sam, sam2; 124 125 struct SegmentTree { 126 int tag[N * 2]; 127 int f[N]; 128 inline void pushdown(int o) { 129 if(!tag[o]) { 130 return; 131 } 132 tag[o << 1] += tag[o]; 133 tag[o << 1 | 1] += tag[o]; 134 tag[o] = 0; 135 return; 136 } 137 138 void add(int L, int R, int l, int r, int o) { 139 if(L <= l && r <= R) { 140 tag[o]++; 141 return; 142 } 143 int mid = (l + r) >> 1; 144 pushdown(o); 145 if(L <= mid) { 146 add(L, R, l, mid, o << 1); 147 } 148 if(mid < R) { 149 add(L, R, mid + 1, r, o << 1 | 1); 150 } 151 return; 152 } 153 154 void solve(int l, int r, int o) { 155 if(l == r) { 156 f[r] = tag[o]; 157 return; 158 } 159 pushdown(o); 160 int mid = (l + r) >> 1; 161 solve(l, mid, o << 1); 162 solve(mid + 1, r, o << 1 | 1); 163 return; 164 } 165 void clear(int l, int r, int o) { 166 tag[o] = 0; 167 if(l == r) { 168 return; 169 } 170 int mid = (l + r) >> 1; 171 clear(l, mid, o << 1); 172 clear(mid + 1, r, o << 1 | 1); 173 return; 174 } 175 }seg, seg2; 176 177 inline void solve() { 178 scanf("%s", str); 179 LL ans = 0; 180 n = strlen(str); 181 for(int i = 0; i < n; i++) { 182 sam.insert(str[i]); 183 sam2.insert(str[n - i - 1]); 184 pos[i] = sam.last; 185 pos2[n - i - 1] = sam2.last; 186 } 187 sam.prework(); 188 sam2.prework(); 189 // 190 for(int len = 1; (len << 1) < n - 1; len++) { 191 //printf("len = %d ", len); 192 for(int i = len; i < n; i += len) { 193 // i i-len 194 //printf(" > %d %d ", i - len, i); 195 int x = std::min(len, sam.lcp(pos[i], pos[i - len])); 196 int y = std::min(len, sam2.lcp(pos2[i], pos2[i - len])); 197 // x + y - len 198 //printf(" > x = %d y = %d ", x, y); 199 if(x + y > len) { 200 seg.add(i - len - x + 2, i - len * 2 + y + 1, 1, n, 1); 201 //printf(" > > > 1 add %d %d ", i - len - x + 2, i - len * 2 + y + 1); 202 seg2.add(i + len - x + 1, i + y, 1, n, 1); 203 //printf(" > > > 2 add %d %d ", i + len - x + 1, i + y); 204 } 205 } 206 } 207 seg.solve(1, n, 1); 208 seg2.solve(1, n, 1); 209 for(int i = 2; i < n - 1; i++) { 210 ans += 1ll * seg2.f[i] * seg.f[i + 1]; 211 //printf("ans += %d * %d ", seg2.f[i], seg.f[i + 1]); 212 } 213 printf("%lld ", ans); 214 return; 215 } 216 217 int main() { 218 219 for(int i = 2; i < N * 2; i++) { 220 pw[i] = pw[i >> 1] + 1; 221 } 222 int T; 223 scanf("%d", &T); 224 while(T--) { 225 solve(); 226 if(T) { 227 sam.clear(); 228 sam2.clear(); 229 seg.clear(1, n, 1); 230 seg2.clear(1, n, 1); 231 } 232 } 233 return 0; 234 }