[BZOJ3697]采药人的路径(点分治+树形dp)
题面
采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
分析
显然可以点分治。设阴性和阳性的边权分别为1和-1,那么平衡就是路径长度为0
考虑点分治的时候dfs每一颗分治子树,用(mark_i)来标记当前节点到分治中心的路径上距离的出现情况,如果(mark_{deep_x}=1),就说明(x)与某个祖先到中心出现过一个相同的距离,也就是说这一段上一定存在休息点。
然后考虑合并分治中心出来的若干子树。设(g_{0/1,i})表示当前子树中子树中距离为i的路径数目,0/1表示是否有休息点。合并的时候类似树形dp,为了防止一个子树内重复合并,要再定义一个数组(f_{0/1,i})表示当前已经合并过子树中距离为i的路径数目.每新来一个子树,先dfs一遍算出g,再把g和f合并。
每棵子树对答案的贡献是
[egin{aligned}&g_{0,0} imes (f_{0,0}-1) ext{(两条路径都平衡,分治中心是休息站,-1是去掉路径经过点数为0的情况)} \ &+ sum_{i}g_{1,-i} imes f_{1,i}+ g_{0,-i} imes f_{1,i}+ g_{1,-i} imes f_{0,i} ext{两条路径和为0,平衡,且至少一条路径上存在休息站}end{aligned}
]
实现上要注意负下标的问题
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 200000
using namespace std;
typedef long long ll;
int n;
struct edge{
int from;
int to;
int len;//两种边,距离分别为-1和1
int next;
}E[maxn*2+5];
int esz=1;
int head[maxn*2+5];
void add_edge(int u,int v,int w){
esz++;
E[esz].from=u;
E[esz].to=v;
E[esz].len=w;
E[esz].next=head[u];
head[u]=esz;
}
bool vis[maxn+5];
int sz[maxn+5];
int ms[maxn+5];
void dfs_root(int x,int fa,int tot_sz,int &root){
sz[x]=1;
ms[x]=0;
for(int i=head[x];i;i=E[i].next){
int y=E[i].to;
if(y!=fa&&!vis[y]){
dfs_root(y,x,tot_sz,root);
sz[x]+=sz[y];
ms[x]=max(ms[x],sz[y]);
}
}
ms[x]=max(ms[x],tot_sz-sz[x]);
if(ms[x]<ms[root]) root=x;
}
int get_root(int x,int tot_sz){
int root=0;
ms[0]=tot_sz+1;
dfs_root(x,0,tot_sz,root);
return root;
}
struct marray{//支持负数下标的数组
ll cnt[maxn*2+5];
inline ll & operator [] (int id){
return cnt[id+maxn];
}
};
marray f[2];//当前已经合并的子树中距离为i的路径数目,0/1表示是否有休息点
marray g[2];//当前子树中子树中距离为i的路径数目
marray mark;//标记某个距离是否出现
int maxd=0;
void get_dist(int x,int fa,int dist){
maxd=max(maxd,abs(dist));
if(mark[dist]) g[1][dist]++;//祖先中有距离为dist的点,那么当前节点与那个祖先之间必有一个休息点
else g[0][dist]++;
mark[dist]++;
for(int i=head[x];i;i=E[i].next){
int y=E[i].to;
if(y!=fa&&!vis[y]){
get_dist(y,x,dist+E[i].len);
}
}
mark[dist]--;
}
ll ans=0;
void solve(int x){
vis[x]=1;
f[0][0]=1;
int mmaxd=0;
for(int p=head[x];p;p=E[p].next){
int y=E[p].to;
if(!vis[y]){
maxd=0;
get_dist(y,x,E[p].len);
// mmind=min(mind,mmind);
mmaxd=max(maxd,mmaxd);
ans+=g[0][0]*(f[0][0]-1);//两边距离为0,说明两条路径都平衡,休息点在x
//-1是去掉路径经过点数为0的情况,即一开始的f[0][0]=1
for(int i=-mmaxd;i<=maxd;i++) ans+=g[1][-i]*f[1][i]+g[0][-i]*f[1][i]+g[1][-i]*f[0][i];
//两边距离总和为0,整条路径平衡,且至少一条路径上有休息点
for(int i=-maxd;i<=maxd;i++){//合并子树
f[0][i]+=g[0][i];
f[1][i]+=g[1][i];
g[0][i]=g[1][i]=0;
}
}
}
for(int i=-mmaxd;i<=mmaxd;i++) f[0][i]=f[1][i]=0;
for(int i=head[x];i;i=E[i].next){
int y=E[i].to;
if(!vis[y]) solve(get_root(y,sz[y]));
}
}
int main(){
int u,v,w;
scanf("%d",&n);
for(int i=1;i<n;i++){
scanf("%d %d %d",&u,&v,&w);
if(w==1){
add_edge(u,v,1);
add_edge(v,u,1);
}else{
add_edge(u,v,-1);
add_edge(v,u,-1);
}
}
solve(get_root(1,n));
printf("%lld
",ans);
}