题解:
首先我们应该注意,这道题问的是:
对于点对(a,b),存在点c在ab路径上,且a<->c和b<->c都是阴阳平衡的合法点对(a,b)有多少对。
因此这玩意是树链统计。
阴阳平衡就是$1+(-1)=0$;
用点分治搞一搞。
仔细看一看,你很快发现如果a->b和a->b->c相等的话,b<->c一定是阴阳平衡的(废话)。
所以我们将状态分为两种,路径上没有阴阳平衡的,还有路径上没有阴阳平衡的。
所以代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 100050 #define ll long long inline int rd() { int f=1,c=0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();} return f*c; } int n,hed[N],cnt; struct EG { int to,nxt,v; }e[2*N]; void ae(int f,int t,int v) { e[++cnt].to = t; e[cnt].nxt = hed[f]; e[cnt].v = v; hed[f] = cnt; } int w[N],rt,sum,mrk[N]; int siz[N]; ll ans; void get_rt(int u,int fa) { w[u] = 0,siz[u] = 1; for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(to==fa||mrk[to])continue; get_rt(to,u); siz[u] += siz[to]; if(siz[to]>w[u])w[u]=siz[to]; } w[u] = max(w[u],sum-siz[u]); if(w[u]<w[rt])rt=u; } ll f[2*N][2],g[2*N][2]; int hs[2*N],max_dep; void dfs(int u,int fa,int dep) { max_dep = max(max_dep, (dep-N) * (dep<N?-1:1)); if(hs[dep])f[dep][1]++; else f[dep][0]++; hs[dep]++; for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(to==fa||mrk[to])continue; dfs(to,u,dep+e[j].v); } hs[dep]--; } void work(int u) { mrk[u] = 1;g[N][0] = 1;int mxd = 0; for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(mrk[to])continue; max_dep = 0; dfs(to,u,N+e[j].v); mxd = max(max_dep,mxd); ans+=f[N][0]*(g[N][0]-1); for(int i=-max_dep;i<=max_dep;i++) ans+= f[N+i][0]*g[N-i][1]+f[N+i][1]*g[N-i][0]+f[N+i][1]*g[N-i][1]; for(int i=N-max_dep;i<=N+max_dep;i++) { g[i][0]+=f[i][0]; g[i][1]+=f[i][1]; f[i][0]=f[i][1]=0; } } for(int i=N-mxd;i<=N+mxd;i++)g[i][0]=g[i][1]=0; int sum0 = sum; for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(mrk[to])continue; rt = 0,sum = (siz[to]>siz[u]?sum0-siz[u]:siz[to]); get_rt(to,0); work(rt); } } int main() { n = rd(); for(int f,t,v,i=1;i<n;i++) { f = rd(),t = rd(),v = rd(); if(!v)v=-1; ae(f,t,v),ae(t,f,v); } w[0]=0x3f3f3f3f; rt = 0,sum = n; get_rt(1,0); work(rt); printf("%lld ",ans); return 0; }