emmm, 比赛的时候没有想到如何利用非树边。
其实感觉很简单。。
对于一个询问答案分为两部分求:
第一部分:只经过树边,用倍增就能求出来啦。
第二部分:经过至少一条非树边, 如果经过一个树边那么必定经过其两个端点,暴力的求出这些端点为起始点的最短路。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define PII pair<int, int> #define PLI pair<LL, int> using namespace std; const int N = 2e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; int n, m, tot, cnt, head[N], f[N][20], d[N], depth[N]; LL cost[N][20], dis[45][N]; bool vis[N]; struct Edge { int from, to, w, nx; } edge[N << 1]; void add(int u, int v, int w) { edge[tot].from = u; edge[tot].to = v; edge[tot].w = w; edge[tot].nx = head[u]; head[u] = tot++; } void Dij(int S, LL d[N]) { memset(d, INF, N * sizeof(LL)); priority_queue<PLI, vector<PLI>, greater<PLI> > que; d[S] = 0; que.push(mk(0, S)); while(!que.empty()) { int u = que.top().se; LL dis = que.top().fi; que.pop(); if(dis > d[u]) continue; for(int i = head[u]; ~i; i = edge[i].nx) { int v = edge[i].to, w = edge[i].w; if(dis + w < d[v]) { d[v] = dis + w; que.push(mk(d[v], v)); } } } } void dfs(int u, int fa, int w, int deep) { vis[u] = true; depth[u] = deep; f[u][0] = fa; cost[u][0] = w; for(int j = 1; j < 20; j++) { f[u][j] = f[f[u][j-1]][j-1]; cost[u][j] = cost[f[u][j-1]][j-1] + cost[u][j-1]; } for(int i = head[u]; ~i; i = edge[i].nx) { int v = edge[i].to; if(v == fa) continue; else if(vis[v]) { d[cnt++] = u; d[cnt++] = v; } else dfs(v, u, edge[i].w, deep + 1); } } int getLCA (int u, int v){ if(depth[u] < depth[v]) swap(u, v); for(int j = 19; j >= 0; j--) if(depth[u] - depth[v] >= (1 << j)) u = f[u][j]; for(int j = 19; j >= 0; j--) if(f[u][j] != f[v][j]) u = f[u][j], v = f[v][j]; return u == v ? u : f[u][0]; } LL getDis(int u, int v) { LL ans = 0; for(int j = 19; j >= 0; j--) if(depth[u] - depth[v] >= (1 << j)) ans += cost[u][j], u = f[u][j]; return ans; } int main() { memset(head, -1, sizeof(head)); scanf("%d%d", &n, &m); for(int i = 1; i <= m; i++) { int u, v, w; scanf("%d%d%d", &u, &v, &w); add(u, v, w); add(v, u, w); } dfs(1, 0, 0, 0); sort(d, d + cnt); cnt = unique(d, d + cnt) - d; for(int i = 0; i < cnt; i++) Dij(d[i], dis[i]); int q; scanf("%d", &q); while(q--) { int u, v, lca; scanf("%d%d", &u, &v); lca = getLCA(u, v); LL ans = getDis(u, lca) + getDis(v, lca); for(int i = 0; i < cnt; i++) ans = min(ans, dis[i][u] + dis[i][v]); printf("%lld ", ans); } return 0; } /* */