解题思路
树链剖分,把dfs序用线段树维护,记录一个最大值,记录一个和,两个函数。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
using namespace std;
const int MAXN=30005;
inline int rd(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) {f=ch=='-'?0:1;ch=getchar();}
while(isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return f?x:-x;
}
int n,m,num,cnt,head[MAXN],to[MAXN<<1],nxt[MAXN<<1],a[MAXN];
int sum[MAXN<<2],lazy[MAXN<<2],fa[MAXN],siz[MAXN],son[MAXN],dep[MAXN];
int id[MAXN],top[MAXN],Max[MAXN<<2],w[MAXN];
inline void add(int bg,int ed){
to[++cnt]=ed,nxt[cnt]=head[bg],head[bg]=cnt;
}
void dfs1(int x,int f,int d){
dep[x]=d,fa[x]=f,siz[x]=1;
int maxson=-1;
for(register int i=head[x];i;i=nxt[i]){
int u=to[i];if(u==f) continue;
dfs1(u,x,d+1);
siz[x]+=siz[u];
if(siz[u]>maxson) {maxson=siz[u];son[x]=u;}
}
}
void dfs2(int x,int topf){
top[x]=topf,id[x]=++num,w[num]=a[x];
if(!son[x]) return;
dfs2(son[x],topf);
for(register int i=head[x];i;i=nxt[i]){
int u=to[i];if(u==son[x] || u==fa[x]) continue;
dfs2(u,u);
}
}
//--------------------xds----------------------
inline void pushup(int x){
sum[x]=sum[x<<1]+sum[x<<1|1];
Max[x]=max(Max[x<<1],Max[x<<1|1]);
}
void build(int x,int l,int r){
if(l==r){
sum[x]=w[l];
Max[x]=w[l];
return;
}
int mid=l+r>>1;
build(x<<1,l,mid);
build(x<<1|1,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,int L,int R,int k){
if(L<=l && r<=R){
sum[x]=k;
Max[x]=k;
return;
}
int mid=l+r>>1;
if(L<=mid) update(x<<1,l,mid,L,R,k);
if(R>mid) update(x<<1|1,mid+1,r,L,R,k);
pushup(x);
}
int query_Sum(int x,int l,int r,int L,int R){
if(L<=l && r<=R) return sum[x];
int mid=l+r>>1,ret=0;
if(L<=mid) ret+=query_Sum(x<<1,l,mid,L,R);
if(R>mid) ret+=query_Sum(x<<1|1,mid+1,r,L,R);
return ret;
}
int query_Max(int x,int l,int r,int L,int R){
if(L<=l && r<=R) return Max[x];
int mid=l+r>>1,ret=-30001;
if(L<=mid) ret=max(ret,query_Max(x<<1,l,mid,L,R));
if(R>mid) ret=max(ret,query_Max(x<<1|1,mid+1,r,L,R));
return ret;
}
//---------------------------------------------------------
int qMax(int x,int y){
int ret=-30001;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ret=max(ret,query_Max(1,1,n,id[top[x]],id[x]));
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ret=max(ret,query_Max(1,1,n,id[x],id[y]));
return ret;
}
int qSum(int x,int y){
int ret=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ret+=query_Sum(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ret+=query_Sum(1,1,n,id[x],id[y]);
return ret;
}
int main(){
n=rd();
int x,y;char c[15];
for(register int i=1;i<n;i++){
x=rd(),y=rd();
add(x,y),add(y,x);
}
for(register int i=1;i<=n;i++) a[i]=rd();
dfs1(1,0,1),dfs2(1,1);
build(1,1,n);
m=rd();
while(m--){
scanf("%s",c+1);
if(c[2]=='M'){
x=rd(),y=rd();
printf("%d
",qMax(x,y));
}
else if(c[2]=='S'){
x=rd(),y=rd();
printf("%d
",qSum(x,y));
}
else {
x=rd(),y=rd();
update(1,1,n,id[x],id[x],y);
}
}
return 0;
}