题目
Description
一棵树上有 n 个节点,编号分别为 1 到 n,每个节点都有一个权值 w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t
:把结点 u 的权值改为 t。
II. QMAX u v
: 询问从点 u 到点 v 的路径上的节点的最大权值。
III. QSUM u v
: 询问从点 u 到点 v 的路径上的节点的权值和。
注意:从点 u 到点 v 的路径上的节点包括 u 和 v 本身。
Input
输入文件的第一行为一个整数 n,表示节点的个数。
接下来 n-1 行,每行 2 个整数 a 和 b,表示节点 aa 和节点 bb 之间有一条边相连。
接下来一行 n 个整数,第 i 个整数 wi 表示节点 i 的权值。
接下来 1 行,为一个整数 q,表示操作的总数。
接下来 q 行,每行一个操作,以 CHANGE u t
或者 QMAX u v
或者 QSUM u v
的形式给出。
Output
对于每个 QMAX
或者 QSUM
的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4
Sample Output
4 1 2 2 10 6 5 6 5 16
思路
这是一道对于刚学树链剖分的读者们,的很好的锻炼题;
有些不懂的读者可以看看
没什么思路技巧,就是将树分成若干条链,然后将这些链储存在线段树中;
首先要将树,分成重边和轻边,再将连续的重边组成重链,连续的轻边组成轻链;
再将这些链上的点,附上不同的编号,放在线段树中;
那么还是具体上代码吧,代码注释会将得很清楚的;
代码
#pragma GCC optimize(3) #pragma GCC target("avx") #pragma GCC optimize("Ofast") #pragma GCC optimize("inline") #pragma GCC optimize("-fgcse") #pragma GCC optimize("-fgcse-lm") #pragma GCC optimize("-fipa-sra") #pragma GCC optimize("-ftree-pre") #pragma GCC optimize("-ftree-vrp") #pragma GCC optimize("-fpeephole2") #pragma GCC optimize("-ffast-math") #pragma GCC optimize("-fsched-spec") #pragma GCC optimize("unroll-loops") #pragma GCC optimize("-falign-jumps") #pragma GCC optimize("-falign-loops") #pragma GCC optimize("-falign-labels") #pragma GCC optimize("-fdevirtualize") #pragma GCC optimize("-fcaller-saves") #pragma GCC optimize("-fcrossjumping") #pragma GCC optimize("-fthread-jumps") #pragma GCC optimize("-funroll-loops") #pragma GCC optimize("-fwhole-program") #pragma GCC optimize("-freorder-blocks") #pragma GCC optimize("-fschedule-insns") #pragma GCC optimize("inline-functions") #pragma GCC optimize("-ftree-tail-merge") #pragma GCC optimize("-fschedule-insns2") #pragma GCC optimize("-fstrict-aliasing") #pragma GCC optimize("-fstrict-overflow") #pragma GCC optimize("-falign-functions") #pragma GCC optimize("-fcse-skip-blocks") #pragma GCC optimize("-fcse-follow-jumps") #pragma GCC optimize("-fsched-interblock") #pragma GCC optimize("-fpartial-inlining") #pragma GCC optimize("no-stack-protector") #pragma GCC optimize("-freorder-functions") #pragma GCC optimize("-findirect-inlining") #pragma GCC optimize("-fhoist-adjacent-loads") #pragma GCC optimize("-frerun-cse-after-loop") #pragma GCC optimize("inline-small-functions") #pragma GCC optimize("-finline-small-functions") #pragma GCC optimize("-ftree-switch-conversion") #pragma GCC optimize("-foptimize-sibling-calls") #pragma GCC optimize("-fexpensive-optimizations") #pragma GCC optimize("-funsafe-loop-optimizations") #pragma GCC optimize("inline-functions-called-once") #pragma GCC optimize("-fdelete-null-pointer-checks") #pragma GCC optimize(2) //为什么要加这么一大堆,我只是为了让看起来冗长的代码,变得更长 #include<bits/stdc++.h>//头文件 #define re register//宏定义 typedef long long ll; using namespace std; inline ll read()//快读 { ll a=0,f=1; char c=getchar();//a 是数字大小,f 是判正负 //???为什么快读也要写注释 while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();} while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();} return a*f; } ll n,m; ll w[200010];//w 记录权值 ll head[200010]; ll size[200010],dep[200010],top[200010]; //size 记录子树 节点个数 ,dep 记录深度, top 记录这条链的顶部 ll id[200010],aa[200010];//id 是在线段树中的编号,aa 是在线段树中的权值 ll f[200010],son[200010];// f 是父节点,son 是重子节点 struct ljj { ll to,stb; }e[200010];//to 表示这条边到达的点,stb 表示上一条边 struct ljq { ll l,r,mx,v; }a[200010];//线段树基本变量 inline ll L(ll x) { return 2*x; }//线段树中左儿子的编号 inline ll R(ll x) { return 2*x+1; }//线段树中右儿子的编号 ll s=0; inline void insert(ll x,ll y) { s++; e[s].stb=head[x]; e[s].to=y; head[x]=s; }//前向星连边 inline void dfs(ll x,ll fa)//找重子节点 { size[x]=1; f[x]=fa;//记录父节点 for(re ll i=head[x];i;i=e[i].stb) { ll xx=e[i].to; if(xx==fa)//不能遍历到父节点 continue; dep[xx]=dep[x]+1;//统计深度 dfs(xx,x); size[x]+=size[xx];//统计子树节点数 if(!son[x]||size[xx]>size[son[x]]) son[x]=xx;//找重子节点,也就是子树节点数最多的子节点 } } ll tot=0;//统计在线段树中的编号 inline void DFS(ll x,ll t)//t 表示这条链的顶部 { top[x]=t;//记录 id[x]=++tot;//记录在线段树中的编号 aa[tot]=w[x];//记录在线段树中的权值 if(!son[x])//如果没有重子节点 return;//返回 DFS(son[x],t);//先遍历重子节点 for(re ll i=head[x];i;i=e[i].stb) { ll xx=e[i].to; if(xx==f[x]||xx==son[x])//遍历轻子节点 continue; DFS(xx,xx);//每个开始的轻子节点的链顶就是自己 } } inline void doit(ll p)//维护区间 { a[p].v=a[L(p)].v+a[R(p)].v;//sum和 a[p].mx=max(a[L(p)].mx,a[R(p)].mx);//最大值 } inline void build(ll p,ll l,ll r)//建树 { a[p].l=l; a[p].r=r; if(l==r) { a[p].v=aa[l]; a[p].mx=aa[l]; return; } ll mid=(l+r)>>1; build(L(p),l,mid); build(R(p),mid+1,r); doit(p); } inline void change(ll p,ll x,ll y)//单点修改 { if(a[p].l==a[p].r) { a[p].v=y; a[p].mx=y; return; } ll mid=(a[p].l+a[p].r)>>1; if(x<=mid) change(L(p),x,y); else change(R(p),x,y); doit(p); } inline ll findsum(ll p,ll l,ll r)//找区间sum { if(l<=a[p].l&&a[p].r<=r) return a[p].v; ll sum=0; ll mid=(a[p].l+a[p].r)>>1; if(l<=mid) sum+=findsum(L(p),l,r); if(r>mid) sum+=findsum(R(p),l,r); return sum; } inline ll qsum(ll x,ll xx) { ll sum=0; while(top[x]!=top[xx])//我们需要是 x 节点跳到与 xx 节点在同一条链上 { if(dep[top[x]]<dep[top[xx]])//深度大的往上跳 swap(x,xx); sum+=findsum(1,id[top[x]],id[x]);//统计 x 到链顶的 sum x=f[top[x]];// 跳到下一个区间,也就是在 top[x] 上面的链 } if(dep[x]<dep[xx]) swap(x,xx); sum+=findsum(1,id[xx],id[x]);//在统计下 x 到 xx 的区间sum //此时 x 与 xx 是在同一条链上 return sum; } inline ll findmax(ll p,ll l,ll r)//区间最大值 { if(l<=a[p].l&&a[p].r<=r) return a[p].mx; ll sum=-(1<<30); ll mid=(a[p].l+a[p].r)>>1; if(l<=mid) sum=max(sum,findmax(L(p),l,r)); if(r>mid) sum=max(sum,findmax(R(p),l,r)); return sum; } inline ll qmax(ll x,ll xx) { ll sum=-(1<<30); while(top[x]!=top[xx])//我们需要是 x 节点跳到与 xx 节点在同一条链上 { if(dep[top[x]]<dep[top[xx]])//深度大的往上跳 swap(x,xx); sum=max(sum,findmax(1,id[top[x]],id[x]));//统计 x 到链顶的 最大值 x=f[top[x]];// 跳到下一个区间,也就是在 top[x] 上面的链 } if(dep[x]<dep[xx]) swap(x,xx); sum=max(sum,findmax(1,id[xx],id[x]));//在统计下 x 到 xx 的区间最大值 //此时 x 与 xx 是在同一条链上 return sum; } int main() { n=read();//读入 for(re ll i=1;i<n;i++) { ll x=read(),y=read(); insert(x,y); insert(y,x);//连边 } for(re ll i=1;i<=n;i++) w[i]=read();//读入 dfs(1,0);//找重子节点 DFS(1,1);//分成链 build(1,1,n);//建树 m=read(); for(re ll i=1;i<=m;i++) { char c[5]; scanf("%s",c); if(c[1]=='H') { ll x=read(),y=read(); change(1,id[x],y);//单点修改 } else if(c[1]=='S') { ll x=read(),y=read(); ll ans=qsum(x,y);//求区间sum printf("%lld ",ans); } else { ll x=read(),y=read(); ll ans=qmax(x,y);//求区间最大值 printf("%lld ",ans); } } //return 0; }