[BZOJ2555]SubString
试题描述
懒得写背景了,给你一个字符串init,要求你支持两个操作
(1):在当前字符串的后面插入一个字符串
(2):询问字符串s在当前字符串中出现了几次?(作为连续子串)
你必须在线支持这些操作。
输入
第一行一个数Q表示操作个数
第二行一个字符串表示初始字符串init
接下来Q行,每行2个字符串Type,Str
Type是ADD的话表示在后面插入字符串。
Type是QUERY的话表示询问某字符串在当前字符串中出现了几次。
为了体现在线操作,你需要维护一个变量mask,初始值为0
读入串Str之后,使用这个过程将之解码成真正询问的串TrueStr。
询问的时候,对TrueStr询问后输出一行答案Result
然后mask = mask xor Result
插入的时候,将TrueStr插到当前字符串后面即可。
HINT:ADD和QUERY操作的字符串都需要解压
输出
对于每个询问输出字符串出现次数。
输入示例
2 A QUERY B ADD BBABBBBAAB
输出示例
0
数据规模及约定
40 % 的数据字符串最终长度 <= 20000,询问次数<= 1000,询问总长度<= 10000
100 % 的数据字符串最终长度 <= 600000,询问次数<= 10000,询问总长度<= 3000000
题解
用 splay 动态维护 dfs 序(括号序列),这样每次 extend 的时候就相当于插入一个节点或者是把一颗子树连到另一个节点上;对应 splay 操作就是每次将一个区间拎出来插到另一个缝隙中,或是对一个区间进行查询。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; const int BufferSize = 1 << 16; char buffer[BufferSize], *Head, *Tail; inline char Getchar() { if(Head == Tail) { int l = fread(buffer, 1, BufferSize, stdin); Tail = (Head = buffer) + l; } return *Head++; } int read() { int x = 0, f = 1; char c = Getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); } return x * f; } #define maxn 1200010 #define maxa 26 int Rt, tot, ch[maxn<<1][2], fa[maxn<<1], dl[maxn<<1], dr[maxn<<1]; struct Node { int val, sum; Node() {} Node(int _): val(_) {} } ns[maxn<<1]; inline void maintain(int o) { ns[o].sum = ns[o].val; for(int i = 0; i < 2; i++) if(ch[o][i]) ns[o].sum += ns[ch[o][i]].sum; return ; } inline void rotate(int u) { int y = fa[u], z = fa[y], l = 0, r = 1; if(z) ch[z][ch[z][1]==y] = u; if(ch[y][1] == u) swap(l, r); fa[u] = z; fa[y] = u; fa[ch[u][r]] = y; ch[y][l] = ch[u][r]; ch[u][r] = y; maintain(y); maintain(u); return ; } inline void splay(int u) { while(fa[u]) { int y = fa[u], z = fa[y]; if(z) { if(ch[y][0] == u ^ ch[z][0] == y) rotate(u); else rotate(y); } rotate(u); } return ; } inline void split(int& lrt, int& rrt, int id) { splay(id); lrt = id; rrt = ch[id][1]; ch[id][1] = fa[rrt] = 0; maintain(lrt); return ; } inline void splitl(int& lrt, int& rrt, int id) { splay(id); lrt = ch[id][0]; rrt = id; ch[id][0] = fa[lrt] = 0; maintain(rrt); return ; } inline int merge(int a, int b) { if(!a) return b; if(!b) return a; while(ch[a][1]) a = ch[a][1]; splay(a); ch[a][1] = b; fa[b] = a; return maintain(a), a; } inline void Insert(int pos, int mrt) { int lrt, rrt; split(lrt, rrt, pos); lrt = merge(lrt, mrt); merge(lrt, rrt); return ; } inline void Insert2(int pos, int mrt, int m2) { int lrt, rrt; split(lrt, rrt, pos); lrt = merge(lrt, mrt); lrt = merge(lrt, m2); merge(lrt, rrt); return ; } inline int Create(int o, int v) { dl[o] = ++tot; dr[o] = ++tot; ns[dl[o]] = Node(v); ns[dr[o]] = Node(0); ch[dl[o]][1] = dr[o]; fa[dr[o]] = dl[o]; maintain(dr[o]); maintain(dl[o]); return dl[o]; } int rt, last, ToT, to[maxn][maxa], par[maxn], Max[maxn]; void extend(int x) { int p = last, np = ++ToT; Max[np] = Max[p] + 1; last = np; int Np = Create(np, 1); while(p && !to[p][x]) to[p][x] = np, p = par[p]; if(!p){ par[np] = rt; Insert(dl[rt], Np); return ; } int q = to[p][x]; if(Max[q] == Max[p] + 1){ par[np] = q; Insert(dl[q], Np); return ; } int nq = ++ToT, Nq = Create(nq, 0); Max[nq] = Max[p] + 1; memcpy(to[nq], to[q], sizeof(to[q])); par[nq] = par[q]; Insert(dl[par[q]], Nq); par[q] = par[np] = nq; int lrt, Q, rrt; splitl(lrt, Q, dl[q]); split(Q, rrt, dr[q]); merge(lrt, rrt); Insert2(dl[nq], Q, Np); while(p && to[p][x] == q) to[p][x] = nq, p = par[p]; return ; } char cmd[10], S[maxn]; void decode(char* S, int mark) { int n = strlen(S); for(int i = 0; i < n; i++) { mark = (mark * 131 + i) % n; swap(S[mark], S[i]); } return ; } int main() { // freopen("data.in", "r", stdin); // freopen("data.out", "w", stdout); rt = last = ToT = 1; Rt = Create(rt, 0); int q = read(), mark = 0; char tc = Getchar(); while(!isalpha(tc)) tc = Getchar(); int n = 0; while(isalpha(tc)) S[n++] = tc, tc = Getchar(); for(int i = 0; i < n; i++) extend(S[i] - 'A'); while(q--) { while(!isalpha(tc)) tc = Getchar(); n = 0; while(isalpha(tc)) cmd[n++] = tc, tc = Getchar(); cmd[n] = 0; while(!isalpha(tc)) tc = Getchar(); n = 0; while(isalpha(tc)) S[n++] = tc, tc = Getchar(); S[n] = 0; decode(S, mark); if(!strcmp(cmd, "ADD")) for(int i = 0; i < n; i++) extend(S[i] - 'A'); else { int p = rt; n = strlen(S); for(int i = 0; i < n; i++) p = to[p][S[i]-'A']; if(!p){ puts("0"); continue; } int lrt, mrt, rrt; splitl(lrt, mrt, dl[p]); split(mrt, rrt, dr[p]); printf("%d ", ns[mrt].sum); mark ^= ns[mrt].sum; lrt = merge(lrt, mrt); merge(lrt, rrt); } } return 0; }
应该能看出来我卡常的痕迹。。。