Solution
首先考虑转化题意,以下用(scc)指代强连通分量
活动的每一步相同于:如果(y,z)在同一个(scc)中,(x)向(y)有连边,那么(x)就可以向(z)连边。
也就是说,对于一个(scc),如果(x)向(scc)中连了一条边,那么(x)就可以在活动后向(scc)中的任何一个点连边。
那么分别考虑每个(scc)的贡献就是内部的贡献(siz(siz-1)+)向(scc)连了边的点的贡献(sizcdot x),其中(x)时向该(scc)连边的点的数量。
于是,对于每一个(scc)开一个(set) (pt[i])维护内部的点,一个(set) (in[i])维护向该(scc)连边的点。
对于加入的每一条边((x,y)):
(1.)((x,y))在同一个(scc)中:直接跳过
(2.)如果(y)所在(scc)本来就向(x)所在(scc)连了边,那么需要合并这两个(scc)
(3.)否则,直接在(in[y])中加入(x)
为了判断是否可以合并(scc),还需要额外维护两个(set):(ins[s])和(outs[s])表示向(s)连边的(scc)以及(s)连出去的(scc),合并时使用启发式合并,依次考虑每一个(set)即可
注意到合并后可能还会导致新的需要合并的(scc),于是开一个·队列先后合并即可。
Code
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+10;
typedef long long ll;
int n,m,fa[N];
inline int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
set<int> pt[N];//点
set<int> ins[N];//连入的强连通分量
set<int> outs[N];//连出的强连通分量
set<int> in[N];//连入的点
inline ll calc(int x){
return 1ll*pt[x].size()*(pt[x].size()-1+in[x].size());
}
typedef pair<int,int> pii;
#define mp make_pair
#define it set<int>::iterator
ll ans=0,ret=0;
queue<pii>q;
inline void merge(int x,int y){
int fx=find(x),fy=find(y);
if(fx==fy) return ;
ans-=calc(fx)+calc(fy);
if(pt[fx].size()>pt[fy].size()) swap(fx,fy),swap(x,y);
fa[fx]=fy;
if(ins[fx].count(fy)) ins[fx].erase(fy),outs[fy].erase(fx);
if(outs[fx].count(fy)) outs[fx].erase(fy),ins[fy].erase(fx);
for(it i=pt[fx].begin();i!=pt[fx].end();i++){
int s=*i;
if(in[fy].count(s)) in[fy].erase(s);
pt[fy].insert(*i);
}
for(it i=ins[fx].begin();i!=ins[fx].end();i++){
int s=*i;
if(outs[fy].count(s)) q.push(mp(s,fy));
else ins[fy].insert(s);outs[s].erase(fx);outs[s].insert(fy);
}
for(it i=outs[fx].begin();i!=outs[fx].end();i++){
int s=*i;
if(ins[fy].count(s)) q.push(mp(s,fy));
else outs[fy].insert(s);ins[s].erase(fx);ins[s].insert(fy);
}
for(it i=in[fx].begin();i!=in[fx].end();i++)
if(!pt[fy].count(*i)) in[fy].insert(*i);
ans+=calc(fy);
}
inline void work(int x,int y){
q.push(mp(x,y));
while(!q.empty())
merge(q.front().first,q.front().second),q.pop();
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) fa[i]=i,pt[i].insert(i);
for(int i=1;i<=n;++i) ans+=calc(i);
for(int i=1,a,b;i<=m;++i){
scanf("%d%d",&a,&b);
int fx=find(a),fy=find(b);
if(fx==fy){
printf("%lld
",ans);
continue;
}
if(ins[fx].count(fy)) work(a,b);
else{
ans-=calc(fy);
ins[fy].insert(fx); outs[fx].insert(fy);
in[fy].insert(a);
ans+=calc(fy);
}
printf("%lld
",ans);
}
return 0;
}