题意为求出树上任意点对的距离对3取余的和。
比赛上听到题意就知道是点分治了,但是越写越不对劲,交之前就觉得会T,果不其然T了。修修改改结果队友写了发dp直接就过了Orz。
赛后想了想维护的东西太脑残了,以为像洛谷板子题一样暴力维护就可以,实则被卡死。
赛后的想法是维护距离当前重心的距离对3取余后的距离和以及个数。然后统计的时候枚举两个点的距离取余值,然后统计贡献。
注意下传的时候要消除重复的贡献。
代码丑陋请见谅QAQ
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int maxn = 3e4 + 15; 5 const int mod = 1e9 + 7; 6 struct node { 7 int s, e, w, next; 8 }edge[maxn * 2]; 9 int head[maxn], len; 10 void add(int s, int e, int w) { 11 edge[len].e = e; 12 edge[len].w = w; 13 edge[len].next = head[s]; 14 head[s] = len++; 15 } 16 int n, root, sum; 17 int vis[maxn], f[maxn], son[maxn]; 18 ll ans[4], o[4], num[4]; 19 void getroot(int x, int fa) { 20 son[x] = 1, f[x] = 0; 21 for (int i = head[x]; i != -1; i = edge[i].next) { 22 int y = edge[i].e; 23 if (y == fa || vis[y])continue; 24 getroot(y, x); 25 son[x] += son[y]; 26 f[x] = max(f[x], son[y]); 27 } 28 f[x] = max(f[x], sum - son[x]); 29 if (f[x] < f[root])root = x; 30 } 31 void getd(int x, int dis, int fa) { 32 o[dis % 3]++; 33 num[dis % 3] += dis; 34 for (int i = head[x]; i != -1; i = edge[i].next) { 35 int y = edge[i].e; 36 if (y == fa || vis[y])continue; 37 getd(y, (dis + edge[i].w) % mod, x); 38 } 39 } 40 void cal(int x, int val, int add) { 41 getd(x, val, 0); 42 for (int i = 0; i < 3; i++) 43 for (int j = 0; j < 3; j++) { 44 ans[(i + j) % 3] = (ans[(i + j) % 3] + o[i] * num[j] * add % mod + mod) % mod; 45 ans[(i + j) % 3] = (ans[(i + j) % 3] + o[j] * num[i] * add % mod + mod) % mod; 46 } 47 for (int i = 0; i < 3; i++)o[i] = num[i] = 0; 48 49 50 } 51 void solve(int x) { 52 cal(x, 0, 1); 53 vis[x] = 1; 54 for (int i = head[x]; i != -1; i = edge[i].next) { 55 int y = edge[i].e; 56 if (vis[y])continue; 57 cal(y, edge[i].w, -1); 58 sum = son[y]; 59 root = 0; 60 getroot(y, 0); 61 solve(root); 62 } 63 } 64 int main() { 65 while (scanf("%d", &n) != EOF) { 66 len = 0; 67 for (int i = 0; i <= n + 10; i++) 68 vis[i] = 0, head[i] = -1; 69 for (int i = 1; i < n; i++) { 70 int x, y, z; 71 scanf("%d%d%d", &x, &y, &z); 72 x++, y++; 73 add(x, y, z); 74 add(y, x, z); 75 } 76 for (int i = 0; i < 3; i++) 77 ans[i] = 0; 78 root = 0, f[0] = INT_MAX - 1; 79 sum = n; 80 getroot(1, 0); 81 solve(root); 82 for (int i = 0; i <= 2; i++) { 83 if (i == 2) 84 printf("%lld ", ans[i]); 85 else 86 printf("%lld ", ans[i]); 87 } 88 } 89 }