这道题如果暴力的话是平方级别的,然后我们可以看出来,显然能够使用换根dp来做
只需要维护做两边dfs,分别维护向上向下的答案即可,不但要维护距离,也要维护点数。
维护的时候都是维护mod3状态下的答案
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=5e5+10; const int mod=1e9+7; int h[N],ne[N],e[N],idx; ll w[N]; ll f[N][3];//以i为根节点的子树到儿子的距离mod 3为j的距离和 ll g[N][3]; ll cnt1[N][3]; ll cnt2[N][3]; void add(int a,int b,ll c){ e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++; } void dfs(int u,int fa){ int i; cnt1[u][0]=1; for(i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==fa) continue; dfs(j,u); f[u][(0+w[i])%3]+=(cnt1[j][0]*w[i]%mod+f[j][0])%mod; f[u][(1+w[i])%3]+=(cnt1[j][1]*w[i]%mod+f[j][1])%mod; f[u][(2+w[i])%3]+=(cnt1[j][2]*w[i]%mod+f[j][2])%mod; for(int x=0;x<3;x++) f[u][x]%=mod; cnt1[u][(0+w[i])%3] += cnt1[j][0]; cnt1[u][(1+w[i])%3] += cnt1[j][1]; cnt1[u][(2+w[i])%3] += cnt1[j][2]; } } void get(int u,int fa){ for(int i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==fa) continue; int c[4]={0}; ll d[4]={0}; c[0]=cnt1[u][0]+cnt2[u][0]; c[1]=cnt1[u][1]+cnt2[u][1]; c[2]=cnt1[u][2]+cnt2[u][2]; c[(0+w[i])%3]-=cnt1[j][0]; c[(1+w[i])%3]-=cnt1[j][1]; c[(2+w[i])%3]-=cnt1[j][2]; d[0]=(f[u][0]+g[u][0])%mod; d[1]=(f[u][1]+g[u][1])%mod; d[2]=(f[u][2]+g[u][2])%mod; d[(0+w[i])%3]=(d[(0+w[i])%3]-f[j][0]+mod-cnt1[j][0]*w[i]%mod+mod)%mod; d[(1+w[i])%3]=(d[(1+w[i])%3]-f[j][1]+mod-cnt1[j][1]*w[i]%mod+mod)%mod; d[(2+w[i])%3]=(d[(2+w[i])%3]-f[j][2]+mod-cnt1[j][2]*w[i]%mod+mod)%mod; g[j][(0+w[i])%3]=(c[0]*w[i]%mod+d[0])%mod; g[j][(1+w[i])%3]=(c[1]*w[i]%mod+d[1])%mod; g[j][(2+w[i])%3]=(c[2]*w[i]%mod+d[2])%mod; cnt2[j][(0+w[i])%3]+=c[0]; cnt2[j][(1+w[i])%3]+=c[1]; cnt2[j][(2+w[i])%3]+=c[2]; get(j,u); } } int main(){ ios::sync_with_stdio(false); int n; while(cin>>n){ int i; for(i=1;i<=n;i++){ h[i]=-1; f[i][0]=f[i][1]=f[i][2]=0; g[i][0]=g[i][1]=g[i][2]=0; cnt1[i][0]=cnt1[i][1]=cnt1[i][2]=0; cnt2[i][0]=cnt2[i][1]=cnt2[i][2]=0; } idx=0; for(i=1;i<n;i++){ int a,b,c; cin>>a>>b>>c; a++,b++; add(a,b,c); add(b,a,c); } dfs(1,-1); get(1,-1); ll ans1=0,ans2=0,ans3=0; for(i=1;i<=n;i++){ //cout<<f[i][0]<<" "<<f[i][1]<<" "<<f[i][2]<<endl; ans1=(ans1+f[i][0]+g[i][0])%mod; ans2=(ans2+f[i][1]+g[i][1])%mod; ans3=(ans3+f[i][2]+g[i][2])%mod; } cout<<ans1<<" "<<ans2<<" "<<ans3<<endl; } }