人生的第一道树分治,要是早点学我南京赛就不用那么挫了,树分治的思路其实很简单,就是对子树找到一个重心(Centroid),实现重心分解,然后递归的解决分开后的树的子问题,关键是合并,当要合并跨过重心的两棵子树的时候,需要有一个接近O(n)的方法,因为f(n)=kf(n/k)+O(n)解出来才是O(nlogn).在这个题目里其实就是将第一棵子树的集合里的每个元素,判下有没符合条件的,有就加上,然后将子树集合压进大集合,然后继续搞第二棵乃至第n棵.我的过程用了map,合并是nlogn的所以代码速度颇慢,大概6s,题目时限10s,可以改成hash应该会快许多,毕竟用map实在太慢,用vector也可以,具体可以参见挑战程序设计竞赛代码.下面的代码查找重心用了挑战的代码.
#pragma comment(linker, "/STACK:102400000,102400000") #include<iostream> #include<cstring> #include<string> #include<cstdio> #include<algorithm> #include<map> #include<vector> #define maxv 50000 #define ll long long using namespace std; int n,k; vector<int> G[maxv+50]; ll val[maxv+50]; ll prime[maxv+50]; ll convert_three(ll v) { ll bas=1;ll res=0; for(int i=0;i<k;++i){ int num=0; while(v%prime[i]==0){ v/=prime[i]; num++; } num%=3;res+=num*bas; bas*=3; } return res; } ll xor(ll x,ll y) { ll res=0;ll bas=1; for(int i=0;i<k;++i){ res+=((x%3)+(y%3))%3*bas; x/=3;y/=3; bas*=3; } return res; } ll inv(ll x) { ll res=0;ll bas=1; for(int i=0;i<k;++i){ res+=((3-(x%3))%3)*bas; x/=3; bas*=3; } return res; } void print(ll x){ while(x){ cout<<x%3; x/=3; } cout<<endl; } bool centroid[maxv+50]; int ssize[maxv+50]; int ans; map<ll,int> sta; map<ll,int>::iterator it; int compute_ssize(int v,int p) { int c=1; for(int i=0;i<G[v].size();++i){ int w=G[v][i]; if(w==p||centroid[w]) continue; c+=compute_ssize(G[v][i],v); } ssize[v]=c; return c; } pair<int,int> search_centroid(int v,int p,int t) { pair<int,int> res=make_pair(INT_MAX,-1); int s=1,m=0; for(int i=0;i<G[v].size();++i){ int w=G[v][i]; if(w==p||centroid[w]) continue; res=min(res,search_centroid(w,v,t)); m=max(m,ssize[w]); s+=ssize[w]; } m=max(m,t-s); res=min(res,make_pair(m,v)); return res; } void enumerate_mul(int v,int p,ll d,map<ll,int> &ds) { if(ds.count(d)) ds[d]++; else ds[d]=1; for(int i=0;i<G[v].size();++i){ int w=G[v][i]; if(w==p||centroid[w]) continue; enumerate_mul(w,v,xor(d,val[w]),ds); } } void solve(int v) { compute_ssize(v,-1); int s=search_centroid(v,-1,ssize[v]).second; centroid[s]=true; for(int i=0;i<G[s].size();++i){ if(centroid[G[s][i]]) continue; solve(G[s][i]); } sta.clear(); sta[val[s]]=1;map<ll,int> tds; for(int i=0;i<G[s].size();++i){ if(centroid[G[s][i]]) continue; tds.clear(); enumerate_mul(G[s][i],s,val[G[s][i]],tds); it=tds.begin(); while(it!=tds.end()){ ll rev=inv((*it).first); if(sta.count(rev)){ ans+=sta[rev]*(*it).second; } ++it; } it=tds.begin(); while(it!=tds.end()){ ll vv=xor((*it).first,val[s]); if(sta.count(vv)){ sta[vv]+=(*it).second; } else{ sta[vv]=(*it).second; } ++it; } } centroid[s]=false; } int main() { while(cin>>n>>k){ ans=0; for(int i=0;i<k;++i){ scanf("%I64d",&prime[i]); } G[0].clear(); for(int i=1;i<=n;++i){ scanf("%I64d",&val[i]); val[i]=convert_three(val[i]); if(val[i]==0) ans++; //print(val[i]); G[i].clear(); } int u,v; for(int i=0;i<n-1;++i){ scanf("%d%d",&u,&v); G[u].push_back(v); G[v].push_back(u); } memset(centroid,0,sizeof(centroid)); solve(1); printf("%d ",ans); } return 0; }