题目链接:
我们考虑每条边的贡献,对每个点求出能到达它的最近的感兴趣的城市(设为$f[i]$,最短距离设为$a[i]$)和它能到达的离它最近的感兴趣的城市(设为$g[i]$,最短距离设为$b[i]$)。
那么每条边$(u,v,w)$的贡献就是$a[u]+w+b[v]$,用这个值去更新答案即可(这个值代表$f[u]$到$g[v]$的最短路长度)。
但要注意一条边能更新答案需要满足$f[u] eq g[v]$,因为要保证起点和终点不同。
手画一下就可以知道最短路径上的边至少有一条会更新答案,即不可能发生最短路径上每条边的$f[u]=g[v]$。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<vector> #include<bitset> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long #define pr pair<ll,int> using namespace std; int tot; int head[100010]; int to[600010]; int nex[600010]; int val[600010]; ll ans; int tim; int n,m,k; int u[500010]; int v[500010]; int w[500010]; ll d[100010]; ll c[100010]; int f[100010]; int g[100010]; int a[100010]; int vis[100010]; priority_queue< pr,vector<pr>,greater<pr> >q; void init() { memset(head,0,sizeof(head)); memset(val,0,sizeof(val)); memset(nex,0,sizeof(nex)); memset(to,0,sizeof(to)); memset(f,0,sizeof(f)); memset(c,0,sizeof(c)); memset(g,0,sizeof(g)); memset(d,0,sizeof(d)); memset(v,0,sizeof(v)); memset(w,0,sizeof(w)); memset(u,0,sizeof(u)); tot=0; } void add(int x,int y,int z) { nex[++tot]=head[x]; head[x]=tot; to[tot]=y; val[tot]=z; } void dijkstra(int opt) { tot=0; memset(vis,0,sizeof(vis)); memset(head,0,sizeof(head)); memset(f,0,sizeof(f)); for(int i=1;i<=n;i++) { d[i]=1ll<<60; } for(int i=1;i<=k;i++) { d[a[i]]=0; f[a[i]]=a[i]; q.push(make_pair(d[a[i]],a[i])); } for(int i=1;i<=m;i++) { if(!opt) { add(u[i],v[i],w[i]); } else { add(v[i],u[i],w[i]); } } while(!q.empty()) { int now=q.top().second; q.pop(); if(vis[now]) { continue; } vis[now]=1; for(int i=head[now];i;i=nex[i]) { if(d[to[i]]>d[now]+val[i]) { d[to[i]]=d[now]+val[i]; f[to[i]]=f[now]; q.push(make_pair(d[to[i]],to[i])); } } } } void solve() { scanf("%d%d%d",&n,&m,&k); for(int i=1;i<=m;i++) { scanf("%d%d%d",&u[i],&v[i],&w[i]); } for(int i=1;i<=k;i++) { scanf("%d",&a[i]); } dijkstra(0); for(int i=1;i<=n;i++) { c[i]=d[i],g[i]=f[i]; } dijkstra(1); ll ans=1ll<<60; for(int i=1;i<=m;i++) { if(g[u[i]]&&f[v[i]]&&g[u[i]]!=f[v[i]]) { ans=min(ans,c[u[i]]+d[v[i]]+w[i]); } } printf("%lld ",ans); } int main() { scanf("%d",&tim); while(tim--) { init(); solve(); } }