题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入输出格式
输入格式:
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出格式:
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
输入输出样例
说明
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Solution:
树剖的模板题。。。
总结一波错误:若比较函数$Max$是宏定义,且比较的两个变量中含有函数,那么不要用宏定义的$Max$,因为这样函数会运行两次,白白的浪费时间。
代码:
#include<bits/stdc++.h> #define il inline #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define For(i,a,b) for(int (i)=(a);(i)<=(b);(i)++) #define Swap(a,b) ((a)^=(b),(b)^=(a),(a)^=(b)) using namespace std; const int N=100005,inf=233333333; int n,q,cnt,h[N],a[N]; struct node{ int to,net,w; }e[N]; int size[N],wson[N],fa[N],dep[N],top[N],pos[N],pre[N],tot; il int gi(){ int a=0;char x=getchar();bool f=0; while((x<'0'||x>'9')&&x!='-')x=getchar(); if(x=='-')x=getchar(),f=1; while(x>='0'&&x<='9')a=(a<<3)+(a<<1)+x-48,x=getchar(); return f?-a:a; } il void add(int u,int v){ e[++cnt].to=v,e[cnt].net=h[u],h[u]=cnt; e[++cnt].to=u,e[cnt].net=h[v],h[v]=cnt; } il void dfs1(int u,int f){ size[u]=1; for(int i=h[u];i;i=e[i].net){ int v=e[i].to; if(v==f)continue; dep[v]=dep[u]+1;fa[v]=u; dfs1(v,u); size[u]+=size[v]; if(size[v]>size[wson[u]])wson[u]=v; } } il void dfs2(int u,int op){ pos[u]=++tot;pre[tot]=u;top[u]=op; if(wson[u])dfs2(wson[u],op); for(int i=h[u];i;i=e[i].net){ int v=e[i].to; if(v==fa[u]||v==wson[u])continue; dfs2(v,v); } } int sum[N<<2],maxn[N<<2]; il void pushup(int rt){ sum[rt]=sum[rt<<1]+sum[rt<<1|1]; maxn[rt]=max(maxn[rt<<1],maxn[rt<<1|1]); } il void build(int l,int r,int rt){ if(l==r){sum[rt]=maxn[rt]=a[pre[l]];return;} int m=l+r>>1; build(lson),build(rson); pushup(rt); } il void update(int k,int v,int l,int r,int rt){ if(l==r){sum[rt]=maxn[rt]=v;return;} int m=l+r>>1; if(k<=m)update(k,v,lson); else update(k,v,rson); pushup(rt); } il int query1(int L,int R,int l,int r,int rt){ if(L<=l&&R>=r)return sum[rt]; int m=l+r>>1,ret=0; if(L<=m)ret+=query1(L,R,lson); if(R>m)ret+=query1(L,R,rson); return ret; } il int query2(int L,int R,int l,int r,int rt){ if(L<=l&&R>=r)return maxn[rt]; int m=l+r>>1,tmp=-inf; if(L<=m)tmp=max(tmp,query2(L,R,lson)); if(R>m)tmp=max(tmp,query2(L,R,rson)); return tmp; } il int getsum(int u,int v){ int ans=0; while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]])Swap(u,v); ans+=query1(pos[top[u]],pos[u],1,n,1); u=fa[top[u]]; } if(dep[u]<dep[v])Swap(u,v); ans+=query1(pos[v],pos[u],1,n,1); return ans; } il int getmax(int u,int v){ int tmp=-inf; while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]])Swap(u,v); tmp=max(tmp,query2(pos[top[u]],pos[u],1,n,1)); u=fa[top[u]]; } if(dep[u]<dep[v])Swap(u,v); tmp=max(tmp,query2(pos[v],pos[u],1,n,1)); return tmp; } int main(){ n=gi(); int u,v;char s[10]; For(i,1,n-1)u=gi(),v=gi(),add(u,v); For(i,1,n)a[i]=gi(); dep[1]=1,fa[1]=1; dfs1(1,-1);dfs2(1,1); build(1,n,1); q=gi(); while(q--){ scanf("%s",s),u=gi(),v=gi(); if(s[1]=='H')update(pos[u],v,1,n,1); if(s[1]=='M')printf("%d ",getmax(u,v)); if(s[1]=='S')printf("%d ",getsum(u,v)); } return 0; }