传送门:铁人两项
简述一下题目:
给出一个(不一定联通)的图,求有多少个三元组(s,c,f)满足s,c,f都是图中的点,且存在一条从s到c的路径和一条从c到f的路径,使得两条路径没有公共点(除c以外)。
这个题当时刚接触到圆方树,我的想法跟正解十分接近使我非常兴奋。
这个题我们想一下如果n2的话我们要怎么做:
枚举两个圆点s,f。路径上所有的点双中的点都可以作为c。如何方便地统计呢?首先我们建出圆方树,把圆点权值设为-1(因为正常计算有重复路径,这样直接免去容斥减的过程),方点权值设为点双的大小,则s到f的路径上的点(包括s,f)的权值和,也就是c的个数。这是n方的。
如果枚举中间点,则很容易求出树上有多少个圆圆路径经过这个点,通过这个点的子树大小直接O(1)进行计算即可,这样我们只用把所有的点枚举一遍即可,这样是O(n)的。
代码先咕咕咕
代码成功的没有咕咕咕:
#define B cout << "BreakPoint" << endl; #define O(x) cout << #x << " " << x << endl; #define O_(x) cout << #x << " " << x << " "; #define Msz(x) cout << "Sizeof " << #x << " " << sizeof(x)/1024/1024 << " MB" << endl; #include<cstdio> #include<cmath> #include<iostream> #include<cstring> #include<algorithm> #include<queue> #include<stack> #define LL long long #define inf 1000000009 #define N 1000005 using namespace std; inline int read() { int s = 0,w = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') w = -1; ch = getchar(); } while(ch >= '0' && ch <= '9') { s = s * 10 + ch - '0'; ch = getchar(); } return s * w; } LL ans; int vis[N]; int n,m,top,res,tot,s; int dfn[N],low[N],stk[N],val[N],sz[N]; struct Graph { int head[N],nxt[N << 1],to[N << 1]; int ecnt; inline void add(int u,int v) { to[++ecnt] = v; nxt[ecnt] = head[u]; head[u] = ecnt; return; } inline void init(int u,int v) { add(u,v); add(v,u); return; } } eold,enew; inline void cmin(int &x,int y) { if(x > y) x = y; return; } void tarjan(int u) { dfn[u] = low[u] = ++tot; stk[++top] = u; sz[u] = 1; for(int i = eold.head[u]; i; i = eold.nxt[i]) { int v = eold.to[i]; if(!dfn[v]) { tarjan(v); cmin(low[u],low[v]); if(low[v] >= dfn[u]) { int t = 0,cnt = 1; res++; while(t != v) { t = stk[top--]; cnt++; enew.init(res,t); sz[res] += sz[t]; } val[res] = cnt; sz[u] += sz[res]; enew.init(res,u); } } else { cmin(low[u],dfn[v]); } } return ; } void dfs(int u,int fa) { int x = u <= n; ans += 2ll * sz[u] * (s - sz[u]) * val[u]; for(int i = enew.head[u]; i; i = enew.nxt[i]) { int v = enew.to[i]; if(v == fa) { continue; } ans += 2ll * x * sz[v] * val[u]; x += sz[v]; dfs(v,u); } return ; } void pre() { n = read(),m = read(); res = n; memset(val,-1,sizeof(val)); for(int i = 1; i <= m; i++) { int u = read(),v = read(); eold.init(u,v); } return ; } void solve() { for(int i = 1; i <= n; i++) { if(!dfn[i]) { tarjan(i); s = sz[i]; dfs(i,-1); } } printf("%lld",ans); return ; } int main() { pre(); solve(); return 0; }