题意:给出一个n个点m条边的无向边,q次询问每次询问把一条边权值增大后问新的MST是多少,输出Sum(MST)/q。
解法:一开始想的是破圈法,后来想了想应该不行,破圈法应该只能用于加边的情况而不是修改边,因为加边可以保证以前MST不用的边加边之后也一定不用,但是修改边不能保证以前不用的边修改边之后会不会再用。
正解是参考https://blog.csdn.net/Ramay7/article/details/52236040这位大佬的。
大佬真的分析得巨好。我的理解就是:假如我们要计算dp[u][v]代表去掉MST上u-v这条边之后能替代的最好边,设u这一边的连通点集是(u1,u2,u3...),v这一边的点集是(v1,v2,v3...),那么我们朴素算法是暴力枚举每一对ui和vi然后取最小值,显然这样超时。用树形DP的方法是,我们枚举一个根节点rt,然后dfs一边计算以rt为根节点的时候的MST的所有边的dp值,计算方式就是dp[u][v]=min(dis[y][rt])(即子树中到根rt的最小值)。枚举完根节点rt之后我们的dp数组就出来了。 为什么这样能达到朴素算法一样的效果呢?因为考虑我们每一次rt对dp[u][v]的贡献,显然每一个rt都是在点集(u1,u2,u3...)中的,然后这次dfs可以计算(u1,u2,u3...)中的某一个ui和所有的(v1,v2,v3...)的点对对答案的贡献,然后所以的rt加起来必定等于(u1,u2,u3...)。 这里说得有点乱了,就是这样:
(rt=u1)x(v1,v2,v3...)+(rt=u2)*(v1,v2,v3...)+(rt=u3)*(v1,v2,v3...)+....(rt=ui)*(v1,v2,v3...) == (u1,u2,u3...)*(v1,v2,v3...) (上面说的一大堆想说的就是这个等式的意思qwq)
最后处理下询问看看修改边在不在原MST上,就可以获得AC了。
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int N=3e3+10; int n,m,fa[N],dis[N][N],dp[N][N]; LL sum,ans; struct edge{ int x,y,z; bool operator < (const edge &rhs) const { return z<rhs.z; } }e[N*N]; bool mst[N][N]; int cnt,head[N],nxt[N<<1],to[N<<1]; void add_edge(int x,int y) { nxt[++cnt]=head[x]; to[cnt]=y; head[x]=cnt; } int getfa(int x) { return x==fa[x] ? x : fa[x]=getfa(fa[x]); } void Kruskal() { sort(e+1,e+m+1); for (int i=1;i<=n;i++) fa[i]=i; int num=1; for (int i=1;i<=m;i++) { int fx=getfa(e[i].x),fy=getfa(e[i].y); if (fx==fy) continue; fa[fx]=fa[fy]; add_edge(e[i].x,e[i].y); add_edge(e[i].y,e[i].x); mst[e[i].x][e[i].y]=mst[e[i].y][e[i].x]=1; sum+=e[i].z; if (++num==n) break; } } int dfs(int rt,int x,int fa) { int Min=0x3f3f3f3f; for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (y==fa) continue; int tmp=dfs(rt,y,x); Min=min(Min,tmp); dp[x][y]=min(dp[x][y],tmp); //用子树的Min更新dp[][] dp[y][x]=min(dp[y][x],tmp); } if (dis[rt][x] && !mst[rt][x]) Min=min(Min,dis[rt][x]); //更新Min return Min; } int main() { while (scanf("%d%d",&n,&m) && n) { cnt=1; for (int i=1;i<=n;i++) head[i]=0; for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) dis[i][j]=mst[i][j]=0,dp[i][j]=0x3f3f3f3f; for (int i=1;i<=m;i++) { scanf("%d%d%d",&e[i].x,&e[i].y,&e[i].z); e[i].x++; e[i].y++; dis[e[i].x][e[i].y]=dis[e[i].y][e[i].x]=e[i].z; } sum=0; ans=0; Kruskal(); for (int i=1;i<=n;i++) dfs(i,i,0); //每个点做根节点dfs一次 int q; scanf("%d",&q); for (int i=1;i<=q;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); x++; y++; if (!mst[x][y]) ans+=sum; //不在MST上 else { //在MST上 LL tdis=sum-dis[x][y]+min(z,dp[x][y]); ans+=tdis; } } printf("%.4lf ",(double)ans/q); } return 0; }