poj3417
题意
给出一颗 n 个节点, n - 1 条边的树,再加上 m 条新边,允许删掉树边和新边各一条,问能使树分为两部分的方案数。
分析
在树的基础上加上不重复的新边一定会构成环,那么考虑的就是怎么拆分环。
对于给出的新边(u, v),构成的环就是,u -> LCA(u, v) -> v -> u,将环上的边都标记加1,最后统计每条边的标记值,
如果一条边未被标记过,那么只要拆掉这条边就分成两部分了,即有 m 中方案数了;如果被标记过一次,那么在拆掉这条边的同时,一定要拆掉构成这个环的新边,即有 1 种方案;如果标记数大于 1,也就是说这条边被两个环同时标记过,根据题目的条件,无法分成两块了,即没有这种方案。
在求标记值的时候,要用到树形DP,设 hide[u] 为 u 到它的父节点所连边被标记过的次数,对于读入的新边 (u, v), hide[u]++ ,hide[v]++,hide[LCA(u, v)] -= 2,这个技巧在求区间覆盖时很常用。
最后,建边要用到链式前向星,向量超时了 (ง •̀_•́)ง┻━┻。
code
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
typedef pair<int, int> P;
typedef long long ll;
const int MAXN = 1e5 + 5; // 最大节点数
const int LOG_N = 60; // 树的最大深度
int n, m;
int head[MAXN];
struct Edge
{
int to, next;
}edge[MAXN * 2];
int depth[MAXN]; // 节点深度
int parent[LOG_N][MAXN]; // parent[k][i]表示 i 向上走 2^k 步能到达的节点
int hide[MAXN]; // u到它的父亲节点所连边被覆盖过几次
int cnt;
void add(int u, int v)
{
edge[cnt].to = v;
edge[cnt].next = head[u];
head[u] = cnt++;
}
void dfs(int pre, int u, int d)
{
parent[0][u] = pre;
depth[u] = d;
for(int i = head[u]; ~i; i = edge[i].next)
{
int v = edge[i].to;
if(v != pre) dfs(u, v, d + 1);
}
}
void init()
{
int root = 1;
dfs(-1, root, 0);
for(int k = 1; k < LOG_N; k++)
{
for(int i = 1; i <= n; i++)
{
if(parent[k - 1][i] < 0) parent[k][i] = -1;
else parent[k][i] = parent[k - 1][parent[k - 1][i]];
}
}
}
int lca(int u, int v)
{
if(depth[u] > depth[v]) swap(u, v);
for(int i = 0; i < LOG_N; i++) // u 和 v 向上走到同一深度
{
if((depth[v] - depth[u]) >> i & 1) // 把 (depth[v] - depth[i]) 化成二进制后可以看到,就是找到所有 1 的位置
{
v = parent[i][v];
}
}
if(v == u) return u;
for(int i = LOG_N - 1; i >= 0; i--) // 找 lca
{
if(parent[i][u] != parent[i][v]) // 如果相同,那么一定是公共祖先或公共祖先之上的节点
{
u = parent[i][u];
v = parent[i][v];
}
}
return parent[0][u];
}
void dfs2(int pre, int u)
{
for(int i = head[u]; ~i; i = edge[i].next)
{
int v = edge[i].to;
if(v != pre)
{
dfs2(u, v);
hide[u] += hide[v];
}
}
}
int main()
{
while(~scanf("%d%d", &n, &m))
{
memset(head, -1, sizeof head);
for(int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
init();
for(int i = 0; i < m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
int node = lca(x, y);
hide[x]++;
hide[y]++;
hide[node] -= 2;
}
dfs2(-1, 1);
int ans = 0;
for(int i = 2; i <= n; i++)
{
if(hide[i] == 0) ans += m;
else if(hide[i] == 1) ans++;
}
printf("%d
", ans);
}
return 0;
}