这道题看起来是大规模解决树上路径的问题……那就是点分治啦。
既然我们要求的是树上长度为3的倍数的路径有多少条,那么我们不妨对每条路径的长度取模,这样的话我们实际上就获得了一堆长度为0,1,2的路径。因为点分治的性质,它每次只统计当前子树内经过重心的长度为3的倍数的路径,所以我们在每次统计之后 还要减去子树内的答案。
这样可以很容易的得出每棵子树内的答案是dis[0] * dis[0] + dis[1] * dis[2] * 2,后者很好理解,至于前者,也不难想,就是每两条长度为0或者是3的倍数的路径都可以被合并成长度为3的倍数的路径。这个不需要*2,因为计算的时候每两个点在一次路径计算中会被重复记算。
其余的就是点分治的常规操作。每次找到重心之后,首先求解本棵子树之内的所有结果。注意这里我们不需要去遍历每一棵子树,直接把她视为一个整体,把根结点(当前的重心)的父亲赋成0直接向下dfs。(然而在里面计算的时候还是要遍历子树的哈哈)然后每次这么更新答案就可以了。
看一下代码。
#include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<queue> #include<cstring> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar(' ') using namespace std; typedef long long ll; const int M = 100005; const int N = 10000005; int read() { int ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >='0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } struct node { int to,next,v; }e[M]; int head[M],dis[M],ecnt,n,m,x,y,w,ans,road[4],sum,root,size[M],maxs[M]; bool vis[M]; int gcd(int x,int y) { return !y ? x : gcd(y,x%y); } void add(int x,int y,int z) { e[++ecnt].to = y; e[ecnt].v = z; e[ecnt].next = head[x]; head[x] = ecnt; } void getroot(int x,int fa) { size[x] = 1,maxs[x] = 0; for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(t == fa || vis[t]) continue; getroot(t,x); size[x] += size[t]; maxs[x] = max(maxs[x],size[t]); } maxs[x] = max(maxs[x],sum - size[x]); if(maxs[x] < maxs[root]) root = x; } void getdis(int x,int fa) { road[dis[x]]++; for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(t == fa || vis[t]) continue; dis[t] = (dis[x] + e[i].v) % 3; getdis(t,x); } } int calc(int x,int leng) { int cur = 0;rep(i,0,3) road[i] = 0;//这里相当于是开桶记录 dis[x] = leng,getdis(x,0); cur += (road[1] * road[2]) << 1; cur += road[0] * road[0]; return cur; } void solve(int x) { vis[x] = 1,ans += calc(x,0); for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(vis[t]) continue; ans -= calc(t,e[i].v); sum = size[t],maxs[root = 0] = n; getroot(t,0),solve(root); } } int main() { n = read(); rep(i,1,n-1) x = read(),y = read(),w = read() % 3,add(x,y,w),add(y,x,w); sum = maxs[root] = n,getroot(1,0); solve(root); int d = gcd(ans,n*n); printf("%d/%d ",ans/d,n*n/d); return 0; }