Description
采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
Input
第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。
Output
输出符合采药人要求的路径数目。
Sample Input
7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
Sample Output
1
HINT
对于100%的数据,N ≤ 100,000。
这题就恶心了呀
还是点分治
考虑经过点x的路径,f[i][0/1]表示当前子树到根的路径为i,存在/不存在休息点的方案数,g[i][0/1]表示前几棵子树到根的路径为i,存在/不存在休息点的方案数
那么对于一个子树,它对答案的贡献是f[0][0]*g[0][0]+Σf[i][0]*g[-i][1]+f[i][1]*g[-i][0]+f[i][1]*g[-i][1]
第一维搞去掉负数调了半天……
#include<cstdio> #include<iostream> #define N 200010 #define LL long long using namespace std; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } struct edge{int to,next,v;}e[2*N]; int head[N],son[N],f[N],mrk[N],dep[N]; LL s[2*N][2],t[2*N][2],dis[N]; int n,cnt,root,sum,mxd; bool vis[N]; LL ans; inline void ins(int u,int v,int w) { e[++cnt].to=v; e[cnt].v=w; e[cnt].next=head[u]; head[u]=cnt; } inline void insert(int u,int v,int w) { ins(u,v,w); ins(v,u,w); } inline void getroot(int x,int fa) { son[x]=1;f[x]=0; for (int i=head[x];i;i=e[i].next) if (!vis[e[i].to]&&fa!=e[i].to) { getroot(e[i].to,x); son[x]+=son[e[i].to]; f[x]=max(f[x],son[e[i].to]); } f[x]=max(f[x],sum-son[x]); if (f[x]<f[root])root=x; } inline void dfs(int x,int fa) { mxd=max(mxd,dep[x]); if (mrk[dis[x]])s[dis[x]][1]++; else s[dis[x]][0]++; mrk[dis[x]]++; for(int i=head[x];i;i=e[i].next) if (!vis[e[i].to]&&fa!=e[i].to) { dep[e[i].to]=dep[x]+1; dis[e[i].to]=dis[x]+e[i].v; dfs(e[i].to,x); } mrk[dis[x]]--; } inline void calc(int x) { int mx=0; t[n][0]=1; for (int i=head[x];i;i=e[i].next) if (!vis[e[i].to]) { dis[e[i].to]=n+e[i].v; dep[e[i].to]=1; mxd=1; dfs(e[i].to,0); mx=max(mx,mxd); ans+=(t[n][0]-1)*s[n][0]; for (int j=-mxd;j<=mxd;j++) ans+=t[n-j][1]*s[n+j][1]+t[n-j][0]*s[n+j][1]+t[n-j][1]*s[n+j][0]; for (int j=n-mxd;j<=n+mxd;j++) { t[j][0]+=s[j][0]; t[j][1]+=s[j][1]; s[j][0]=s[j][1]=0; } } for (int i=n-mx;i<=n+mx;i++) t[i][0]=t[i][1]=0; } inline void solve(int x) { vis[x]=1;calc(x); for (int i=head[x];i;i=e[i].next) if (!vis[e[i].to]) { sum=son[e[i].to]; root=0; getroot(e[i].to,0); solve(root); } } int main() { n=read(); for (int i=1;i<n;i++) { int x=read(),y=read(),z=read(); if (!z)z--; insert(x,y,z); } f[0]=n+1;sum=n; getroot(1,0); solve(root); printf("%lld ",ans); return 0; }