题目大意
给出一棵(n)个点的树,每条边的权值是1或0,一条路径合法的条件是:路径上存在一个休息点(不能是起点也不能是终点),使得起点到该点路径上0和1的个数相等,该点到终点的路径上0和1的个数也相等。求合法路径条数。
分析
求满足条件的树上路径条数显然是点分治。
考虑分治中心(x),对于两条路径(x->u)和(x->v),路径(u->v)合法有三种情况:
- 休息点在(x->u)上
- 休息点在(x->v)上
- 休息点在(u)
将0视作-1,将1视作1,一条路径0和1的个数相等等价于路径权值和为0,那么我们只需要预处理每个点(u),路径(x->u)是否有休息点,然后用两个桶统计一下答案就行了。为了便于计算答案,我们一棵一棵子树做,就不用容斥了。
Code
#include <cstdio>
#include <cstring>
typedef long long ll;
const int N = 300007;
int max(int a, int b) { return a > b ? a : b; }
ll ans;
int n;
int sum, p, tot, st[N], to[N << 1], nx[N << 1], len[N << 1], siz[N], maxsiz[N], del[N];
void add(int u, int v, int w) { to[++tot] = v, nx[tot] = st[u], len[tot] = (w == 1 ? -1 : 1), st[u] = tot; }
void getp(int u, int from)
{
siz[u] = 1, maxsiz[u] = 0;
for (int i = st[u]; i; i = nx[i]) if (to[i] != from && !del[to[i]]) getp(to[i], u), siz[u] += siz[to[i]], maxsiz[u] = max(maxsiz[u], siz[to[i]]);
maxsiz[u] = max(maxsiz[u], sum - siz[u]);
if (maxsiz[u] < maxsiz[p]) p = u;
}
int cnt, arr[N], dis[N], ok[N], buc[N * 4], b[N * 4], b0[N * 4];
void getdis(int u, int from)
{
if (b0[dis[u] + N]) ok[u] = 1;
else ok[u] = 0;
arr[++cnt] = u, b0[dis[u] + N]++;
for (int i = st[u]; i; i = nx[i]) if (to[i] != from && !del[to[i]]) dis[to[i]] = dis[u] + len[i], getdis(to[i], u);
b0[dis[u] + N]--;
}
void solve(int u)
{
del[u] = 1, dis[u] = 0;
for (int i = st[u]; i; i = nx[i])
if (!del[to[i]])
{
ll ret = 0;
cnt = 0, dis[to[i]] = len[i], getdis(to[i], u);
for (int j = 1; j <= cnt; j++)
{
if (ok[arr[j]])
{
ret += b[N - dis[arr[j]]] + buc[N - dis[arr[j]]];
if (!dis[arr[j]]) ret++;
}
else
{
ret += buc[N - dis[arr[j]]];
if (!dis[arr[j]]) ret += b[N];
}
}
for (int j = 1; j <= cnt; j++)
{
if (ok[arr[j]]) buc[dis[arr[j]] + N]++;
else b[dis[arr[j]] + N]++;
}
ans += ret;
}
dis[u] = 0, getdis(u, 0);
for (int i = 1; i <= cnt; i++) b[dis[arr[i]] + N] = 0, buc[dis[arr[i]] + N] = 0, ok[arr[i]] = 0;
for (int i = st[u]; i; i = nx[i]) if (!del[to[i]]) sum = siz[to[i]], p = 0, getp(to[i], 0), solve(p);
}
int main()
{
scanf("%d", &n);
for (int i = 1, u, v, w; i < n; i++) scanf("%d%d%d", &u, &v, &w), add(u, v, w), add(v, u, w);
sum = n, maxsiz[0] = N, getp(1, 0), solve(p);
printf("%lld
", ans);
return 0;
}