题目大意:给你一张图,点分成黑点和白点,在保证白点到离自己最近的黑点存在最短路的情况下花尽可能少的代价。
乍一看题目,好像不知所云。但是仔细研究发现,在没有对原图进行操作的情况下我们很难保证白点到离自己黑点的最短路还存在,于是乎,下面的解法就诞生了。
我们考虑添加超级点S,在S 与每个黑点间连权值为0的边,先处理从S 出发到每个点的最短距离,我们可以取出一张最短路径图,现在我们的问题就变成了取权值最小的边的集合 使得这幅图连接,那么,最小生成树算法就可以完美地解决这个问题。
怎么取出这个最短路径图?如下代码:
void get()
{
for (int i=1;i<=n;i++)
{
if (a[i]) continue;
for (int j=head[i];j;j=edge[j].next)
{
int y=edge[j].to;
if (y==0) continue;
if (dist[y]+1ll*edge[j].l==dist[i]) edge[j].fl=1;
}
}
}
也就是说我们枚举每一个白点连的边,如果dist是通过这条边去更新的那么就fl赋值为1,表示它属于最短路径图。于是对fl为1的边跑最小生成树就好了。
代码如下:
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<queue> #define ll long long using namespace std; const int maxn=100000+50; int n,m,tot,S,cnt; int a[maxn],head[maxn],vis[maxn],fa[maxn]; ll dist[maxn]; struct EDGE { int from;int to;int next;int l;int fl; }edge[maxn<<5]; inline int read() { char ch=getchar(); int s=0,f=1; while (!(ch>='0'&&ch<='9')) {if (ch=='-') f=-1;ch=getchar();} while (ch>='0'&&ch<='9') {s=(s<<3)+(s<<1)+ch-'0';ch=getchar();} return s*f; } void add(int x,int y,int l) { edge[++tot]=(EDGE){x,y,head[x],l,0}; head[x]=tot; edge[++tot]=(EDGE){y,x,head[y],l,0}; head[y]=tot; } void spfa(int x) { memset(dist,0x3f3f3f3f,sizeof(dist)); queue <int> q; dist[x]=0;vis[x]=1;q.push(x); while (!q.empty()) { int k=q.front();q.pop();vis[k]=0; for (int i=head[k];i;i=edge[i].next) { int y=edge[i].to; if (dist[y]>dist[k]+1ll*edge[i].l) { dist[y]=dist[k]+1ll*edge[i].l; if (!vis[y]) { vis[y]=1; q.push(y); } } } } } void get() { for (int i=1;i<=n;i++) { if (a[i]) continue; for (int j=head[i];j;j=edge[j].next) { int y=edge[j].to; if (y==0) continue; if (dist[y]+1ll*edge[j].l==dist[i]) edge[j].fl=1; } } } bool cmp(EDGE aa,EDGE bb) {return aa.l<bb.l;} int find(int x) { if (fa[x]!=x) fa[x]=find(fa[x]); return fa[x]; } int main() { n=read();m=read(); for (int i=1;i<=n;i++) a[i]=read(); for (int i=1;i<=m;i++) { int u=read(),v=read(),l=read(); add(u,v,l); } for (int i=1;i<=n;i++) { if (a[i]) add(S,i,0); fa[i]=i; } spfa(S); get(); sort(edge+1,edge+1+tot,cmp); ll ans=0; int p=0; for (int i=1;i<=tot;i++) { if (!edge[i].fl) continue; int x=edge[i].from,y=edge[i].to; int xx=find(x),yy=find(y); if (xx!=yy) { fa[xx]=yy; ++p; ans+=1ll*edge[i].l; if (p==n-1) break; } } if (!ans) puts("impossible"); else cout<<ans<<endl; return 0; }