Query on the tree
时间限制:1s 内存限制: 65536K
问题描述
度度熊最近沉迷在和树有关的游戏了,他一直认为树是最神奇的数据结构。一天他遇到这样一个问题:
有一棵树,树的每个点有点权,每次有三种操作:
1. Query x 表示查询以x为根的子树的权值和。
2. Change x y 表示把x点的权值改为y。
3. Root x 表示把x变为根。
现在度度熊想请更聪明的你帮助解决这个问题。
输入
第一行为数据组数T
每组数据第一行为N ,表示树的节点数。
后面 行每行有两个数 ,表示 之间有一条边 。初始时树是以1号节点为根节点。
之后的一行为 个数表示这 个点的点权。
然后为整数Q为操作次数。
之后的Q行为描述中的三种操作。
输出
对于第k组输入数据,第一行输出Case #k接下来对于每个”Queryx”操作,输出以x为根的子数和。
样例输入
2
5
1 2
1 3
3 4
3 5
1 2 3 4 5
5
Query 1
Change 3 10
Query 1
Root 4
Query 3
8
1 2
1 3
3 4
4 5
5 6
5 7
4 8
1 2 3 4 5 6 7 8
5
Query 1
Query 3
Root 5
Query 3
Query
1样例输出
Case #1:
15
22
18
Case #2:
36
33
6
3
解题报告:
树上的查询一共有三种操作,如果只是考虑前两种操作Query和Change,则需要一种高效的数据结构来支持子树和的查询和更新。
采取类似于LCA在线算法的方式,先DFS得到树上顶点序列,通过顶点序列可以知道每颗子树的范围。利用树状数组的方式来记录每个顶点的权值的变更。那么Query和Change的时间复杂度都为O(lgn)
当引入第三种操作-变更根节点后,并不需要调整树的结构,而只是要在查询操作的时候做些处理。
如果Query x的x为当前root节点,则直接输出当前所有节点的权值和,可以用一个变量SumOfAllTree来记录整颗树所有节点的权值和,查询为O(1)的复杂度
如果x在原树上不为当前root节点的祖先,即lca(x,root)
!= x,那么直接输出x节点所在的子树和SubTree(x)。
否则可找出root到x这条路径上x的儿子节点y,那么在当前root的条件下x对应的子树的和为SumOfAllTree-SubTree(y)。
求lca和求y的时间复杂度都可以做到O(lgn)。因此加上预处理后整体的时间复杂度为O(nlgn
+ Qlgn)
解题代码:
#include "iostream" #include "cstring" #include "cstdio" #include "vector" #define F first #define S second #define PB push_back #define MP make_pair using namespace std; const int N = 10010; const int D = 20; vector<int>e[N]; int go[N][D],depth[N],l[N],r[N]; int time_stamp; void dfs(int u,int p) { depth[u]=p==-1?0:depth[p]+1; go[u][0]=p; l[u]=++time_stamp; for(int i=0;go[u][i]!=-1;i++){ go[u][i+1]=go[go[u][i]][i]; } for(int i=0;i<e[u].size();i++){ int v=e[u][i]; if(v!=p){ dfs(v,u); } } r[u]=time_stamp; } int jump(int u,int d) { for(int i=D-1;i>=0;i--){ if(d>=(1<<i)){ u=go[u][i]; d-=1<<i; } } return u; } int lca(int u,int v) { if(depth[u]<depth[v]){ swap(u,v); } u=jump(u,depth[u]-depth[v]); for(int i=D-1;i>=0;i--){ if(go[u][i]!=go[v][i]){ u=go[u][i]; v=go[v][i]; } } return u==v?u:go[u][0]; } int lowbit(int x) { return x&(-x); } char com[20]; int n; int val[N],sumroot[N],sumroad[N]; void init(int n) { time_stamp=0; memset(sumroot,0,sizeof(sumroot)); memset(sumroad,0,sizeof(sumroad)); memset(go, -1, sizeof(go)); for(int i=1;i<=n;i++){ e[i].clear(); } } void update(int a[],int x,int v) { while(x<=n){ a[x]+=v; x+=lowbit(x); } } int get(int a[],int x) { int sum=0; while(x>0){ sum+=a[x]; x-=lowbit(x); } return sum; } int getsum(int a[],int l,int r) { return get(a,r)-get(a,l-1); } int main() { int KK = 1; int T,Q,x,y,root,allval; scanf("%d",&T); while(T--){ scanf("%d",&n); init(n); root=1; allval=0; for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); e[x].PB(y); e[y].PB(x); } dfs(1,-1); for(int i=1;i<=n;i++){ scanf("%d",&val[i]); update(sumroad,l[i],val[i]); update(sumroad,r[i]+1,-val[i]); update(sumroot,l[i],val[i]); allval+=val[i]; } printf("Case #%d: ", KK++); scanf("%d",&Q); while(Q--){ scanf("%s",com); if(com[0]=='Q'){ scanf("%d",&x); if(x==root){ printf("%d ",allval); }else if(lca(x,root)!=x){ printf("%d ",getsum(sumroot,l[x],r[x])); }else{ int tmp=jump(root,depth[root]-depth[x]-1); printf("%d ",allval-getsum(sumroot,l[tmp],r[tmp])); } }else if(com[0]=='C'){ scanf("%d%d",&x,&y); update(sumroad,l[x],y-val[x]); update(sumroad,r[x]+1,val[x]-y); update(sumroot,l[x],y-val[x]); allval+=y-val[x]; val[x]=y; }else{ scanf("%d",&root); } } } return 0; }