SAM+LCT模板题。
题目相当于求询问串在SAM上走到的状态的right集合大小,等于parent树上这个节点的子树中有多少前缀节点(right集合大小为1的节点)。
因为强制在线,所以parent树用LCT维护。注意是维护子树和。代码还是非常好写的,也非常好想。
#include <bits/stdc++.h>
using namespace std;
const int N = 601000;
char str[N];
void getstr(int p){
scanf("%s",str);
int len = strlen(str);
for (int i = 0;i < len;i++){
p = (p*131+i)%len;
swap(str[i],str[p]);
}
}
struct Node{
int son[2],fa,sum,sum1,sum2;
Node(){}
Node(int _fa,int _sum):fa(_fa),sum(_sum),sum1(_sum){
son[0] = son[1] = sum2 = 0;}
};
struct Lct{
Node nod[N<<1];
void insert(int p,int d){
nod[p] = Node(0,d);
}
void link(int x,int y){
access(x);
access(y);
splay(x);
splay(y);
nod[x].fa = y;
nod[y].sum2 += nod[x].sum;
nod[y].sum += nod[x].sum;
}
void cut(int x){
access(x);
splay(x);
int u = nod[x].son[0];
if (u){
nod[x].son[0] = 0;
update(x);
nod[u].fa = 0;
}
}
void access(int x){
int y = 0;
while (x){
splay(x);
//!!!!
int u = nod[x].son[1];
if (u) nod[x].sum2 += nod[u].sum;
if (nod[y].fa == x) nod[x].sum2 -= nod[y].sum;
nod[x].son[1] = y;
update(x);
y = x;x = nod[x].fa;
}
}
void splay(int x){
int w;
while ((w = check(x)) != -1){
int y = nod[x].fa;
if (w == check(y)) rotate(y,w^1);
rotate(x,w^1);
}
}
void rotate(int x,int d){
int y = nod[x].fa,z = nod[y].fa,w = check(y);
nod[x].fa = z;
if (w != -1) nod[z].son[w] = x;
nod[y].son[d^1] = nod[x].son[d];
if (nod[x].son[d]) nod[nod[x].son[d]].fa = y;
nod[y].fa = x;
nod[x].son[d] = y;
update(y);
update(x);
}
void update(int p){
int u = nod[p].son[0],v = nod[p].son[1];
nod[p].sum = nod[p].sum1+nod[p].sum2;
if (u) nod[p].sum += nod[u].sum;
if (v) nod[p].sum += nod[v].sum;
}
int check(int x){
int y = nod[x].fa;
if (!y) return -1;
if (nod[y].son[0] == x) return 0;
if (nod[y].son[1] == x) return 1;
return -1;
}
int getans(int p){
if (p == 0) return 0;
access(p);
splay(p);
return nod[p].sum-nod[nod[p].son[0]].sum;
}
}lct;
struct State{
int go[26],par,val;
State(){}
State(int _val):par(0),val(_val){
memset(go,0,sizeof(go));
}
}state[N<<1];
int root = 1,len = 1,last = 1;
int q,mask,lastans;
char opt[10];
void extend(int);
int trans();
int main(){
scanf("%d",&q);
scanf("%s",str);
int slen = strlen(str);
for (int i = 0;i < slen;i++)
extend(str[i]-'A');
while (q--){
scanf("%s",opt);
if (opt[0] == 'A'){
getstr(mask);
int lenn = strlen(str);
for (int i = 0;i < lenn;i++)
extend(str[i]-'A');
}
else{
getstr(mask);
lastans = 0;
int u = trans();
lastans = lct.getans(u);
printf("%d
",lastans);
mask ^= lastans;
}
}
return 0;
}
void extend(int w){
int p = last,np = ++len;
state[np] = State(state[p].val+1);
lct.insert(np,1);
while (p && state[p].go[w] == 0)
state[p].go[w] = np,p = state[p].par;
if (p == 0){
state[np].par = root;
lct.link(np,root);
}
else{
int q = state[p].go[w];
if (state[q].val == state[p].val+1){
state[np].par = q;
lct.link(np,q);
}
else{
int nq = ++len;state[nq] = State(state[p].val+1);
lct.insert(nq,0);
memcpy(state[nq].go,state[q].go,sizeof(state[q].go));
state[nq].par = state[q].par;
state[q].par = nq;
state[np].par = nq;
lct.cut(q);
lct.link(nq,state[nq].par);
lct.link(q,nq);
lct.link(np,nq);
while (p && state[p].go[w] == q)
state[p].go[w] = nq,p = state[p].par;
}
}
last = np;
}
int trans(){
int len = strlen(str);
int p = root;
for (int i = 0;i < len;i++)
p = state[p].go[str[i]-'A'];
return p;
}