Description
这是一道模板题。
给定一棵 n 个节点的树,初始时该树的根为 1 号节点,每个节点有一个给定的权值。下面依次进行 m 个操作,操作分为如下五种类型:
换根:将一个指定的节点设置为树的新根。
修改路径权值:给定两个节点,将这两个节点间路径上的所有节点权值(含这两个节点)增加一个给定的值。
修改子树权值:给定一个节点,将以该节点为根的子树内的所有节点权值增加一个给定的值。
询问路径:询问某条路径上节点的权值和。
询问子树:询问某个子树内节点的权值和。
Input
第一行为一个整数 nnn,表示节点的个数。
第二行 nnn 个整数表示第 iii 个节点的初始权值 ai 。
第三行 n−1 个整数,表示 i+1号节点的父节点编号 fi+1 (1⩽fi+1⩽n)。
第四行一个整数 m,表示操作个数。
接下来 m 行,每行第一个整数表示操作类型编号:(1⩽u,v⩽n)
若类型为 1,则接下来一个整数 u,表示新根的编号。
若类型为 2,则接下来三个整数 u,v,k,分别表示路径两端的节点编号以及增加的权值。
若类型为 3,则接下来两个整数 u,k,分别表示子树根节点编号以及增加的权值。
若类型为 4,则接下来两个整数 u,v,表示路径两端的节点编号。
若类型为 5,则接下来一个整数 u,表示子树根节点编号。
Output
对于每一个类型为 4 或 5 的操作,输出一行一个整数表示答案。
思路
- 重构两遍40分,后发现getson函数写错,正确:
if(fa[top[u]]==v) return top[u];
错误:if(fa[u]==v) return u;
- 树剖换根相关: root(当前根),x(询问),lca最近公共祖先,以询问总和sum为例:
- 1. root==x:sum=整棵树之和;
- 2. root在x的子树中(lca(root,x)==x):sum=整棵树之和-以(x~root路径上靠近x的点)为根的子树;
- 3. root不在x的子树中(lca(root,x)!=x):sum=以1为根(原树)时x的子树;
- 以1为根时x的子树在线段树中为连续的一段区间(num[x],num[x]+siz[x]-1);
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#define int long long
#define maxn 100005
using namespace std;
int n,m,cnt,head[maxn],val[maxn],root;
int fa[maxn],top[maxn],deep[maxn],siz[maxn],son[maxn],num[maxn],fnum[maxn];
struct node{int next,to;}e[maxn<<1];
struct fdfdfd{int l,r,flag,sum,len;}a[maxn<<2];
void addedge(int x,int y){e[++cnt].to=y; e[cnt].next=head[x]; head[x]=cnt;}
void dfs_1(int u)
{
deep[u]=deep[fa[u]]+1; siz[u]=1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to; dfs_1(v); siz[u]+=siz[v];
if(son[u]==-1||siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs_2(int u,int topp)
{
top[u]=topp; num[u]=++cnt; fnum[cnt]=u;
if(son[u]!=-1) dfs_2(son[u],topp);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v!=son[u]) dfs_2(v,v);
}
}
void pushup(int x){a[x].sum=a[x<<1].sum+a[x<<1|1].sum;}
void pushdown(int x)
{
if(a[x].flag==0) return;
a[x<<1].flag+=a[x].flag; a[x<<1].sum+=a[x<<1].len*a[x].flag;
a[x<<1|1].flag+=a[x].flag; a[x<<1|1].sum+=a[x<<1|1].len*a[x].flag;
a[x].flag=0;
}
void build(int x,int left,int right)
{
a[x].l=left; a[x].r=right; a[x].len=right-left+1;;
if(left==right) {a[x].sum=val[fnum[left]]; return;}
int mid=(left+right)>>1;
build(x<<1,left,mid); build(x<<1|1,mid+1,right);
pushup(x);
}
void modify(int x,int left,int right,int d)
{
if(a[x].r<left||a[x].l>right) return;
if(left<=a[x].l&&right>=a[x].r) {a[x].flag+=d,a[x].sum+=a[x].len*d; return;}
pushdown(x);
modify(x<<1,left,right,d); modify(x<<1|1,left,right,d);
pushup(x);
}
void change_uv(int u,int v,int k)
{
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
modify(1,num[top[u]],num[u],k);
u=fa[top[u]];
}
if(deep[u]>deep[v]) swap(u,v);
modify(1,num[u],num[v],k);
}
int query(int x,int left,int right)
{
if(a[x].r<left||a[x].l>right) return 0;
if(left<=a[x].l&&right>=a[x].r) return a[x].sum;
pushdown(x);
return query(x<<1,left,right)+query(x<<1|1,left,right);
}
int ask_uv(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
ans+=query(1,num[top[u]],num[u]);
u=fa[top[u]];
}
if(deep[u]>deep[v]) swap(u,v);
ans+=query(1,num[u],num[v]);
return ans;
}
int getlca(int u,int v)
{
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return deep[u]<deep[v]?u:v;
}
int getson(int u,int v)
{
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
if(fa[top[u]]==v) return top[u];
u=fa[top[u]];
}
return deep[u]<deep[v]?son[u]:son[v];
}
void change_u(int u,int k)
{
int lca=getlca(u,root),son;
if(u==root) return modify(1,1,n,k);
if(u!=lca) return modify(1,num[u],num[u]+siz[u]-1,k);
if(u==lca) modify(1,1,n,k),son=getson(u,root),modify(1,num[son],num[son]+siz[son]-1,-k);
}
int ask_u(int u)
{
int lca=getlca(u,root),son,ans=0;
if(u==root) return query(1,1,n);
if(u!=lca) return query(1,num[u],num[u]+siz[u]-1);
if(u==lca) ans=query(1,1,n),son=getson(u,root),ans-=query(1,num[son],num[son]+siz[son]-1);
return ans;
}
signed main()
{
memset(son,-1,sizeof(son));
scanf("%lld",&n);
for(int i=1;i<=n;++i) scanf("%lld",&val[i]);
for(int i=1,u;i<n;++i) scanf("%lld",&u),fa[i+1]=u,addedge(u,i+1);
dfs_1(1); cnt=0; dfs_2(1,1); build(1,1,n); root=1;
scanf("%lld",&m);
while(m--)
{
int op,u,v,k; scanf("%lld%lld",&op,&u);
if(op==1) root=u;
else if(op==2) scanf("%lld%lld",&v,&k),change_uv(u,v,k);
else if(op==3) scanf("%lld",&k),change_u(u,k);
else if(op==4) scanf("%lld",&v),printf("%lld
",ask_uv(u,v));
else printf("%lld
",ask_u(u));
}
return 0;
}