LCA
题意:先给出一棵无根树,然后下面再给出m条边,把这m条边连上,然后每次你能毁掉两条边,规定一条是树边,一条是新边,问有多少种方案能使树断裂。
我们知道,这m条边连上后这颗树必将成环,假设新边为(u,v),那么环为u---->LCA(u,v)------->v-------->u,我们给这个环上的边计数1,表示这些边被一个环覆盖了一次。添加了多条新边后,可知树上有些边是会被多次覆盖的,画图很容易发现,但一个树边被覆盖了2次或以上,它就是一条牢固的边,就是说毁掉它再毁掉任何一条新边都好,树都不会断裂,这个结论也是很容易证明的,画图更明显,所以不累述
所以这启发了我们,要统计所有的边被覆盖了几次,我们分情况来讨论
1.覆盖0次,说明这条边不在任何一个环上,这样的边最脆弱,单单是毁掉它就已经可以使树断裂了,这时候只要任意选一条新边去毁,树还是断裂的,所以这样的树边,就产生m种方案(m为新边条数)
2.覆盖1次,说明这条边在一个环上,且,仅在一个环上,那么要使树断裂,就毁掉这条树边,并且毁掉和它对应的那条新边(毁其他的新边无效),就一定能使树断裂,这种树边能产生的方案数为1,一条这样的树边只有唯一解
3.覆盖2次或以上,无论怎么样都不能使树断裂,产生的方案数为0
所以,如果我们能知道所有的树边的覆盖,那么统计一次就行了,所以问题只剩下,怎么每条边被覆盖了几次?
需要用到树DP。
首先我们定义dp[u]的意义为,u所对应的那条父边(u和它父亲连接的那条边)被覆盖的次数
对应一条新边(u,v),我们知道是要求LCA(u,v)的,这时候我们计数dp[u]++ , dp[v]++ , dp[lca]-=2
为什么这样计数?我们试着看看,点u和点v和点lca,都试着沿路径一直回到树根处(注意不是回到LCA而是树根),u的路径中每经过一个点,就将这些点上的值加上dp[u],同样v的路径上没经过一个点就将这些点上的值加上dp[v],lca也是这样。你会发现,lca回到树根的部分,其实被抵消掉了,dp值没有变化,而u到lca,v到lca部分的值都已经分别加上了dp[u],dp[v]
这启发了我们,我们在求完所有m对顶点的LCA后,每个u和v都做一次dp[u]++,dp[v]++,dp[lca]-=2,然后我从树根开始向下遍历一次整棵树,在回溯的时候就执行累加dp[u],dp[v]的操作
这其实就是树DP的过程,好像说得有点复杂了,不过其实不难理解
#include <iostream> #include <cstdio> #include <cstring> using namespace std; const int N = 100005; const int Q = 100005; int tote,totea; int head[N],__head[N],dp[N],fa[N],vis[N]; struct edge{ int u,v,next; }e[2*N]; struct ask{ int u,v,lca,next; }ea[2*Q]; void add_edge(int u ,int v){ e[tote].u = u; e[tote].v = v; e[tote].next = head[u]; head[u] = tote++; } void add_ask(int u ,int v){ ea[totea].u = u; ea[totea].v = v; ea[totea].lca = -1; ea[totea].next = __head[u]; __head[u] = totea++; } int find(int x){ return x == fa[x] ? x : fa[x] = find(fa[x]); } void Tarjan(int u) { vis[u] = 1; for(int k = __head[u]; k!=-1; k=ea[k].next) if(vis[ea[k].v]) { int v = ea[k].v; int lca = find(v); ea[k].lca = ea[k^1].lca = lca; } for(int k=head[u]; k!=-1; k=e[k].next) if(!vis[e[k].v]) { int v = e[k].v; Tarjan(v); fa[v] = u; } } void DP(int u) { vis[u] = 1; for(int k=head[u]; k!=-1; k=e[k].next) if(!vis[e[k].v]) { int v = e[k].v; DP(v); dp[u] += dp[v]; } } int main() { int n,q; while(scanf("%d%d",&n,&q)!=EOF) { tote = totea = 0; memset(head,-1,sizeof(head)); memset(__head,-1,sizeof(__head)); memset(dp,0,sizeof(dp)); memset(vis,0,sizeof(vis)); for(int i=1; i<=n; i++) fa[i] = i; for(int i=1; i<n; i++){ int u,v; scanf("%d%d",&u,&v); add_edge(u,v); add_edge(v,u); } for(int i=0; i<q; i++){ int u,v; scanf("%d%d",&u,&v); add_ask(u,v); add_ask(v,u); dp[u]++; dp[v]++; } Tarjan(1); for(int i=0; i<q; i++) { int s = i*2 , lca = ea[s].lca; dp[lca] -= 2; } memset(vis,0,sizeof(vis)); DP(1); int res = 0; for(int i=2; i<=n; i++) if(dp[i] == 1) res++; else if(dp[i] == 0) res += q; printf("%d\n",res); } return 0; }