【问题描述】
给一包含 N 个节点的树,以 1 号点为根。 求 u 号点的子树上,在子树中深度为 c 的点的权值和(u 号点深度为 0)。 包含权值修改操作。
【输入格式】
输入共 N+Q+2 行。
第 1 行包含 1 个正整数 N,表示共有 N 个节点。
第 2 行包含 N 个由空格隔开的非负的整数 V1, V2, V3 ... VN,其中 Vi 表示编号为 i 的点的权值。
第 2 +(1) 至 2 +(N-1) 行,每行包含 2 个正整数 fr,to,表示一边连接 fr,to 两点。
第 N+2 行包含 1 个正整数 Q,表示接下来共有 Q 个 修改 / 询问。
第 N+2 +(1) 至 N+2 +(Q) 行,每行包含 1 个 修改 / 询问。
修改 格式:
1 u v
表示将 u 号点的权值改为 v。
询问 格式:
2 u c 求 u 号点的子树上,在子树中深度为 c 的点的权值和(u 号点深度为 0)。
【输出格式】
输出多行(<= Q)。
每行包含 1 个非负的整数,表示其对应 询问 的答案,按照输入的先后顺序依次回答。
【样例输入】
7 1 1 1 1 1 1 1 1 2 1 3 2 7 3 4 3 6 4 5 7 2 1 1 2 1 2 2 3 6 1 4 5 2 4 0 2 1 2 2 3 1
【样例输出】
2 3 0 5 7 6
【数据规模与约定】
对于全部测试点:
编号为 1 的点是根; fa 是 to 的父亲节点; to > fa; to - fa <= 50
0 <= Vi <= 3000; 0 <= v <= 3000
c < N
题解:
这个题目先将我的做法吧!
我们对于每个点求出他的dfn,deep,和子树内节点的最大的dfn(记为ed)。
我们对于每个点,将他们按照dep作为第一关键字,dfn作为第二关键子排序,对于deep相同的点,他们在线段树中的位置一定连续,并且满足dfn递增的,我们把这些节点按顺序插到线段树之中去,那么我们要统计的一定是连续的一段区间,所以对于一个点我们要判断他十分在子树内只需要判断他的dfn十分>=dfn[x]&&<=ed[x],所以我们二分出满足条件的区间的左右断点,线段树查询就行了,修改是一样的。
还有一个比较好想的做法,就是开深度棵线段树,查询也是一样的,只不过要动态开点,代码是LJ的,也提供作为参考。
代码:算法1
#include <cstdio> #include <iostream> #include <algorithm> #include <cstring> #include <cmath> #include <iostream> #define MAXN 300100 using namespace std; struct edge{ int first,next,to; }a[MAXN*2]; struct tree{ int l,r,zhi; }tr[MAXN*4]; struct no{ int dfn,dep,v; }node[MAXN*2]; int dep[MAXN],dfn[MAXN],ed[MAXN],v[MAXN]; int depl[MAXN],depr[MAXN]; int n,q,num=0; void addedge(int from,int to){ a[++num].to=to; a[num].next=a[from].first; a[from].first=num; } void dfs1(int now,int f){ dep[now]=dep[f]+1,dfn[now]=++num; for(int i=a[now].first;i;i=a[i].next){ int to=a[i].to; if(to==f) continue; dfs1(to,now); ed[now]=max(ed[now],ed[to]); } ed[now]=max(ed[now],dfn[now]); } bool cmp(no x,no y){ if(x.dep==y.dep) return x.dfn<y.dfn; return x.dep<y.dep; } int erfen(int hh,int l,int r){ int mid,ans=-1; while(l<=r){ mid=(l+r)/2; if(node[mid].dfn>=hh) ans=mid,r=mid-1; else l=mid+1; } return ans; } int erfen2(int l,int r,int minn,int maxx){ int mid,ans=-1; while(l<=r){ mid=(l+r)/2; if(node[mid].dfn<minn) l=mid+1; else if(node[mid].dfn>maxx) r=mid-1; else ans=mid,r=mid-1; } return ans; } int erfen3(int l,int r,int minn,int maxx){ int mid,ans=-1; while(l<=r){ mid=(l+r)/2; if(node[mid].dfn<minn) l=mid+1; else if(node[mid].dfn>maxx) r=mid-1; else ans=mid,l=mid+1; } return ans; } void build(int xv,int l,int r){ if(l==r){ tr[xv].l=tr[xv].r=l; tr[xv].zhi=node[l].v; return; } tr[xv].l=l,tr[xv].r=r; int mid=(l+r)/2; build(xv*2,l,mid),build(xv*2+1,mid+1,r); tr[xv].zhi=tr[xv*2].zhi+tr[xv*2+1].zhi; } void change(int xv,int ps,int x){ int l=tr[xv].l,r=tr[xv].r,mid=(l+r)/2; if(l==r){ tr[xv].zhi=x; return; } if(ps<=mid) change(xv*2,ps,x); else change(xv*2+1,ps,x); tr[xv].zhi=tr[xv*2].zhi+tr[xv*2+1].zhi; } int query(int xv,int l,int r){ int L=tr[xv].l,R=tr[xv].r,mid=(L+R)/2; if(L==l&&r==R){ return tr[xv].zhi; } if(r<=mid) return query(xv*2,l,r); else if(l>mid) return query(xv*2+1,l,r); else return query(xv*2,l,mid)+query(xv*2+1,mid+1,r); } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&v[i]); for(int i=1;i<=n-1;i++){ int x,y;scanf("%d%d",&x,&y); addedge(x,y),addedge(y,x); } num=0;dfs1(1,0); for(int i=1;i<=n;i++) node[i].dfn=dfn[i],node[i].dep=dep[i],node[i].v=v[i]; sort(node+1,node+n+1,cmp); memset(depl,127,sizeof(depl)); for(int i=1;i<=n;i++){ depl[node[i].dep]=min(depl[node[i].dep],i); depr[node[i].dep]=max(depr[node[i].dep],i); } build(1,1,n); scanf("%d",&q); while(q--){ int id,x,y;scanf("%d%d%d",&id,&x,&y); if(id==1){ int ps=erfen(dfn[x],depl[dep[x]],depr[dep[x]]); change(1,ps,y); } else{ int deep=dep[x]+y; int ll=erfen2(depl[deep],depr[deep],dfn[x],ed[x]),rr=erfen3(depl[deep],depr[deep],dfn[x],ed[x]); if(ll==-1||rr==-1){ printf("0 ");continue; } printf("%d ",query(1,ll,rr)); } } return 0; }
代码:算法2
#include<iostream> #include<cstdlib> #include<cstdio> #include<cmath> #include<algorithm> #include<cstring> #include<queue> #include<vector> #include<stack> #include<map> #define ls k<<1 #define rs k<<1|1 #define RG register #define MAXN 400010 #define LL long long int using namespace std; const int INF=1e9; struct node{ int next; int to; }t[MAXN*2]; int head[MAXN*2],num; int n,m; int sum[MAXN*20],lson[MAXN*20],rson[MAXN*20]; int val[MAXN]; int dfn[MAXN],low[MAXN],dep[MAXN],tot; int root[MAXN],sz; void add(int from,int to) { t[++num].next=head[from]; t[num].to=to; head[from]=num; } void dfs(int u,int f,int k) { dfn[u]=++tot;dep[u]=k; for(int i=head[u];i;i=t[i].next) { int v=t[i].to; if(v==f) continue; dfs(v,u,k+1); } low[u]=tot; } void update(int L,int R,int id,int x,int &k) { if(!k) k=++sz; if(L==R){sum[k]=x;return;} int mid=(L+R)>>1; if(id<=mid) update(L,mid,id,x,lson[k]); else update(mid+1,R,id,x,rson[k]); sum[k]=sum[lson[k]]+sum[rson[k]]; } int query(int L,int R,int l,int r,int k) { if(!k) return 0; if(l<=L&&R<=r) return sum[k]; int mid=(L+R)>>1; if(r<=mid) return query(L,mid,l,r,lson[k]); else{ if(l>mid) return query(mid+1,R,l,r,rson[k]); else return query(L,mid,l,mid,lson[k])+query(mid+1,R,mid+1,r,rson[k]); } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&val[i]); int x,y; for(int i=1;i<n;i++){ scanf("%d%d",&x,&y);add(x,y);add(y,x); } dfs(1,0,1); for(int i=1;i<=n;i++) update(1,n,dfn[i],val[i],root[dep[i]]); scanf("%d",&m); int ch,c1,c2; for(int i=1;i<=m;i++) { scanf("%d%d%d",&ch,&c1,&c2); if(ch==1) update(1,n,dfn[c1],c2,root[dep[c1]]); else printf("%d ",query(1,n,dfn[c1],low[c1],root[dep[c1]+c2])); } return 0; }