P4218 [CTSC2010]珠宝商
神题...
可以想到点分治,细节不写了。。。
(学了个新姿势,sam可以在前面加字符
但是一次点分治只能做到(O(m)),考虑(sqrt n)点分治,如果子树大小(>sqrt n)就用(O(m))的点分治做法,否则用蛤希暴力。
然而块大小设为(20,30)比(sqrt n)快多了...
#include<bits/stdc++.h>
#define il inline
#define vd void
#define frog 19260817
typedef long long ll;
typedef unsigned long long ull;
il ll gi(){
ll x=0,f=1;
char ch=getchar();
while(!isdigit(ch))f^=ch=='-',ch=getchar();
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
#define qt 20
std::unordered_map<ull,int>qaq[50010/qt*2+1];
ull Base[50010];
int n,m;
ll ans;
char S[50010],T[50010];
int fir[50010],dis[100010],nxt[100010],id;
il vd link(int a,int b){nxt[++id]=fir[a],fir[a]=id,dis[id]=b;}
int siz[50010],f[50010],FA[50010],N,rt;bool vis[50010];
struct SAM{
int slink[100010],trans[100010][26],pos[100010],son[100010][26],len[100010],endpos[100010],cnt,lst,leaf[50010],flg;
char ss[100010];
SAM(){cnt=0;lst=++cnt;len[lst]=0;}
il vd extend(int ch,int i){
int p=lst,np=++cnt;len[np]=len[p]+1;lst=np;endpos[np]=1;leaf[i]=np;pos[np]=i;
while(p&&!trans[p][ch])trans[p][ch]=np,p=slink[p];
if(!p)slink[np]=1;
else{
int q=trans[p][ch];
if(len[q]==len[p]+1)slink[np]=q;
else{
int nq=++cnt;
slink[nq]=slink[q],len[nq]=len[p]+1,memcpy(trans[nq],trans[q],sizeof trans[q]);
while(p&&trans[p][ch]==q)trans[p][ch]=nq,p=slink[p];
slink[np]=slink[q]=nq;
}
}
}
int t[100010],st[100010];
il vd prepare(){
for(int i=1;i<=m;++i)ss[i]=S[i];
for(int i=1;i<=cnt;++i)++t[len[i]];
for(int i=1;i<=cnt;++i)t[i]+=t[i-1];
for(int i=cnt;i>1;--i)st[t[len[i]]--]=i;
for(int i=cnt,x;i;--i){
x=st[i];endpos[slink[x]]+=endpos[x];
if(!pos[slink[x]])pos[slink[x]]=pos[x];
if(pos[x]-len[slink[x]])son[slink[x]][ss[pos[x]-len[slink[x]]]-'a']=x;
}
}
int tag[100010];
il vd calc(){
for(int i=1,x;i<=cnt;++i)x=st[i],tag[x]+=tag[slink[x]];
}
il vd dfs3(int x,int fa,int y,int _len){
if(_len==len[y])y=son[y][T[x]-'a'];
else if(ss[pos[y]-_len]!=T[x])y=0;
if(!y)return;
++tag[y];++_len;
//printf("%d %d %d %d
",x,fa,y,_len);
for(int i=fir[x];i;i=nxt[i]){
if(fa==dis[i]||vis[dis[i]])continue;
dfs3(dis[i],x,y,_len);
}
}
}sam,rsam;
il vd getrt(int x,int fa=-1){
siz[x]=1,f[x]=0;
for(int i=fir[x];i;i=nxt[i]){
if(fa==dis[i]||vis[dis[i]])continue;
FA[dis[i]]=x;
getrt(dis[i],x);
siz[x]+=siz[dis[i]];
f[x]=std::max(f[x],siz[dis[i]]);
}
f[x]=std::max(f[x],N-siz[x]);
if(f[rt]>f[x])rt=x;
}
std::vector<int>G;
il vd dfs(int x,int fa=-1){
G.push_back(x);
for(int i=fir[x];i;i=nxt[i]){
if(fa==dis[i]||vis[dis[i]])continue;
dfs(dis[i],x);
}
}
il vd dfs2(int x,int y,int fa=-1){
if(!y)return;
ans+=sam.endpos[y];
for(int i=fir[x];i;i=nxt[i]){
if(fa==dis[i]||vis[dis[i]])continue;
dfs2(dis[i],sam.trans[y][T[dis[i]]-'a'],x);
}
}
std::vector<ull>A,B;
std::vector<int>LA,LB;
il vd dfs2_(int x,ull HA,ull HB,int len,int fa=-1){
HA=(HA+T[x])*frog,HB+=Base[++len]*T[x];
A.push_back(HA),LA.push_back(len+1),B.push_back(HB),LB.push_back(len);
for(int i=fir[x];i;i=nxt[i]){
if(fa==dis[i]||vis[dis[i]])continue;
dfs2_(dis[i],HA,HB,len,x);
}
}
il vd work(int x,int fa,ll o){
if(siz[x]<=qt){
A.clear(),B.clear();LA.clear();LB.clear();
dfs2_(x,(ull)frog*T[fa],0,0,fa);
for(int i=0;i<A.size();++i)
for(int j=0;j<B.size();++j){
ull H=A[i]+B[j]*Base[LA[i]];
if(qaq[LA[i]+LB[j]].count(H))ans+=o*qaq[LA[i]+LB[j]][H];
}
return;
}
memset(sam.tag,0,(sam.cnt+1)*4);memset(rsam.tag,0,(rsam.cnt+1)*4);
if(fa)sam.dfs3(x,fa,sam.son[1][T[fa]-'a'],1),rsam.dfs3(x,fa,rsam.son[1][T[fa]-'a'],1);
else sam.dfs3(x,fa,1,0),rsam.dfs3(x,fa,1,0);
sam.calc(),rsam.calc();
for(int i=1;i<=m;++i)ans+=o*sam.tag[sam.leaf[i]]*rsam.tag[rsam.leaf[m-i+1]];
}
il vd solve(int x){
if(siz[x]<=qt){
G.clear();dfs(x);
for(int i:G)dfs2(i,sam.trans[1][T[i]-'a']);
for(int i:G)vis[i]=1;
return;
}
work(x,0,1);
vis[x]=1;
for(int i=fir[x];i;i=nxt[i]){
if(vis[dis[i]])continue;
work(dis[i],x,-1);
}
for(int i=fir[x];i;i=nxt[i]){
if(vis[dis[i]])continue;
rt=0,N=siz[dis[i]],getrt(dis[i]),solve(rt);
}
}
int main(){
#ifdef XZZSB
freopen("in.in","r",stdin);
freopen("out.out","w",stdout);
#endif
sam.flg=1,rsam.flg=0;
n=gi(),m=gi();int a,b;
for(int i=1;i<n;++i)a=gi(),b=gi(),link(a,b),link(b,a);
scanf("%s",T+1),scanf("%s",S+1);
Base[0]=1;for(int i=1;i<=m;++i)Base[i]=Base[i-1]*frog;
for(int i=1;i<=m;++i){
ull Hash=0;
for(int j=1;i+j-1<=m&&j<=qt*2+1;++j)Hash+=Base[j]*S[i+j-1],++qaq[j][Hash];
}
for(int i=1;i<=m;++i)sam.extend(S[i]-'a',i);
sam.prepare();
std::reverse(S+1,S+m+1);
for(int i=1;i<=m;++i)rsam.extend(S[i]-'a',i);
rsam.prepare();
std::reverse(S+1,S+m+1);
N=n;f[0]=1e9,rt=0,getrt(1),solve(rt);
printf("%lld
",ans);
return 0;
}