看别人代码学习的,想了好久才看懂代码,自己写了一遍,又写点注释。
1 #include <iostream> 2 #include <cstdio> 3 #include <vector> 4 using namespace std; 5 int N,k; 6 long long ans = 0; // 答案会爆掉 long long 7 int sum[222222],num[222222][6],in[222222]; 8 //sum子树节点个数 num[u][i] u子树中,距离根节点距离mod k == i 的节点个数 9 vector<int> E[222222]; 10 void dfs(int u,int depth) //深度用来求两片区域相隔的距离 11 { 12 sum[u] = num[u][depth%k] = 1; 13 for (unsigned int i = 0; i<E[u].size();i++) 14 { 15 int v = E[u][i]; 16 if (in[v]==1) continue; 17 in[v] = 1; 18 dfs(v,depth+1); 19 for (int j = 0; j<k ; j++) 20 for (int kk = 0; kk<k; kk++) 21 { 22 int t = (k- ((( ( j + kk- (2*depth) ) % k ) + k ) % k )) % k; 23 ans+=t*(long long)num[u][j]*(long long)num[v][kk]; // 将走的过程中 多余(浪费)的边计数 24 } 25 for (int j = 0; j<k; j++) 26 num[u][j]+=num[v][j]; 27 sum[u]+=sum[v]; 28 } 29 ans+=(long long)sum[u]*(long long)(N-sum[u]); // 统计每条边需要走的次数 30 } 31 int main() 32 { 33 cin >> N >> k; 34 for (int i = 1; i<N; i++) 35 { 36 int a,b; 37 scanf("%d%d",&a,&b); 38 E[a].push_back(b); 39 E[b].push_back(a); 40 } 41 in[1] = 1; 42 dfs(1,0); 43 cout << ans/k <<endl; 44 }