题意
一棵边权为-1或1的树,求满足条件的路径数:路径边权和为0,存在路径上异于端点的一点到端点的边权和为0。
考虑点分治,设路径端点为$u,v$,中间满足条件的点为$k$,则$dis[u]+dis[v]=0$ 且 $dis[u]=dis[k]$或$dis[v]=dis[k]$。
搜距离的时候将一棵子树中的距离按照是否有相等的前缀距离分成两组,记为$f[i][0],f[i][1]$,与前面几颗子树的和$g[i][0],g[i][1]$
计算这一颗子树与之前子树的路径数 $ans+=f[i][0]*g[-i][1]+f[i][1]*g[-i][0]+f[i][1]*g[-i][1]$
计算子树内到分治中心的路径数 $ans+=f[0][1]$
距离枚举范围用深度进行限制,记录$maxdepth$
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 using namespace std; 5 const int N=100010; 6 const int inf=0x3f3f3f3f; 7 typedef long long ll; 8 int n,rt,S,mf[N],size[N],x,y,z,dis[N],d[N],st[N<<1],f[N<<1][2],g[N<<1][2],md; 9 bool vis[N],use[N<<1]; 10 int p,head[N],to[N<<1],nxt[N<<1],w[N<<1]; 11 ll ans; 12 inline int read() { 13 int re=0; char ch=getchar(); 14 while (ch<'0'||ch>'9') ch=getchar(); 15 while (ch>='0'&&ch<='9') re=re*10+ch-48,ch=getchar(); 16 return re; 17 } 18 inline void add(int x,int y,int z) { 19 to[++p]=y; nxt[p]=head[x]; w[p]=z; head[x]=p; 20 to[++p]=x; nxt[p]=head[y]; w[p]=z; head[y]=p; 21 } 22 void get_rt(int x,int fa) { 23 size[x]=1; mf[x]=0; 24 for (int i=head[x]; i; i=nxt[i]) { 25 if (to[i]==fa || vis[to[i]]) continue; 26 get_rt(to[i],x); 27 size[x]+=size[to[i]]; 28 if (size[to[i]]>mf[x]) mf[x]=size[to[i]]; 29 } 30 mf[x]=max(mf[x],S-size[x]); 31 if (mf[x]<mf[rt]) rt=x; 32 } 33 void get_dis(int x,int fa,int dep) { 34 dis[++dis[0]]=d[x]; md=max(md,dep); 35 if (st[d[x]]) f[d[x]][1]++; else f[d[x]][0]++; 36 st[d[x]]++; 37 for (int i=head[x]; i; i=nxt[i]) { 38 if (to[i]==fa || vis[to[i]]) continue; 39 d[to[i]]=d[x]+w[i]; 40 get_dis(to[i],x,dep+1); 41 } 42 st[d[x]]--; 43 } 44 void dfs(int x) { 45 vis[x]=1; int maxd=0; 46 for (int i=head[x]; i; i=nxt[i]) { 47 if (vis[to[i]]) continue; 48 dis[0]=0; d[to[i]]=n+w[i]; 49 md=0; get_dis(to[i],x,1); maxd=max(maxd,md); 50 f[n][0]+=f[n][1]; ans+=f[n][1]; f[n][1]=0; 51 ans+=1ll*f[n][0]*g[n][0]; g[n][0]+=f[n][0],f[n][0]=0; 52 for (int i=n+md; i>n; i--) { 53 ans+=1ll*f[i][0]*g[(n<<1)-i][1]+1ll*f[i][1]*g[(n<<1)-i][0] 54 +1ll*f[i][1]*g[(n<<1)-i][1]; 55 ans+=1ll*g[i][0]*f[(n<<1)-i][1]+1ll*g[i][1]*f[(n<<1)-i][0] 56 +1ll*g[i][1]*f[(n<<1)-i][1]; 57 g[i][0]+=f[i][0]; g[i][1]+=f[i][1]; 58 g[(n<<1)-i][0]+=f[(n<<1)-i][0]; g[(n<<1)-i][1]+=f[(n<<1)-i][1]; 59 f[i][0]=f[i][1]=f[(n<<1)-i][0]=f[(n<<1)-i][1]=0; 60 } 61 } 62 g[n][0]=0; 63 for (int i=n+maxd; i>n; i--) g[i][0]=g[i][1]=g[(n<<1)-i][0]=g[(n<<1)-i][1]=0; 64 for (int i=head[x]; i; i=nxt[i]) { 65 if (vis[to[i]]) continue; 66 S=size[to[i]]; mf[rt=0]=inf; 67 get_rt(to[i],x); dfs(rt); 68 } 69 } 70 int main() { 71 n=read(); 72 for (int i=1; i<n; i++) x=read(),y=read(),z=read() ? 1:-1,add(x,y,z); 73 S=n; mf[rt=0]=inf; 74 get_rt(1,0); dfs(rt); 75 printf("%lld\n",ans); 76 return 0; 77 }