例题:考虑一颗边权为1的树上有多少个路径正好为k的点对。
我们考虑一个这样的树,现在问,这个树上有多少个点对之间的距离为k。
首先,我们从根结点开始考虑。
那么我们可以把所有的路径划分为两个部分
1,经过根结点的路径。2,不经过根结点的路径。
对于第一种路径,经过根节点,那么就是x->root->y。
也就是说这条路径是root的两个不同子树的链组成。
那么不就是考虑d[x] + d[y] == k的点对吗。
我们可以求的root到每个结点的距离,存放到d数组里面。
同时,保存每个结点是root的哪个子树下面的点 用b数组保存,保存root能到那些结点,用point数组保存。
那么我们可以把point数组根据距离进行排序。
从而用两个指针的方式将其进行统计。
对于第二种路径来说,
不就是递归第一种路径嘛。
例题链接:https://www.luogu.com.cn/problem/CF161D
#include"stdio.h" #include"string.h" #include"algorithm" using namespace std; inline int read(){ int x=0,f=1; char c=getchar(); while(c<'0'||c>'9'){ if(c=='-')f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<3)+(x<<1)+c-'0'; c=getchar(); } return x*f; } const int N = 100010; int head[N],ver[N],Next[N],edge[N],tot; int n,m; int v[N],Size[N],ans,root;///找到树的重心 int vis[N]; int d[N],b[N],point[N],top; int cnt[N]; int num,k; void add(int x,int y,int w){ ver[++ tot] = y; edge[tot] = w; Next[tot] = head[x]; head[x] = tot; } void get_root(int x,int far,int n){///求子树的重心 Size[x] = 1; int max_part = 0; for(int i = head[x]; i; i = Next[i]){ int y = ver[i]; if(vis[y] || y == far) continue; get_root(y,x,n); Size[x] += Size[y]; max_part = max(max_part,Size[y]); } max_part = max(max_part,n - Size[x]); if(max_part < ans || root == 0) { ans = max_part; root = x; } return ; } void get_dist(int x,int far,int ww,int from){ point[++ top] = x; b[x] = from;d[x] = ww; cnt[from] ++; for(int i = head[x]; i; i = Next[i]){ int y = ver[i]; if(y == far || vis[y]) continue; // d[y] = d[far] + edge[i]; get_dist(y,x,ww + edge[i],from); } } int cmp(int x,int y){ if(d[x] == d[y]) return b[x] < b[y]; return d[x] < d[y]; } void calc(int root) { top = 0; point[++ top] = root; d[root] = 0; b[root] = root; cnt[root] = 1; for(int i = head[root]; i; i = Next[i]) { int y = ver[i]; if(vis[y]) continue; cnt[y] = 0; // d[y] = edge[i]; get_dist(y,root,edge[i],y); } sort(point + 1,point + top + 1,cmp); int left = 1,right = top; while(left < right){ if(d[point[left]] + d[point[right]] < k) left ++; else if(d[point[left]] + d[point[right]] > k) right --; else { int xx = 0; int r = right; while(r > left){ if(d[point[r]] + d[point[left]] == k) { if(b[point[r]] != b[point[left]]) xx ++; } else break; r --; } num += xx; left ++; } } } void solve(int u) { vis[u] = 1; top = 0; calc(u); for(int i = head[u]; i; i = Next[i]){ int y = ver[i]; if(vis[y]) continue; ans = n; root = 0; get_root(y,0,Size[y]); solve(root); } } int main() { n = read();k = read(); for(int i = 1; i <= n - 1; i ++){ int x,y,w; x = read(); y = read();w = 1; add(x,y,w); add(y,x,w); } ans = n; get_root(1,0,n); solve(root); printf("%d ",num); }