[BZOJ3697]采药人的路径
试题描述
采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
输入
第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。
输出
输出符合采药人要求的路径数目。
输入示例
7 1 2 0 3 1 1 2 4 0 5 2 0 6 3 1 5 7 1
输出示例
1
数据规模及约定
对于100%的数据,N ≤ 100,000。
题解
很明显这是一个求点对数的问题,所以想到点分治。那么对于跨重心的合法链我们怎么求呢?
首先考虑没有休息点的情况,那么显然我们可以把权值 0 看成 -1 统计一下每条从重心往下搜的链的权值和,那么对于一个子树 i,我们设权值和为 x 的链有 f[x] 条,那么我们只需要找到之前所有子树中权值和为 -x 的链的条数(设它为 g[-x])那么权值 x 对答案的贡献为 f[x] * g[-x],因为权值种数总是与子树大小相关的,所以直接暴力累加对于所有的 x 的贡献即可。
那么现在考虑上有休息点的情况。显然休息点可以在重心、重心左边或是重心右边。所以现在问题的关键在于如何得到 f[0~1][x],f[0][x] 表示起点为子树中的一个节点,终点为重心,权值和为 x 且没有休息点的链的条数,f[1][x] 表示起点为子树中的一个节点,终点为重心,权值和为 x 且有休息点的链的条数。(请慢慢理解。。。)仔细想想发现若是有休息点,那么子树中那个节点到休息点的权值和一定等于 0,意味着休息点到重心的权值和等于 x,那么这个事情就很好统计了。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 100010 #define maxm 200010 #define LL long long int n, m, head[maxn], next[maxm], to[maxm], dist[maxm]; LL ans; void AddEdge(int a, int b, int c) { to[++m] = b; dist[m] = c; next[m] = head[a]; head[a] = m; swap(a, b); to[++m] = b; dist[m] = c; next[m] = head[a]; head[a] = m; return ; } bool vis[maxn]; int root, size, f[maxn], siz[maxn]; void getroot(int u, int fa) { siz[u] = 1; f[u] = 0; for(int e = head[u]; e; e = next[e]) if(to[e] != fa && !vis[to[e]]) { getroot(to[e], u); siz[u] += siz[to[e]]; f[u] = max(f[u], siz[to[e]]); } f[u] = max(f[u], size - siz[u]); if(f[root] > f[u]) root = u; return ; } int has[maxn<<1], A[2][maxn<<1], B[2][maxn<<1], mxd, mnd; void dfs(int u, int fa, int d) { // printf("(d)%d(%d) ", d, has[d+n]); mxd = max(mxd, d); mnd = min(mnd, d); A[has[d+n]?1:0][d+n]++; has[d+n]++; for(int e = head[u]; e; e = next[e]) if(to[e] != fa && !vis[to[e]]) dfs(to[e], u, d + dist[e]); has[d+n]--; return ; } void solve(int u) { // printf("u: %d ", u); vis[u] = 1; bool fir = 1; int Mxd = -n - 1, Mnd = n + 1; for(int e = head[u]; e; e = next[e]) if(!vis[to[e]]) { mxd = -n - 1; mnd = n + 1; dfs(to[e], u, dist[e]); Mxd = max(Mxd, mxd); Mnd = min(Mnd, mnd); if(fir) ; else { ans += (LL)A[0][n] * B[0][n]; // printf("%d ", A[0][n] * B[0][n]); for(int i = n + mnd; i <= n + mxd; i++) { int d = i - n; ans += (LL)A[0][i] * B[1][n-d] + A[1][i] * B[0][n-d] + A[1][i] * B[1][n-d]; // printf("(%d)%d ", d, A[0][i] * B[1][n-d] + A[1][i] * B[0][n-d] + A[1][i] * B[1][n-d]); } } ans += (LL)A[1][n]; fir = 0; for(int i = n + mnd; i <= n + mxd; i++) B[0][i] += A[0][i], B[1][i] += A[1][i], A[0][i] = A[1][i] = 0, has[i] = 0; // putchar(' '); } for(int i = n + Mnd; i <= n + Mxd; i++) B[0][i] = B[1][i] = 0; for(int e = head[u]; e; e = next[e]) if(!vis[to[e]]) { root = 0; f[0] = n + 1; size = siz[u]; getroot(to[e], u); solve(root); } return ; } int main() { n = read(); for(int i = 1; i < n; i++) { int a = read(), b = read(), c = read() ? 1 : -1; AddEdge(a, b, c); } root = 0; f[0] = n + 1; size = n; getroot(1, 0); solve(root); printf("%lld ", ans); return 0; }