子树可以移动,唔。
还是用Splay维护DFS序即可。
子树的话直接截取出来就好了。
然后求前驱后继可能麻烦一些。
添加两个虚拟节点会比较好写。
#include <map> #include <cmath> #include <queue> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; #define F(i,j,k) for (int i=j;i<=k;++i) #define D(i,j,k) for (int i=j;i>=k;--i) #define ll long long #define mp make_pair #define maxn 300005 int n,fa[maxn],st[maxn],top=0,id[maxn],rt,m; int pos[maxn][2],w[maxn],ch[maxn][2],pn[maxn],siz[maxn][2]; int sta[maxn],cnt=0; ll val[maxn],sum[maxn],tag[maxn]; char opt[11]; vector <int> v[maxn]; void update(int x) { sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+val[x]; siz[x][0]=siz[ch[x][0]][0]+siz[ch[x][1]][0]; siz[x][1]=siz[ch[x][0]][1]+siz[ch[x][1]][1]; if (pn[x]==-1) siz[x][0]++; else if (pn[x]==1) siz[x][1]++; } int build(int l,int r,int f) { if (l>r) return 0; int mid=(l+r)/2; fa[mid]=f; if (st[mid]<0) { pos[-st[mid]][1]=mid; val[mid]=-w[-st[mid]]; pn[mid]=-1; } else if (st[mid]>0) { pos[st[mid]][0]=mid; val[mid]=w[st[mid]]; pn[mid]=1; } else if (st[mid]==0) { pos[st[mid]][0]=mid; val[mid]=0; pn[mid]=0; } ch[mid][0]=build(l,mid-1,mid); ch[mid][1]=build(mid+1,r,mid); update(mid); return mid; } void dfs(int o) { st[++top]=o; for (int i=0;i<v[o].size();++i) dfs(v[o][i]); st[++top]=-o; } void Debug(int o) { if (!o) return ; printf("|-----------------| "); printf("Node %d : ",o); printf("Fa %d Ch l - %d Ch r - %d ",fa[o],ch[o][0],ch[o][1]); printf("Val %lld Sum %lld Positive Negetive %d ",val[o],sum[o],pn[o]); Debug(ch[o][0]); Debug(ch[o][1]); } void pushdown(int x) { if (tag[x]) { tag[ch[x][0]]+=tag[x]; tag[ch[x][1]]+=tag[x]; val[ch[x][0]]+=pn[ch[x][0]]*tag[x]; val[ch[x][1]]+=pn[ch[x][1]]*tag[x]; sum[ch[x][0]]+=1LL*(siz[ch[x][0]][1]-siz[ch[x][0]][0])*tag[x]; sum[ch[x][1]]+=1LL*(siz[ch[x][1]][1]-siz[ch[x][1]][0])*tag[x]; tag[x]=0; } } void rot(int x,int &k) { int y=fa[x],z=fa[y],l,r; if (ch[y][0]==x) l=0; else l=1; r=l^1; if (y==k) k=x; else { if (ch[z][0]==y) ch[z][0]=x; else ch[z][1]=x; } fa[x]=z; fa[y]=x;fa[ch[x][r]]=y; ch[y][l]=ch[x][r]; ch[x][r]=y; update(y);update(x); } void splay(int x,int &k) { sta[cnt=1]=x; for (int i=x;i!=k;i=fa[i]) sta[++cnt]=fa[i]; while (cnt) pushdown(sta[cnt--]); while(x!=k) { int y=fa[x]; if (y!=k) { int z=fa[y]; if (ch[z][0]==y^ch[y][0]==x) rot(x,k); else rot(y,k); } rot(x,k); } } int pre(int x) { if (ch[x][0]) { x=ch[x][0]; while (ch[x][1]) x=ch[x][1]; return x; } while (ch[fa[x]][0]==x) x=fa[x]; x=fa[x]; return x; } int nxt(int x) { if (ch[x][1]) { x=ch[x][1]; while (ch[x][0]) x=ch[x][0]; return x; } while (ch[fa[x]][1]==x) x=fa[x]; x=fa[x]; return x; } int main() { scanf("%d",&n); F(i,2,n) { int u; scanf("%d",&u); v[u].push_back(i); } sta[++top]=0;dfs(1);sta[++top]=0; F(i,1,n) scanf("%d",&w[i]); rt=build(1,top,0); scanf("%d",&m); F(i,1,m) { scanf("%s",opt); int x,y,tmp,L,R; switch(opt[0]) { case 'Q': scanf("%d",&x); splay(pre(pos[1][0]),rt); splay(nxt(pos[x][0]),ch[rt][1]); printf("%lld ",sum[ch[ch[rt][1]][0]]); break; case 'C': scanf("%d%d",&x,&y); splay(pre(pos[x][0]),rt); splay(nxt(pos[x][1]),ch[rt][1]); tmp=ch[ch[rt][1]][0]; fa[tmp]=0; ch[ch[rt][1]][0]=0; update(ch[rt][1]); update(rt); L=pos[y][0];R=pos[y][0]+1; splay(pos[y][0],rt); splay(nxt(pos[y][0]),ch[rt][1]); ch[ch[rt][1]][0]=tmp; fa[tmp]=ch[rt][1]; update(ch[rt][1]); update(rt); break; case 'F': scanf("%d%d",&x,&y); splay(pre(pos[x][0]),rt); splay(nxt(pos[x][1]),ch[rt][1]); tag[ch[ch[rt][1]][0]]+=y; tmp=ch[ch[rt][1]][0]; val[tmp]+=y*pn[tmp]; sum[tmp]+=1LL*(siz[tmp][1]-siz[tmp][0])*y; break; } } }