struct PAM {
static const int MAXN = 1e6 + 10;
char s[MAXN];
int ch[MAXN][26], len[MAXN], fail[MAXN], dep[MAXN];
int cnt, slen, lst;
int R[MAXN];
void Init() {
ms(ch[0]), len[0] = 0, fail[0] = 1, dep[0] = 0;
ms(ch[1]), len[1] = -1, fail[1] = 0, dep[1] = 0;
cnt = 1, slen = 0, lst = 0;
}
int NewNode() {
int o = ++cnt;
ms(ch[o]);
return o;
}
int Fail(int u) {
while(s[slen - len[u] - 1] != s[slen])
u = fail[u];
return u;
}
void Extend(char c) {
s[++slen] = c;
int u = Fail(lst), o = ch[u][c - 'a'];
if(o == 0) {
o = NewNode();
len[o] = len[u] + 2;
fail[o] = ch[Fail(fail[u])][c - 'a'];
dep[o] = dep[fail[o]] + 1;
ch[u][c - 'a'] = o;
}
lst = o;
R[slen] = dep[o];
}
} pam;
时间消耗较大,空间消耗较小的版本。
struct PAM {
static const int MAXN = 1e6 + 10;
char s[MAXN];
vector<pii> ch[MAXN];
int len[MAXN], fail[MAXN], dep[MAXN];
int cnt, slen, lst;
int R[MAXN];
void Init() {
ch[0].clear(), len[0] = 0, fail[0] = 1, dep[0] = 0;
ch[1].clear(), len[1] = -1, fail[1] = 0, dep[1] = 0;
cnt = 1, slen = 0, lst = 0;
}
int NewNode() {
int o = ++cnt;
ch[o].clear();
return o;
}
int Fail(int u) {
while(s[slen - len[u] - 1] != s[slen])
u = fail[u];
return u;
}
int getCh(int u, int c) {
for(int i = 0; i < ch[u].size(); ++i) {
if(ch[u][i].first == c)
return ch[u][i].second;
}
return 0;
}
void setCh(int u, int c, int o) {
ch[u].eb(c, o);
}
void Extend(char c) {
s[++slen] = c;
int u = Fail(lst), o = getCh(u, c);
if(o == 0) {
o = NewNode();
len[o] = len[u] + 2;
fail[o] = getCh(Fail(fail[u]), c);
dep[o] = dep[fail[o]] + 1;
setCh(u, c, o);
}
lst = o;
R[slen] = dep[o];
}
} pam;
加入一个新字符之后,pam的节点停留在lst,表示包含这个新字符的最长的回文后缀的节点,这个时候可以对lst进行一些统计频次之类的操作。
例:统计回文子串的数量:
pam中的一个节点代表一类本质相同的回文子串,那么Extend一个新字符之后,pam的lst就停留在新节点,然后多出来的回文子串就是lst的fail高度,由于pam的fail是不会变的,所以可以新增节点时直接继承自fail父亲的高度+1,即dep[o] = dep[fail[o]] + 1;
,然后,无论是否新增节点,统计以 (i) 位置为右端点的回文子串的数量,即 R[slen] = dep[o];
,然后全部把所有位置加起来就可以了。
统计相交的回文子串对的数量:
由上一例的方法统计回文子串的数量,然后求出不相交的回文子串对的数量,减去即可。问题转化成怎么求不相交的对,注意到一个以 (i) 位置为右端点的子串,和所有 (i+1,i+2,...,n) 位置为左端点的子串均不相交,所以先正序一次找出R,然后反序一次找出L,然后把L翻转之后从后往前dp。