Description
给你一张\(~n~\)个点\(~m~\)条边的无向图,求有多少个三元组\(~(x, ~y, ~z)~\)满足存在一条从\(~x~\)到\(~z~\)并且经过\(~y~\)的路径。保证两点之间最多只有一条边连接。
Solution
考虑对这张图建圆方树,每个方点的权值记录该点双的点数,每个圆点的权值为\(-1\)。这样先确定\(~x, ~z~\)之后, 其路径上的点权和就是满足条件的\(~y~\)的个数 (因为一个圆点的贡献会算进两个相邻的方点中,所以每个圆点的权值是\(-1\)).
现在考虑优化这个思路,可以发现每个点的点权对答案的贡献次数是该点出现在合法路径上出现的次数,所以可以\(~O(n)~\)快速求出该点的出现总次数。注意图不一定联通,\(~sum~\)为该联通块的大小。
\[Ans_u = val[u] \times [(sum - siz[u]) \times siz[u] + \sum_{v = son_x} siz[v] \times (siz[u] - siz[v])]
\]
Code
#include<bits/stdc++.h>
#define Set(a, b) memset(a, b, sizeof (a))
#define For(i, j, k) for(int i = j; i <= k; ++i)
#define Forr(i, j, k) for(int i = j; i >= k; --i)
#define Travel(i, u, G) for(int i = G.beg[u], v = G.to[i]; i; i = G.nex[i], v = G.to[i])
using namespace std;
inline int read() {
int x = 0, p = 1; char c = getchar();
for(; !isdigit(c); c = getchar()) if(c == '-') p = -1;
for(; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x *= p;
}
template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
inline void File() {
#ifndef ONLINE_JUDGE
freopen("P4630.in", "r", stdin);
freopen("P4630.out", "w", stdout);
#endif
}
typedef long long ll;
const int N = 2e5 + 10, M = N << 2;
int n, m, cnt, dfn[N], low[N], val[N], u, v, fa[N], clk, siz[N], sum, vis[N];
struct edge {
int e = 1, beg[N], nex[M], to[M];
inline void add(int x, int y) { to[++ e] = y, nex[e] = beg[x], beg[x] = e; }
} G1, G2;
stack<int> S; ll ans;
inline void tarjan(int u, int f) {
dfn[u] = low[u] = ++ clk, S.push(u);
++ sum, val[u] = -1;
Travel(i, u, G1) if (v != f) {
if (!dfn[v]) {
tarjan(v, u), chkmin(low[u], low[v]);
if (low[v] >= dfn[u]) {
++ val[++ cnt], fa[cnt] = u, G2.add(u, cnt);
while (!S.empty()) {
int x = S.top(); S.pop();
G2.add(cnt, x), fa[x] = cnt, ++ val[cnt];
if (x == v) break;
}
}
} else chkmin(low[u], dfn[v]);
}
}
inline void dfs(int u) {
if (u <= n) siz[u] = 1, vis[u] = 1;
ll res = 0;
Travel(i, u, G2) {
dfs(v), siz[u] += siz[v];
res += 2ll * siz[v] * (siz[u] - siz[v]);
}
res += 2ll * (sum - siz[u]) * (siz[u]);
ans += res * val[u];
}
int main() {
File();
cnt = n = read(), m = read();
For(i, 1, m) u = read(), v = read(), G1.add(u, v), G1.add(v, u);
For(i, 1, n) if (!vis[i]) sum = 0, tarjan(i, 0), dfs(i);
cout << ans << endl;
return 0;
}