题意
维护一个集合,满足下面三种操作:
加字符串
删字符串
查询集合中的所有字符串在给出的模板串中出现的次数
强制在线
Analysis
如果不强制在线我们就可以用cdq来做,但这道糟糕的题目它强制在线。。。
首先我们发现对于加和删我们可以维护两个AC自动机,然后答案相减即可。
现在的问题在于怎么往一个AC自动机中插入一个串然后维护AC自动机。
二进制分组的思想就是把每(2^i)个修改放在一个数据结构里维护,然后再类似2048一样合并,最后暴力重构。
比如:(31=16+8+4+2+1),又多了一个变成(32=16+8+4+2+1+1=16+8+4+2+2=16+8+4+4=16+8+8=16+16=32)。
查询时就在这log个里面查,时间复杂度是(O(nlog{n}))级别的很好理解。
关键是修改的复杂度。
可以发现第(k)个修改的暴力重构的大小即为(lowbit(k)),设重构的复杂度是(O(f(size)))。
则修改复杂度为(sum_{i=1}^n O(f(lowbit(k)))=sum_{i=1}^{log n} O(frac{n}{2^{i+1}} * f(2^i)) le sum_{i=1}^{log n} O(f(i))=O(f(n)log{n}))。
这道题就迎刃而解了。注意为了合并要维护一个trie树再维护一个AC自动机。
能用二进制分组的前提是可以离线分治解,即修改之间互相独立。
// CF710F String Set Queries
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int LEN = 300010;
int n;
struct ACAM{
int ch[LEN][26], trie[LEN][26], fail[LEN], sum[LEN], end[LEN];
int sz[21], rt[21];
int tot, top;
ACAM() {
tot = top = 0;
}
void get_fail(int id) {
queue<int> q;
while (!q.empty()) q.pop();
int p = rt[id];
fail[p] = 0;
for (int i = 0; i < 26; i++) {
if (trie[p][i]) {
ch[p][i] = trie[p][i];
fail[ch[p][i]] = p;
q.push(ch[p][i]);
} else {
ch[p][i] = p;
}
}
while (!q.empty()) {
p = q.front();
q.pop();
sum[p] = end[p] + sum[fail[p]];
for (int i = 0; i < 26; i++) {
if (trie[p][i]) {
ch[p][i] = trie[p][i];
fail[ch[p][i]] = ch[fail[p]][i];
q.push(ch[p][i]);
} else {
ch[p][i] = ch[fail[p]][i];
}
}
}
}
int merge(int x, int y) {
if (!x || !y) return x | y;
end[x] += end[y];
for (int i = 0; i < 26; i++) {
trie[x][i] = merge(trie[x][i], trie[y][i]);
}
return x;
}
void insert(char *s) {
rt[++top] = ++tot;
sz[top] = 1;
int len = strlen(s + 1);
int p = rt[top];
for (int i = 1; i <= len; i++) {
p = trie[p][s[i] - 'a'] = ++tot;
}
end[p]++;
while (top > 1 && sz[top] == sz[top - 1]) {
rt[top - 1] = merge(rt[top - 1], rt[top]);
sz[top - 1] += sz[top];
sz[top] = rt[top] = 0;
top--;
}
get_fail(top);
}
int ask(char *s) {
int len = strlen(s + 1);
int ret = 0;
for (int i = 1; i <= top; i++) {
int p = rt[i];
for (int j = 1; j <= len; j++) {
p = ch[p][s[j] - 'a'];
ret += sum[p];
}
}
return ret;
}
}tr1, tr2;
int main() {
scanf("%d", &n);
while (n--) {
int opt;
char s[LEN];
scanf("%d%s", &opt, s + 1);
if (opt == 1) {
tr1.insert(s);
} else if (opt == 2) {
tr2.insert(s);
} else {
printf("%d
", tr1.ask(s) - tr2.ask(s));
fflush(stdout);
}
}
return 0;
}