传送门
从题意可以看出,要从每个叶节点开始遍历给定的树,然后在这个过程中建广义SAM,最后统计不同子串数量。
注意下一边遍历树一边建广义 SAM 的方法,把父节点在 SAM 上的位置作为 last。
还要注意一下这个广义 SAM 的写法,特判一下已经存在的情况。
其实不特判基本上不会出问题,但是毕竟还是不标准。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=1e5+10;
const int M=1e6+10;
int n,k,clr[N],deg[N];
int head[N],to[N*2],nxt[N*2],tot;
void add(int u,int v){to[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
struct SuffixAutoMachine{
int tot=1,fa[M*2],len[M*2],ch[M*2][10];
int newnode(int x){fa[++tot]=fa[x];len[tot]=len[x];memcpy(ch[tot],ch[x],sizeof(ch[tot]));return tot;}
int extend(int p,int c){
int q=ch[p][c],nq=newnode(q);
len[nq]=len[p]+1;fa[q]=nq;
for(;p&&ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
return nq;
}
int append(int p,int c){
if(ch[p][c]) if(len[ch[p][c]]==len[p]+1) return ch[p][c];else return extend(p,c);
int np=newnode(0);len[np]=len[p]+1;
for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
if(!p) {fa[np]=1;return np;}
if(len[ch[p][c]]==len[p]+1) fa[np]=ch[p][c];
else fa[np]=extend(p,c);
return np;
}
void solve(){
LL ans=0;
for(int i=2;i<=tot;i++) ans+=len[i]-len[fa[i]];
printf("%lld
",ans);
}
}sam;
void dfs(int u,int fa,int p){
p=sam.append(p,clr[u]);
for(int i=head[u];i;i=nxt[i]) if(to[i]!=fa) dfs(to[i],u,p);
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) scanf("%d",&clr[i]);
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);deg[u]++;deg[v]++;
}
for(int i=1;i<=n;i++) if(deg[i]==1) dfs(i,0,1);
sam.solve();
return 0;
}