题面
https://www.luogu.org/problem/CF161D
题解
#include<cstdio> #include<iostream> #include<algorithm> #include<vector> #include<cstring> using namespace std; int n,k,sum,p,root; long long ans=0; vector<int> to[50050]; int cnt[500],cnt0[500]; int maxson[50050],siz[50050],dis[50050]; int q[50050]; bool vis[50050]; void getroot(int now,int ff) { int i,l=to[now].size(); maxson[now]=0; siz[now]=1; for (i=0;i<l;i++) if (to[now][i]!=ff && !vis[to[now][i]]) { getroot(to[now][i],now); maxson[now]=max(maxson[now],siz[to[now][i]]); siz[now]+=siz[to[now][i]]; } maxson[now]=max(maxson[now],sum-siz[now]); if (maxson[now]<maxson[root]) root=now; } void getdis(int x,int ff) { int i,l=to[x].size(); for (i=0;i<l;i++) if (to[x][i]!=ff && !vis[to[x][i]]) { dis[to[x][i]]=dis[x]+1; if (dis[to[x][i]]<=k) q[++p]=dis[to[x][i]],cnt0[dis[to[x][i]]]++; getdis(to[x][i],x); } } void solve(int x){ vis[x]=true; int i,j,l=to[x].size(); p=0; cnt[0]=1; for (i=0;i<l;i++) if (!vis[to[x][i]]) { dis[to[x][i]]=1; cnt0[1]++,q[++p]=dis[1]; getdis(to[x][i],x); for (j=0;j<=k;j++) ans+=cnt0[j]*1LL*cnt[k-j]; for (j=0;j<=k;j++) cnt[j]+=cnt0[j]; for (j=0;j<=k;j++) cnt0[j]=0; } memset(cnt,0,sizeof(cnt)); for (i=0;i<l;i++) if (!vis[to[x][i]]) { sum=siz[to[x][i]]; maxson[root=0]=99999; getroot(to[x][i],x); solve(root); } } int main() { int i,u,v; scanf("%d %d",&n,&k); for (i=1;i<n;i++) { scanf("%d %d",&u,&v); to[u].push_back(v); to[v].push_back(u); } sum=n; maxson[root=0]=99999; getroot(1,-1); solve(root); cout<<ans<<endl; }