题目描述
Description
给定一棵 n 个点的树和树上的 m 条链,请问有多少对链有至少一个公共点。
Input
第一行两个正整数 n,m。 第 2~n 行,每行两个整数 s,t,表示有一条边连接 s 点和 t 点。
接下来 m 行,每行两个整数 a,b,表示有一条链的起点、终点分别为 a,b。(链的起点、 终点可能相同)
Output
一行一个整数,表示答案。
Sample Input
5 3
1 2
1 3
2 4
2 5
4 5
1 4
3 5
Sample Output
3
Hint
对于 30%的数据,$n,m leq 100$
对于 50%的数据,$n,m leq 2000$
对于另 20%的数据,读入的 $s,t$ 满足 $t=s+1$
对于 100%的数据,$1 leq n,m leq 100000$
时间限制:$1s$ 空间限制:$512MB$
题目分析
这一题的关键在于判断两条链是否重合的方法。首先,要分析一下两条链相交的情况。如图所示:
(图A:链 $S1-T1$ 与 $S2-T2$ 有一个交点O)
(图B:链 $S1-T1$ 与 $S2-T2$ 有两个或更多交点)
由上图可以得出,两条链相交,当且仅当一条链两端点的LCA在另一条链上。
现在我们需要分类讨论。对于第一种情况,我们只需要计算有多少条链的LCA在某个点上,设$tot_i$为LCA在x点上的链的个数。那么,如果$tot_i>1$,那就说明有多条链经过这一点,也就是说,有$n imes (n - 1) div 2$对链相交。这个过程用代码写出来就是:
for(int i=1;i<=n;i++) ans+=tot[i]*(tot[i]-1)/2;
对于第二种情况,就复杂一些。我们要考虑一条链上有多少个其它链的LCA。我们需要一个前缀和数组$sum$,其递推公式为$sum_v = sum_u + tot_v (v为u的子节点)$,最后要计算每一条链上的和,只需要算出$sum_u + sum_v - 2 imes sum_lca(u,v) (u,v为链的端点)$。这个过程用代码写出来就是:
void solve(int u,int fa) { sum[u]=sum[fa]+tot[u]; //递推 for(auto &v: g[u]) //遍历子节点 if(v!=fa) solve(v,u); }
for(int i=1;i<=m;i++) ans+=sum[c[i].u]+sum[c[i].v]-2*sum[c[i].lca];
最后,就只需要输出ans的值了。
完整代码
不贴代码的题解不是好题解。
1 #include <bits/stdc++.h> 2 #define int long long 3 using namespace std; 4 int n,m,ans,dep[100005],f[100005][20],sum[100005],tot[100005]; 5 vector<int> g[100005]; //用vector存图,搭配C++11食用效果更佳 6 struct Chain 7 { 8 int u,v; //链的起点和终点 9 int lca; //链的端点的LCA 10 }c[100005]; 11 void dfs(int u,int fa) //LCA预处理 12 { 13 dep[u]=dep[fa]+1; 14 f[u][0]=fa; 15 for(int i=1;(1<<i)<=dep[u];i++) 16 f[u][i]=f[f[u][i-1]][i-1]; 17 for(auto &v: g[u]) 18 if(v!=fa) 19 dfs(v,u); 20 } 21 int lca(int a,int b) //求最近公共祖先 22 { 23 if(dep[a]<dep[b]) swap(a,b); 24 for(int i=17;i>=0;i--) 25 { 26 if(dep[f[a][i]]>=dep[b]) 27 a=f[a][i]; 28 if(a==b) return a; 29 } 30 for(int i=17;i>=0;i--) 31 if(f[a][i]!=f[b][i]) 32 a=f[a][i],b=f[b][i]; 33 return f[a][0]; 34 } 35 void solve(int u,int fa) //求前缀和数组 36 { 37 sum[u]=sum[fa]+tot[u]; 38 for(auto &v: g[u]) 39 if(v!=fa) 40 solve(v,u); 41 } 42 signed main() 43 { 44 scanf("%lld%lld",&n,&m); //输入数据 45 for(int i=1;i<n;i++) 46 { 47 int u,v; 48 scanf("%lld%lld",&u,&v); 49 g[u].push_back(v); 50 g[v].push_back(u); 51 } 52 dfs(1,0); 53 for(int i=1;i<=m;i++) 54 { 55 scanf("%lld%lld",&c[i].u,&c[i].v); 56 c[i].lca=lca(c[i].u,c[i].v); 57 } 58 for(int i=1;i<=m;i++) tot[c[i].lca]++; //记录tot 59 solve(1,0); 60 for(int i=1;i<=n;i++) ans+=tot[i]*(tot[i]-1)/2; //情况一计算 61 for(int i=1;i<=m;i++) ans+=sum[c[i].u]+sum[c[i].v]-2*sum[c[i].lca]; //情况二计算 62 printf("%lld",ans); 63 while(1); //不加防抄袭神器的题解也不是好题解…… 64 return 0; 65 }