题面
https://www.luogu.org/problem/P3346
题解
#include<cstdio> #include<iostream> #include<cstring> #include<vector> #define ri register int #define N 100050 #define LL long long using namespace std; int n,m,c[N]; vector<int> to[N]; struct SAM { int len[42*N],ff[42*N],ch[42*N][10]; int tot; void copy(int nq,int q){ ff[nq]=ff[q]; for (ri i=0;i<m;i++) ch[nq][i]=ch[q][i]; } inline int extend(int p,int c) { int np=++tot,q,nq; len[np]=len[p]+1; while (p && !ch[p][c]) ch[p][c]=np,p=ff[p]; if (!p) ff[np]=1; else { q=ch[p][c]; if (len[p]+1==len[q]) ff[np]=q; else { nq=++tot; copy(nq,q); len[nq]=len[p]+1; ff[q]=ff[np]=nq; while (p && ch[p][c]==q) ch[p][c]=nq,p=ff[p]; } } return np; } LL work() { LL ret=0; for (ri i=1;i<=tot;i++) ret+=len[i]-len[ff[i]]; return ret; } }sam; void build(int ff,int now,int pre) { int t=sam.extend(pre,c[now]); for (ri i=0;i<to[now].size();i++) if (to[now][i]!=ff) build(now,to[now][i],t); } int main(){ scanf("%d %d",&n,&m); sam.tot=1; for (ri i=1;i<=n;i++) scanf("%d",&c[i]); int u,v; for (ri i=1;i<n;i++) { scanf("%d %d",&u,&v); to[u].push_back(v); to[v].push_back(u); } for (ri i=1;i<=n;i++) if (to[i].size()==1) build(-1,i,1); printf("%lld ",sam.work()); }