神仙题,考虑暴力dp,f[i][j]代表以i为根的子树,到i距离为j的点的权值和。由于这个dp跟深度有关,显然可以用长链剖分优化到O(n)
考虑修改,我们可以将操作序列分块,对于每个块跑一遍长链剖分,统计一个点的询问时暴力进行在这个询问之前的所有修改,处理完一个块之后再把块内的所有修改加上。
我们需要维护一个数据结构支持单点修改,区间异或和,我们发现修改有O(nsqrt(n))次而查询只有O(n)次,分块即可做到O(nsqrt(n))。
#include<bits/stdc++.h> #define file(s) freopen(s".in","r",stdin);freopen(s".out","w",stdout); #define P 998244353 #define mid (l+r>>1) #define N 1100000 #define lb(x) (x&(-x)) #define inf 999999999 #define M 1658561 #define int long long #define mem(x) memset(x,0,sizeof(x)); using namespace std; int s[N],sl[N],siz[N],dep[N],to[N],nxt[N],head[N],cnt,sz,n,q,son[N],dfn[N],id[N],tot,c[N],tc[N],ans[N],op[N]; int u[N],v[N],h[N]; struct nd1{int t,x,v;}; vector<nd1> ch; struct nd2{int t,l;}; vector<nd2> ask[N],back; void ins(int x,int y){ sl[x/sz]^=s[x],s[x]+=y,sl[x/sz]^=s[x]; } void add(int x,int y){ to[++cnt]=y; nxt[cnt]=head[x]; head[x]=cnt; } int query(int x,int y){ int res=0; if(x/sz==y/sz) for(int i=x;i<=y;i++) res^=s[i]; else{ for(int i=x;i<(x/sz+1)*sz;i++) res^=s[i]; for(int i=x/sz+1;i<=y/sz-1;i++) res^=sl[i]; for(int i=y/sz*sz;i<=y;i++) res^=s[i]; } return res; } void dfs1(int x,int fa,int dp){ dep[x]=dp;siz[x]=1; for(int i=head[x];i;i=nxt[i]){ if(to[i]==fa) continue; dfs1(to[i],x,dp+1);siz[x]+=siz[to[i]];h[x]=max(h[x],h[to[i]]); if(h[son[x]]<=h[to[i]]) son[x]=to[i]; } h[x]++; } void dfs2(int x,int fa){ dfn[++tot]=x;id[x]=tot; if(son[x]) dfs2(son[x],x); for(int i=head[x];i;i=nxt[i]){ if(to[i]==fa||to[i]==son[x]) continue; dfs2(to[i],x); } } void dfs(int x,int fa){ ins(id[x],c[x]); if(son[x]) dfs(son[x],x); for(int i=head[x];i;i=nxt[i]){ if(to[i]==fa||to[i]==son[x]) continue; dfs(to[i],x); for(int j=1;j<=h[to[i]];j++){ ins(id[x]+j,s[id[to[i]]+j-1]); } } int p=0; for(auto i:ask[x]){ for(;p<ch.size()&&ch[p].t<i.t;p++){ int v=ch[p].x; if(id[v]>=id[x]&&id[v]<id[x]+siz[x]){ int t1=id[x]-dep[x]+dep[v],t2=ch[p].v; back.push_back({t1,t2}); ins(t1,t2); } } ans[i.t]=query(id[x],id[x]+min(i.l,h[x]-1)); } for(auto i:back) ins(i.t,-i.l); back.clear(); } signed main(){ int x,y; scanf("%lld%lld",&n,&q);sz=2000; for(int i=1;i<=n;i++) scanf("%lld",&c[i]),tc[i]=c[i]; for(int i=1;i<n;i++) scanf("%lld%lld",&x,&y),add(x,y),add(y,x); dfs1(1,0,1);dfs2(1,0); for(int i=1;i<=q;i++)scanf("%lld%lld%lld",&op[i],&u[i],&v[i]); for(int i=1;i<=q;i+=sz){ int l=i,r=min(q,i+sz-1); ch.clear(); mem(sl); mem(s); for(int j=1;j<=n;j++) ask[j].clear(); for(int j=l;j<=r;j++) if(op[j]==1) ch.push_back({j,u[j],v[j]-tc[u[j]]}),tc[u[j]]=v[j]; else ask[u[j]].push_back({j,v[j]}); dfs(1,0); for(int j=l;j<=r;j++) if(op[j]==1) c[u[j]]=v[j]; } for(int i=1;i<=q;i++) if(op[i]==2) printf("%lld ",ans[i]); return 0; }