题目描述
给出一棵树,求树上两点间的距离
考虑用树链剖分+线段树做。将边权下移,变成点的权值,然后统计点权和即可。
当然,注意,在统计点权时,实际上LCA这个点的权值并不包含在我们的距离之内,因此需要减掉。
#include <bits/stdc++.h> using namespace std; #define N 10010 inline int read(){ int x = 0, s = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); } return x * s; } struct node{ int v, w; int next; } t[N << 1]; int f[N]; int bian = 0; inline void add(int u, int v, int w){ t[++bian] = (node){v, w, f[u]}, f[u] = bian; t[++bian] = (node){u, w, f[v]}, f[v] = bian; return ; } int dfn[N], cnt = 0; int fa[N], son[N], top[N], siz[N], deth[N]; int a[N]; int n, m; /*---------树链剖分start--------*/ #define v t[i].v void dfs1(int now, int father){ deth[now] = deth[father] + 1; siz[now] = 1; fa[now] = father; for(int i = f[now]; i; i = t[i].next){ if(v != father){ dfs1(v, now); siz[now] += siz[v]; if(siz[v] > siz[son[now]]) son[now] = v; } } return ; } void dfs2(int now, int tp){ top[now] = tp; dfn[now] = ++cnt; if(!son[now]) return ; dfs2(son[now], tp); for(int i = f[now]; i; i = t[i].next){ if(v != fa[now] && v != son[now]) dfs2(v, v); } return ; } void dfs3(int now, int father){ for(int i = f[now]; i; i = t[i].next){ if(v != father){ a[dfn[v]] = t[i].w; dfs3(v, now); } } return ; } #undef v int lca(int x, int y){ while(top[x] != top[y]){ if(deth[top[x]] < deth[top[y]]) swap(x, y); x = fa[top[x]]; } return deth[x] < deth[y] ? x : y; } /*-----------end-------*/ /*---------线段树start--------*/ struct tree{ int w; } e[N << 2]; inline void pushup(int o){ e[o].w = e[o << 1].w + e[o << 1 | 1].w; return ; } void build(int o, int l, int r){ if(l == r){ e[o].w = a[l]; return ; } int mid = l + r >> 1; build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); pushup(o); return ; } int query(int o, int l, int r, int in, int end){ if(l > end || r < in) return 0; if(l >= in && r <= end){ return e[o].w; } int mid = l + r >> 1; return query(o << 1, l, mid, in, end) + query(o << 1 | 1, mid + 1, r, in, end); } int ask_he(int x, int y){ int sum = 0; while(top[x] != top[y]){ if(deth[top[x]] < deth[top[y]]) swap(x, y); sum += query(1, 1, n, dfn[top[x]], dfn[x]); x = fa[top[x]]; } if(deth[x] > deth[y]) swap(x, y); sum += query(1, 1, n, dfn[x], dfn[y]); return sum; } /*------------线段树end------------*/ inline void clean(){ memset(son, 0, sizeof(son)); memset(dfn, 0, sizeof(dfn)); cnt = 0; for(int i = 1;i <= bian; i++) t[i] = (node){0, 0, 0}; memset(f, 0, sizeof(f)); bian = 0; return ; } int main(){ // freopen("10.in", "r", stdin); int T = read(); while(T--){ clean(); n = read(), m = read(); for(int i = 1;i < n; i++){ int x = read(), y = read(), w = read(); add(x, y, w); } dfs1(1, 1); dfs2(1, 1); dfs3(1, 1); build(1, 1, n); while(m--){ int x = read(), y = read(); int LCA = lca(x, y); printf("%d ", ask_he(x, y) - a[dfn[LCA]]); } puts(" "); } return 0; }
但是啊,悄悄告诉你呀,这个方法,超慢哒! 好奇怪的语气
根据ZHT大佬的思路,求出每个点到根的距离,然后每次询问直接为 $dis_x + dis_y - 2 * dis_{lca(x, y)}$
在此再次%%% + orz
树剖版:
#include <bits/stdc++.h> using namespace std; #define N 10010 inline int read(){ int x = 0, s = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); } return x * s; } struct node{ int v, w; int next; } t[N << 1]; int f[N]; int bian = 0; inline void add(int u, int v, int w){ t[++bian] = (node){v, w, f[u]}, f[u] = bian; t[++bian] = (node){u, w, f[v]}, f[v] = bian; return ; } int dfn[N], cnt = 0; int fa[N], son[N], top[N], siz[N], deth[N]; int a[N]; int dis[N]; int n, m; /*---------树链剖分start--------*/ #define v t[i].v void dfs1(int now, int father){ deth[now] = deth[father] + 1; siz[now] = 1; fa[now] = father; for(int i = f[now]; i; i = t[i].next){ if(v != father){ dis[v] = dis[now] + t[i].w; dfs1(v, now); siz[now] += siz[v]; if(siz[v] > siz[son[now]]) son[now] = v; } } return ; } void dfs2(int now, int tp){ top[now] = tp; dfn[now] = ++cnt; if(!son[now]) return ; dfs2(son[now], tp); for(int i = f[now]; i; i = t[i].next){ if(v != fa[now] && v != son[now]) dfs2(v, v); } return ; } int LCA(int x, int y){ while(top[x] != top[y]){ if(deth[top[x]] < deth[top[y]]) swap(x, y); x = fa[top[x]]; } return deth[x] < deth[y] ? x : y; } inline void clean(){ memset(son, 0, sizeof(son)); memset(dfn, 0, sizeof(dfn)); cnt = 0; for(int i = 1;i <= bian; i++) t[i] = (node){0, 0, 0}; memset(f, 0, sizeof(f)); bian = 0; return ; } int main(){ int T = read(); while(T--){ clean(); int n = read(), m = read(); for(int i = 1; i < n; i++){ int x = read(), y = read(), w = read(); add(x, y, w); } dfs1(1, 1); dfs2(1, 1); while(m--){ int x = read(), y = read(); printf("%d ", dis[x] + dis[y] - 2 * dis[LCA(x, y)]); } puts(" "); } }
倍增版:
#include <bits/stdc++.h> using namespace std; #define N 10010 inline int read(){ int x = 0, s = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); } return x * s; } struct node{ int v, w; int next; } t[N << 1]; int f[N]; int bian = 0; inline void add(int u, int v, int w){ t[++bian] = (node){v, w, f[u]}, f[u] = bian; t[++bian] = (node){u, w, f[v]}, f[v] = bian; return ; } int a[N][21]; int fa[N], dis[N]; int deth[N]; int n, m; #define v t[i].v void dfs(int now, int father){ a[now][0] = father; deth[now] = deth[father] + 1; for(int i = 1; (1 << i) <= deth[now]; i++) a[now][i] = a[a[now][i-1]][i-1]; for(int i = f[now]; i; i = t[i].next){ if(v != father){ dis[v] = dis[now] + t[i].w; dfs(v, now); } } return ; } #undef v int LCA(int x, int y){ if(deth[x] < deth[y]) swap(x, y); for(int i = 20; i >= 0; i--){ if(deth[a[x][i]] >= deth[y]) x = a[x][i]; } if(x == y) return x; for(int i = 20; i >= 0; i--){ if(a[x][i] == a[y][i]) continue; else x = a[x][i], y = a[y][i]; } return a[x][0]; } inline void clean(){ memset(a, 0, sizeof(a)); memset(deth, 0, sizeof(deth)); memset(dis, 0, sizeof(dis)); for(int i = 1;i <= bian; i++) t[i] = (node){0, 0, 0}; memset(f, 0, sizeof(f)); bian = 0; return ; } int main(){ int T = read(); while(T--){ clean(); int n = read(), m = read(); for(int i = 1; i < n; i++){ int x = read(), y = read(), w = read(); add(x, y, w); } dis[1] = 0; a[1][0] = 1; dfs(1, 1); while(m--){ int x = read(), y = read(); printf("%d ", dis[x] + dis[y] - dis[LCA(x, y)] * 2); } puts(" "); } return 0; }