题目描述
在神秘的东方有一棵奇葩的树,它有一个固定的根节点(编号为1)。树的每条边上都是一个字符,字符为a,b,c中的一个,你可以从树上的任意一个点出发,然后沿着远离根的边往下行走,在任意一个节点停止,将你经过的边的字符依次写下来,就能得到一个字符串,例如:
在这棵树中我们能够得到的字符串是:
c, cb, ca, a, b, a
现在pty得到了一棵树和一个字符串S。如果S的一个子串[l,r]和树上某条路径所得到的字符串完全相同,则我们称这个子串和该路径匹配。现在pty想知道,S的所有子串和树上的所有路径的匹配总数是多少?
输入格式
第一行:n
接下来n-1行,每行一个整数fa, 一个字符c,字符和整数之间用一个空格隔开
第i行fa代表第i号节点的父亲,c表示第i号节点和fa的连边的字符
最后一行为字符串S
输出格式
输出共一行,表示匹配总数,因为评测系统暂未确定,所以C/C++选手请使用cout输出。
样例
样例输入
5
1 c
2 b
1 a
2 a
cba
样例输出
5
数据范围与提示
【样例说明】
单个字符匹配的对数为4对,两个字符匹配的对数为1对:cb
solution
广义后缀自动机
其实是在trie树上建。
具体看
记录 val[i] 表示i这个状态所代表的所有串的出现次数和
val=(max-min)*right
记sum为i到根的val之和。
做匹配的时候,我们考虑新加入一个字符。
假设这个字符的匹配点是i,之前匹配长度为len。
那么这个字符的答案就是(len-s[fa].max)*ri[i]+s[fa].sum.
也就是该字符串在这个点的匹配长度,加上父亲点的sum,也就是以这个字符结尾的其他后缀的贡献。
1 格式化代码 2 #include<cstdio> 3 #include<iostream> 4 #include<cstdlib> 5 #include<cstring> 6 #include<algorithm> 7 #include<cmath> 8 #define maxn 800005 9 #define ll long long 10 using namespace std; 11 int n,id[maxn],rt=1,cnt=1,la=1,a[maxn]; 12 struct node{ 13 int nex[26],par,Max,ri; 14 ll val,sum; 15 }s[maxn*3]; 16 char ch[maxn*10]; 17 ll tax[maxn]; 18 void ins(int fa,int c,int i){ 19 20 int np=++cnt,p=id[fa];s[np].ri=1; 21 //cout<<i<<' '<<c<<' '<<p<<endl; 22 s[np].Max=s[la].Max+1;id[i]=la=np;//这里还不是很懂为什么是la 23 24 for(;p&&!s[p].nex[c];p=s[p].par)s[p].nex[c]=np; 25 if(!p)s[np].par=rt; 26 else { 27 int q=s[p].nex[c]; 28 if(s[q].Max==s[p].Max+1)s[np].par=q; 29 else { 30 int nq=++cnt; 31 for(int i=0;i<26;i++)s[nq].nex[i]=s[q].nex[i]; 32 s[nq].par=s[q].par;s[nq].Max=s[p].Max+1; 33 s[q].par=s[np].par=nq; 34 for(;s[p].nex[c]==q;p=s[p].par)s[p].nex[c]=nq; 35 } 36 } 37 } 38 void Calc(){ 39 for(int i=1;i<=cnt;i++)tax[s[i].Max]++; 40 for(int i=1;i<=cnt;i++)tax[i]+=tax[i-1]; 41 for(int i=1;i<=cnt;i++)a[tax[s[i].Max]--]=i; 42 for(int i=cnt;i>=1;i--){ 43 int k=a[i],f=s[k].par; 44 s[f].ri+=s[k].ri; 45 s[k].val=s[k].ri*(s[k].Max-s[f].Max); 46 } 47 for(int i=1;i<=cnt;i++){ 48 int k=a[i]; 49 s[k].sum=s[k].val+s[s[k].par].sum; 50 } 51 } 52 int main(){ 53 cin>>n;id[1]=1; 54 for(int i=2,fa;i<=n;i++){ 55 char c;scanf("%d %c",&fa,&c); 56 ins(fa,c-'a',i); 57 } 58 Calc(); 59 scanf("%s",ch);int l=strlen(ch),len=0; 60 int p=rt;ll ans=0; 61 for(int i=0;i<l;i++){ 62 for(;!s[p].nex[ch[i]-'a']&&p;p=s[p].par,len=s[p].Max); 63 if(!p)p=rt,len=0; 64 else { 65 p=s[p].nex[ch[i]-'a'];len++; 66 int f=s[p].par; 67 ans=ans+s[f].sum+(len-s[f].Max)*s[p].ri; 68 } 69 } 70 cout<<ans<<endl; 71 return 0; 72 }