SAM hot tea %%%%%%%
首先我们显然可以将所有能够得到的字符串分成六类:(varnothing, ext{*},s, ext{*}s,s ext{*},s ext{*}t),其中 (s,t) 分别代表某个非空字符串,( ext{*}) 则代表题目中的星号,显然前两种情况的贡献都是 (1),算出后几种情况的答案后直接加 (2) 即可,第三种情况也异常简单,相当于求 (s) 中本质不同的子串个数,SAM 基操,相信来挑战这道题的人都会。第四种情况可以看作在某个 (s[1…n-1]) 的子串后面添上一个 ( ext{*}),因此第四种情况合法的子串数就是 (s[1…n-1]) 本质不同子串个数,这个一波 SAM 带走即可,同理第五种情况合法的子串数就是 (s[2…n]) 本质不同子串数,同样一波 SAM 带走,比较棘手的是第六种情况。
由于我们要求本质不同子串个数,因此枚举星号代替了原字符串中的哪个字符是不明智的,因为这样会算重。因此我们考虑退一步,我们不是刚才对 (s[1…n-1]) 建了一个 SAM 吗,那么我们就考虑枚举星号左边那部分对应的子串在 SAM 上哪个节点对应的字符串集合中,假设为 (x),那么根据 SAM 那一套理论,这个字符串所有出现位置的右端点都在 (x) 对应的 ( ext{endpos}) 中,也就是 (x) 在 parent tree 上子树内 (s) 所有前缀对应的节点。换句话说,所有星号可以替换的位置都在集合 ({y+1|yin ext{endpos}(x)}) 中,考虑星号右边的字符串有哪些取值,那么显然,星号右边的字符串必须是这些星号可以放的位置右边的字符组成的后缀的某个前缀,也就是 (S={s[y+2…n]|yin ext{endpos}(x)}) 中某个字符串的某个前缀,也就是说我们只需求 (S={s[y+2…n]|yin ext{endpos}(x)}) 有多少个本质不同的前缀即可。直接枚举前缀显然不可取,不过注意到如果我们对反串也建一个 SAM,那么反串的 SAM 上某个节点 (z) 表示的字符串符合要求,当且仅当存在一个 (yin ext{endpos}(x)),满足 (s[y+2…n]) 在反串的 SAM 上对应的节点在 (z) 的子树内,也就是说如果我们找出这样的 (| ext{endpos}(x)|) 个后缀在反串 SAM 上对应的节点,那么所有符合条件的点组成的集合就是它们到根节点路径上节点的并,而我们要求的,就是这个并中所有节点表示的字符串集合的大小之和。然后到这里我就不会做了,一直在想怎么剖。事实上这东西长得一脸虚树,根据虚树那一套理论,这东西节点并的字符串集合的大小,就是它们按照 DFS 序排序之后,DFS 序列相邻两个点+一头一尾路径上点的权值之和扣掉它们的 LCA 的权值的和除以 (2)(在本题中就是它们对应的等价类的大小,即 (len_x-len_{lnk_x})),外加它们的 LCA 到根节点这段路径上权值之和。然后我们考虑线段树合并维护这货,具体来说就开一棵线段树,以 DFS 序为下标,如果一个叶节点存在则表示其表示的区间的 DFS 序在该节点的子树内存在,然后对于线段树上每个节点,维护该区间中,在 SAM 上存在的 DFS 序的最小值、最大值,以及两两之间的距离和,每次转移时将一个点为根的线段树与其儿子节点表示的线段树合并一波即可算出答案。初始状态大概就对于所有 (iin[1,n-2]),我们找出长度为 (i) 的前缀在原串 SAM 上表示的节点,设其为 (x),然后假设 (s[i+2…n]) 在反串上表示的节点为 (y),那么我们就将 (y) 在反串上的 DFS 序加入 (x) 表示的线段树即可。使用 ST 表求 LCA 可以实现 (mathcal O(nlog n)) 的复杂度。
const int MAXN=1e5;
const int MAXP=2e5;
const int MAX_ND=MAXP<<5;
const int LOG_N=19;
char str[MAXN+5];int n;
struct graph{
int hd[MAXP+5],to[MAXP+5],nxt[MAXP+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
};
struct SAM{
struct node{int ch[27],len,lnk;} s[MAXP+5];
int cur,ncnt,ed[MAXN+5];graph t;
SAM(){cur=ncnt=1;}
void extend(char c,int ps){
int id=c-'a',nw=++ncnt,p=cur;
s[nw].len=s[cur].len+1;ed[ps]=nw;cur=nw;
while(p&&!s[p].ch[id]) s[p].ch[id]=nw,p=s[p].lnk;
if(!p) return s[nw].lnk=1,void();
int q=s[p].ch[id];
if(s[q].len==s[p].len+1) return s[nw].lnk=q,void();
int cl=++ncnt;s[cl].len=s[p].len+1;
s[cl].lnk=s[q].lnk;s[q].lnk=cl;s[nw].lnk=cl;
for(int i=0;i<26;i++) s[cl].ch[i]=s[q].ch[i];
while(p&&s[p].ch[id]==q) s[p].ch[id]=cl,p=s[p].lnk;
}
void build(){for(int i=2;i<=ncnt;i++) t.adde(s[i].lnk,i);}
ll calc(){
ll res=0;
for(int i=1;i<=ncnt;i++) res+=s[i].len-s[s[i].lnk].len;
return res;
}
} s_tot,fw,bk;//forward backward
#define g1 fw.t
#define g2 bk.t
int dep[MAXP+5];pii st[MAXP*2+5][LOG_N+2];
int dfn[MAXP+5],rid[MAXP+5],tim=0,eul[MAXP+5],tim_eul=0;
void dfs0(int x,int f){
rid[dfn[x]=++tim]=x;
st[eul[x]=++tim_eul][0]=mp(dep[x],x);
for(int e=g2.hd[x];e;e=g2.nxt[e]){
int y=g2.to[e];if(y==f) continue;
dep[y]=dep[x]+1;dfs0(y,x);
st[eul[x]=++tim_eul][0]=mp(dep[x],x);
}
}
int getlca(int x,int y){
x=eul[x];y=eul[y];if(x>y) swap(x,y);
int k=31-__builtin_clz(y-x+1);
return min(st[x][k],st[y-(1<<k)+1][k]).se;
}
int getdis(int x,int y){
int lc=getlca(x,y);
return bk.s[x].len+bk.s[y].len-(bk.s[lc].len<<1);
}
struct data{
int lc,lft,rit;ll sum;
data(int _lc=0,int _lft=0,int _rit=0,ll _sum=0):
lc(_lc),lft(_lft),rit(_rit),sum(_sum){}
data operator +(const data &rhs){
if(!lc) return rhs;
if(!rhs.lc) return data(lc,lft,rit,sum);
return data(getlca(lc,rhs.lc),lft,rhs.rit,
sum+rhs.sum+getdis(rid[rit],rid[rhs.lft]));
}
};
namespace segtree{
struct node{int ch[2];data d;} s[MAX_ND+5];
int ncnt=0;
void pushup(int k){s[k].d=s[s[k].ch[0]].d+s[s[k].ch[1]].d;}
void insert(int &k,int l,int r,int x){
if(!k) k=++ncnt;if(l==r) return s[k].d=data(rid[x],x,x,0),void();
int mid=l+r>>1;
(x<=mid)?insert(s[k].ch[0],l,mid,x):insert(s[k].ch[1],mid+1,r,x);
pushup(k);
}
int merge(int x,int y,int l,int r){
if(!x||!y) return x+y;int z=++ncnt;
if(l==r) return s[z].d=data(rid[l],l,l,0),z;
int mid=l+r>>1;
s[z].ch[0]=merge(s[x].ch[0],s[y].ch[0],l,mid);
s[z].ch[1]=merge(s[x].ch[1],s[y].ch[1],mid+1,r);
return pushup(z),z;
}
}
using segtree::insert;
using segtree::merge;
int rt[MAXP+5];ll res=0;
void dfs(int x,int f){
for(int e=g1.hd[x];e;e=g1.nxt[e]){
int y=g1.to[e];if(y==f) continue;
dfs(y,x);rt[x]=merge(rt[x],rt[y],1,bk.ncnt);
} data d=segtree::s[rt[x]].d;
if(d.lc) res+=1ll*(fw.s[x].len-fw.s[fw.s[x].lnk].len)*
(d.sum+getdis(rid[d.lft],rid[d.rit])+(getdis(d.lc,1)<<1));
}
int main(){
scanf("%s",str+1);n=strlen(str+1);
for(int i=1;i<=n;i++) s_tot.extend(str[i],/*19260817*/i);
for(int i=1;i<n;i++) fw.extend(str[i],i);
for(int i=n;i>1;i--) bk.extend(str[i],i);
fw.build();bk.build();
dfs0(1,0);
for(int i=1;i<=LOG_N;i++) for(int j=1;j+(1<<i)-1<=tim_eul;j++)
st[j][i]=min(st[j][i-1],st[j+(1<<i-1)][i-1]);
for(int i=1;i<n-1;i++) insert(rt[fw.ed[i]],1,bk.ncnt,dfn[bk.ed[i+2]]);
dfs(1,0);res>>=1;res+=2+s_tot.calc()+fw.calc()+bk.calc();
printf("%lld
",res);
return 0;
}