树上交集 / 路径计数机
题目链接:ybt高效进阶 21165 / 150C / nowcoder 1103B
题目大意
给你一棵树,问你能找到多少个四元组 (a,b,c,d),满足 a 到 b 边数为 p,c 到 d 边数为 q,而且两条路径没有交。
思路
考虑求不交比较难,我们搞有交的。
那不难想出两条路径就两种情况,一个是公用同一个 LCA,要么是有一条路径穿过了另一条路径的 LCA。
然后我们就以 LCA 为中心去搞,考虑求出这四个东西:(fp_i,fq_i,gp_i,gq_i),分别表示从 (i) 出发,经过 (p/q) 条边,然后到的是 (i) 子树内 / 外的点。
那不难想到所有的四元组个数就是:(sumlimits_{i=1}^n fp_isumlimits_{i=1}^n fq_i)
然后有交的:(sumlimits_{i=1}^n(fp_ifq_i+fp_igq_i+gp_ifq_i))。
(后面两个都是第二个情况,只是谁穿过不同而已)
那接下来就是要求这四个数组。
那不难想出可以先搞路径一段是在 (i) 上的,然后把两个路径拼上得到上面的数组。
(两条路径的长度之和已经确定,就直接枚举一条的长度)
然后设 (f_{i,j},g_{i,j}) 为有多少条路径一端是 (i),另一端在 (i) 子树内 / 外,然后路径长度是 (j) 的路径数。
(f_{i,j}) 可以直接 DP 下去,(f_{i,j}=sum_{k=son_i} f_{k,j-1},f_{i,0}=1)。
至于 (g_{i,j}),我们想,如果我们求出以 (i) 为根的时候的 (f_{i,j}),那这个就是不管子树内外的,减去在子树内的(一开始算出的 (f_{i,j})),就是在子树外的了。
那我们考虑换根一下就好了。(记得跑完换回来)
然后就好啦。
代码
#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
struct node {
int to, nxt;
}e[6001];
int n, p, q, x, y;
int le[3001], KK;
ll ans, f[3001][3001], g[3001][3001];//这个是一个端点是 i,另一个端点在 i 子树内 / 外,路径长度为 j 的个数
ll fp[3001], fq[3001], gp[3001], gq[3001];//我们要求的东西
ll nw[3001][3001];
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
e[++KK] = (node){x, le[y]}; le[y] = KK;
}
void dfs_f(int now, int father) {
f[now][0] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
dfs_f(e[i].to, now);
for (int j = 1; j <= p; j++) {//把两个从 now 出发的拼起来
fp[now] += (j == 0 ? 1 : f[e[i].to][j - 1]) * f[now][p - j];
}
for (int j = 1; j <= q; j++) {
fq[now] += (j == 0 ? 1 : f[e[i].to][j - 1]) * f[now][q - j];
}
for (int j = 1; j < n; j++)
f[now][j] += f[e[i].to][j - 1];
}
}
void dfs_g(int now, int father) {
for (int i = 0; i < n; i++)
g[now][i] = nw[now][i] - f[now][i];
g[now][0] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
for (int j = 1; j < n; j++)//换根
nw[now][j] -= nw[e[i].to][j - 1];
for (int j = 1; j < n; j++)
nw[e[i].to][j] += nw[now][j - 1];
dfs_g(e[i].to, now);
for (int j = 1; j < n; j++)//换回来
nw[e[i].to][j] -= nw[now][j - 1];
for (int j = 1; j < n; j++)
nw[now][j] += nw[e[i].to][j - 1];
}
for (int i = 1; i <= p; i++)//计算(两个都是从 now 出发,一个向 now 子树,一个往外)
gp[now] += g[now][i] * f[now][p - i];
for (int i = 1; i <= q; i++)
gq[now] += g[now][i] * f[now][q - i];
}
int main() {
// freopen("intersection.in", "r", stdin);
// freopen("intersection.out", "w", stdout);
scanf("%d %d %d", &n, &p, &q);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
}
dfs_f(1, 0);
for (int i = 1; i <= n; i++)
for (int j = 0; j < n; j++)
nw[i][j] = f[i][j];
dfs_g(1, 0);
ll lsum = 0, rsum = 0;//按上面进行计算
for (int i = 1; i <= n; i++)
lsum += fp[i], rsum += fq[i];
ans = lsum * rsum;
for (int i = 1; i <= n; i++)
ans -= fp[i] * fq[i] + fp[i] * gq[i] + gp[i] * fq[i];
printf("%lld", ans * 4ll);//记得要乘4(ab可以互换,cd可以互换)
return 0;
}