给n<=300000的树,每个点上有一个字母,一个点的权值为:从该点出发向下走到任意节点停下形成的不同字符串的数量,问最大权值。
题目本身还有一些奇怪要求在此忽略。。
Trie合并的模板题。
1 #include<stdio.h> 2 #include<string.h> 3 #include<stdlib.h> 4 #include<math.h> 5 //#include<queue> 6 #include<algorithm> 7 #include<iostream> 8 using namespace std; 9 10 bool isdigit(char c) {return c>='0' && c<='9';} 11 int qread() 12 { 13 char c;int s=0,t=1;while (!isdigit(c=getchar())) (c=='-' && (t=-1)); 14 do s=s*10+c-'0'; while (isdigit(c=getchar())); return s*t; 15 } 16 17 int n; 18 #define maxn 300011 19 struct Edge{int to,next;}edge[maxn<<1];int first[maxn],le=2; 20 void in(int x,int y) {Edge &e=edge[le];e.to=y;e.next=first[x];first[x]=le++;} 21 void insert(int x,int y) {in(x,y);in(y,x);} 22 23 int c[maxn]; 24 char s[maxn]; 25 struct Trie 26 { 27 int ch[maxn<<1][26],size,val[maxn<<1]; 28 Trie() {memset(ch[0],0,sizeof(ch[0]));size=0;} 29 int id(char c) {return c-'a';} 30 void up(int x) 31 { 32 val[x]=1; 33 for (int i=0;i<26;i++) 34 if (ch[x][i]) val[x]+=val[ch[x][i]]; 35 } 36 int New(char c) 37 { 38 size++;memset(ch[size],0,sizeof(ch[size])); 39 val[size]=2; 40 size++;memset(ch[size],0,sizeof(ch[size])); 41 ch[size-1][id(c)]=size; 42 val[size]=1; 43 return size-1; 44 } 45 int combine(int x,int y) 46 { 47 if (!x || !y) return x+y; 48 for (int i=0;i<26;i++) 49 ch[x][i]=combine(ch[x][i],ch[y][i]); 50 up(x); 51 return x; 52 } 53 }t; 54 55 int root[maxn],dif[maxn]; 56 inline void dfs(int x,int fa) 57 { 58 root[x]=t.New(s[x]); 59 int base=0; 60 for (int i=first[x];i;i=edge[i].next) 61 { 62 const Edge &e=edge[i]; if (e.to==fa) continue; 63 dfs(e.to,x); 64 if (!base) base=e.to; 65 else root[base]=t.combine(root[base],root[e.to]); 66 } 67 if (base) 68 { 69 int u=t.ch[root[x]][t.id(s[x])]; 70 for (int i=0;i<26;i++) t.ch[u][i]=t.ch[root[base]][i]; 71 t.up(u);t.up(root[x]); 72 } 73 dif[x]=t.val[root[x]]-1; 74 // cout<<x<<' '<<dif[x]<<endl; 75 } 76 77 int main() 78 { 79 n=qread(); 80 for (int i=1;i<=n;i++) c[i]=qread(); 81 scanf("%s",s+1); 82 for (int i=1,x,y;i<n;i++) 83 { 84 x=qread(),y=qread(); 85 insert(x,y); 86 } 87 dfs(1,0); 88 int ans=0,cnt=0; 89 for (int i=1;i<=n;i++) 90 { 91 if (c[i]+dif[i]>ans) ans=c[i]+dif[i],cnt=1; 92 else if (c[i]+dif[i]==ans) cnt++; 93 } 94 printf("%d %d ",ans,cnt); 95 return 0; 96 }