前言
严格次小生成树,顾名思义,就是在联通图上选择一些边构成一棵树,使这棵树边权和严格次小。
第一步:求出最小生成树
要求严格次小生成树,我们就要先求出最小生成树((Prim)和(Kruskal)等算法都可以,我用的是(Kruskal)),在求最小生成树的过程中,还要给每条使用过的边打一个标记,代码如下:
inline void Kruskal()//求最小生成树
{
register int res=0,tot=1;//res记录最小生成树的边权和,tot记录已经在树上的点数
for(sort(e+1,e+m+1,cmp),i=1;i<=m&&tot<n;++i)//先将边按权值排序,只要已经在树上的点数小于n,就不断枚举下一条边
{
static LL fx,fy;
if((fx=getfa(e[i].from))^(fy=getfa(e[i].to))) f[fy]=fx,res+=e[i].val,++tot,e[i].used=1,v[e[i].from].push_back(i),v[e[i].to].push_back(i);//合并两个联通快,更新res和tot,并标记这条边已使用
}
}
最小生成树的主要思想
讲如何求严格次小生成树之前,还要先回顾一下最小生成树的主要思想。
比如说(Kruskal),它的思想就是,贪心每次选择权值最小的一条边,判断这条边两边的点是否在同一个联通块内,也就是判断加上这条边之后是否会构成环。
从最小生成树的主要思想而引发得到的思路
首先,我们应该有一个比较显然的想法:如果我们要求出严格次小生成树,那么肯定就是用一条边去替换最小生成树上的一条边。
现在假设我们已经选好了一条边(这可以拿来枚举),那么应该替换掉那一条边呢?
添上一条边肯定会使原来的树上形成一个环,要让这棵树重新变回树,那么显然要删掉这个环上的一条边,而无论删掉环上的哪一条边,结果都是一棵树。
既然删哪条边都无所谓,那么由于贪心的思想,显然是尽量删掉边权较大的边。
又因为我们是求严格次小生成树,因此,只要删掉边权严格小于新加上的那条边的边权的边权最大的一条边。
如何确定这条边——倍增(LCA)
现在我们已经知道应该删掉的边是什么样的了,但是如何确定这条边到底是哪条边呢?
我们可以假设新加上的这条边的左右端点分别为(u)和(v),那么,要删掉的边必然位于最小生成树上(u)到(v)的路径之间。也就是肯定位于(u)到(lca(u,v))和(v)到(lca(u,v))的其中一条路径上。
那么我们就可以考虑倍增(LCA)。
同经典的倍增(LCA)一样,我们用(fa_{i,j})来记录(i)的第(2^j)个祖先,然后还要额外加上两个数组(Max1_{i,j})和(Max2_{i,j})分别表示节点(i)到其第(2^j)个祖先的路径上边权最大的边的边权和边权严格次大的边的边权。
为什么只需记录最大和严格次大的边权,而不需记录第三大、第四大、第五大?
证明: 由于最小生成树基于贪心,因此,这些边中最大的边也不会大于新加入的这条边,所以严格次大的边一定小于新加入的这条边。
既然这样,对于新加入的一条边((u,v)),我们只需在求出它们的(LCA)的同时计算出(u)和(v)到(lca(u,v))的路径上小于((u,v))权值的最大边权即可。
代码如下:
inline LL MAX(LL x,LL a,LL b,LL val)//求出a到其第2^b个父亲的路径上边权小于val的最大值与x的最大值
{
if(val>Max1[a][b]) return max(x,Max1[a][b]);
return max(x,Max2[a][b]);
}
inline LL Get(LL x,LL y,LL val)//求出x和y路径之间小于val的最大边权
{
register LL i,res=0;//res记录答案
if(Depth[x]<Depth[y]) swap(x,y);//以下为倍增LCA的模板,只不过在每次更新节点的同时加上一个更新小于val的最大边权的操作即可
for(i=0;Depth[x]^Depth[y];++i) if((Depth[x]^Depth[y])&(1<<i)) res=MAX(res,x,i,val),x=fa[x][i];
if(!(x^y)) return res;
for(i=0;fa[x][i]^fa[y][i];++i);
for(;i>=0;--i) if(fa[x][i]^fa[y][i]) res=MAX(MAX(res,x,i,val),y,i,val),x=fa[x][i],y=fa[y][i];
return MAX(MAX(res,x,0,val),y,0,val);
}
代码
最后,贴一下洛谷板子题的代码:
#include<bits/stdc++.h>
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)<(y)?(x):(y))
#define abs(x) ((x)<0?-(x):(x))
#define LL long long
#define ull unsigned long long
#define swap(x,y) (x^=y,y^=x,x^=y)
#define tc() (A==B&&(B=(A=ff)+fread(ff,1,100000,stdin),A==B)?EOF:*A++)
#define pc(ch) (pp_<100000?pp[pp_++]=(ch):(fwrite(pp,1,100000,stdout),pp[(pp_=0)++]=(ch)))
#define N 100000
#define M 300000
#define MOD 31011
LL pp_=0;char ff[100000],*A=ff,*B=ff,pp[100000];
using namespace std;
LL n,m,Depth[N+5],f[N+5],fa[N+5][20],Max1[N+5][20],Max2[N+5][20];
struct edge
{
LL from,to,val,used;
}e[M+5];
vector<LL> v[N+5];
inline void read(LL &x)
{
x=0;LL f=1;static char ch;
while(!isdigit(ch=tc())) f=ch^'-'?1:-1;
while(x=(x<<3)+(x<<1)+ch-48,isdigit(ch=tc()));
x*=f;
}
inline void write(LL x)
{
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
inline LL getfa(LL x)
{
return f[x]^x?f[x]=getfa(f[x]):x;
}
inline bool cmp(edge x,edge y)
{
return x.val<y.val||(x.val==y.val&&x.to<y.to);
}
inline void GetRt(LL x)
{
register LL i,sz=v[x].size();
for(i=0;i<sz;++i)
{
register LL nxt=e[v[x][i]].from^x?e[v[x][i]].from:e[v[x][i]].to,val=e[v[x][i]].val;
if(!(fa[x][0]^nxt)) continue;
fa[nxt][0]=x,Max1[nxt][0]=val,Depth[nxt]=Depth[x]+1,GetRt(nxt);
}
}
inline void Init()
{
register LL i,j;
for(GetRt(Depth[1]=1),j=1;j<20;++j)
{
for(i=1;i<=n;++i)
{
fa[i][j]=fa[fa[i][j-1]][j-1],Max1[i][j]=max(Max1[i][j-1],Max1[fa[i][j-1]][j-1]),Max2[i][j]=max(Max2[i][j-1],Max2[fa[i][j-1]][j-1]);
if(Max1[i][j-1]^Max1[fa[i][j-1]][j-1]) Max2[i][j]=max(Max2[i][j],min(Max1[i][j-1],Max1[fa[i][j-1]][j-1]));
}
}
}
inline LL MAX(LL x,LL a,LL b,LL val)
{
if(val>Max1[a][b]) return max(x,Max1[a][b]);
return max(x,Max2[a][b]);
}
inline LL Get(LL x,LL y,LL val)
{
register LL i,res=0;
if(Depth[x]<Depth[y]) swap(x,y);
for(i=0;Depth[x]^Depth[y];++i) if((Depth[x]^Depth[y])&(1<<i)) res=MAX(res,x,i,val),x=fa[x][i];
if(!(x^y)) return res;
for(i=0;fa[x][i]^fa[y][i];++i);
for(;i>=0;--i) if(fa[x][i]^fa[y][i]) res=MAX(MAX(res,x,i,val),y,i,val),x=fa[x][i],y=fa[y][i];
return MAX(MAX(res,x,0,val),y,0,val);
}
int main()
{
register LL i,j,tot=1,ans=1e18,res=0,t;
for(read(n),read(m),i=1;i<=m;++i) read(e[i].from),read(e[i].to),read(e[i].val);
for(i=1;i<=n;++i) f[i]=i;
for(sort(e+1,e+m+1,cmp),i=1;i<=m&&tot<n;++i)
{
static LL fx,fy;
if((fx=getfa(e[i].from))^(fy=getfa(e[i].to))) f[fy]=fx,res+=e[i].val,++tot,e[i].used=1,v[e[i].from].push_back(i),v[e[i].to].push_back(i);
}
for(Init(),i=1;i<=m;++i)
if(!e[i].used) t=Get(e[i].from,e[i].to,e[i].val),ans=min(ans,res+e[i].val-t);
return write(ans),fwrite(pp,1,pp_,stdout),0;
}