题目大意:
https://www.luogu.org/problemnew/show/P3384
树链剖分的讲解 两个dfs() 修改 查询
很详细很好理解 https://www.cnblogs.com/George1994/p/7821357.html
不过上面的讲解没有完整的代码
没有说到 操作3和操作4 应该处理的区间
某棵子树在线段树中的对应区间 可以从 根节点的dfsID p[i] 及 根节点的儿子数量 num[i] 得到
即若子树的根节点为 i 那么对应的区间应该是 ( p[i],p[i]+num[i]-1 )
BTW 其实这道题测试数据很弱 不一定能检验出模板的错误
#include <bits/stdc++.h> using namespace std; #define mem(i,j) memset(i,j,sizeof(i)) #define LL long long #define lson l , m , rt << 1 #define rson m + 1 , r , rt << 1 | 1 const int maxn = 2e5 + 10; const int maxnode = maxn<<2; const int maxedge = maxn<<2; LL head[maxn], tot, pos; LL fa[maxn], son[maxn], dep[maxn], num[maxn]; // i的父亲、i的重结点、i的深度、i的儿子个数 LL top[maxn], p[maxn], fp[maxn]; // i所在链的顶端、ID->dfsID、dfsID->ID LL n,m,r,mod; LL val[maxn]; struct Edge { int to,ne; }e[maxedge]; void init() { tot=1; pos=0; mem(head,0); mem(son,0); } void add(int u,int v) { e[tot].to = v; e[tot].ne = head[u]; head[u] = tot++; } struct IntervalTree { LL _sum, _min, _max; LL sumv[maxnode], minv[maxnode], maxv[maxnode], setv[maxnode], addv[maxnode]; void init() { mem(sumv,0); mem(setv,0); mem(addv,0); } void maintain(int L, int R, int rt) { int lc = rt<<1, rc = rt<<1|1; if(R > L) { sumv[rt] = (sumv[lc] + sumv[rc])%mod; minv[rt] = min(minv[lc], minv[rc]); maxv[rt] = max(maxv[lc], maxv[rc]); } if(setv[rt] >= 0) { minv[rt] = maxv[rt] = setv[rt]; sumv[rt] = setv[rt] * (R-L+1)%mod; } if(addv[rt]) { minv[rt] += addv[rt]; maxv[rt] += addv[rt]; sumv[rt] = (sumv[rt]+addv[rt] * (R-L+1)%mod)%mod; } } void pushdown(int rt) { int lc = rt*2, rc = rt*2+1; if(setv[rt] >= 0) { setv[lc] = setv[rc] = setv[rt]; addv[lc] = addv[rc] = 0; setv[rt] = -1; } if(addv[rt]) { addv[lc] += addv[rt]; addv[rc] += addv[rt]; addv[rt] = 0; } } ///update(更新区间左右端点、更新值、更新选项 op = 1 为加减 op != 1 为置值、当前区间左右端点、根) void update(int L, int R, LL v, int op, int l, int r, int rt){ //int lc = rt<<1, rc = rt<<1|1; if(L <= l && R >= r) { if(op == 1) addv[rt] += v; else { setv[rt] = v; addv[rt] = 0; } } else { pushdown(rt); int m = l + (r-l)/2; if(L <= m) update(L, R, v, op, lson); else maintain(lson); if(R > m) update(L, R, v, op, rson); else maintain(rson); } maintain(l, r, rt); } ///query(问询的左右端点、累加lazy_tag的累加量、当前区间左右端点、根) void query(int L, int R, LL add, int l, int r, int rt) { if(setv[rt] >= 0) { LL v = setv[rt] + add + addv[rt]; _sum += v * (LL)(min(r,R)-max(l,L)+1)%mod; _min = min(_min, v); _max = max(_max, v); } else if(L <= l && R >= r) { _sum += (sumv[rt] + add * (LL)(r-l+1)%mod)%mod; _min = min(_min, minv[rt] + add); _max = max(_max, maxv[rt] + add); } else { int m = l + (r-l)/2; if(L <= m) query(L, R, add+addv[rt], lson); if(R > m) query(L, R, add+addv[rt], rson); } } }T; /** -----树链剖分----- */ void dfs1(int u,int pre,int d) { dep[u]=d; fa[u]=pre; num[u]=1; for(int i=head[u];i;i=e[i].ne) { int v=e[i].to; if(v!=fa[u]) { dfs1(v,u,d+1); num[u]+=num[v]; if(!son[u] || num[v]>num[son[u]]) son[u]=v; } } } void dfs2(int u,int sp) { top[u]=sp; p[u]=++pos; fp[p[u]]=u; if(!son[u]) return; dfs2(son[u],sp); for(int i=head[u];i;i=e[i].ne) { int v=e[i].to; if(v!=son[u] && v!=fa[u]) dfs2(v,v); } } // 查询树上x到y的总和 LL queryPath(int x,int y) { LL ans=0LL; int fx=top[x], fy=top[y]; // fx==fy 说明到了LCA while(fx!=fy) { // x y不在同一条重链上 if(dep[fx]>=dep[fy]) { T._sum=0LL; T.query(p[fx],p[x],0,1,pos,1); ans=(ans+T._sum)%mod; x=fa[fx]; } else { T._sum=0LL; T.query(p[fy],p[y],0,1,pos,1); ans=(ans+T._sum)%mod; y=fa[fy]; } // 先加离LCA更远的 且只加到父亲节点的一段 一步步移 fx=top[x], fy=top[y]; } // 直到两点在同一条重链上跳出 此时节点必连续 // 将最后到达LCA的一段连续的区间加上 if(p[x]>p[y]) swap(x,y); T._sum=0LL; T.query(p[x],p[y],0,1,n,1); return (ans+T._sum)%mod; } // 将树上x到y都加上z (和queryPath()差不多) void updatePath(int x,int y,int z) { int fx=top[x], fy=top[y]; while(fx!=fy) { if(dep[fx]>=dep[fy]) { T.update(p[fx],p[x],(LL)z,1,1,n,1); x=fa[fx]; } else { T.update(p[fy],p[y],(LL)z,1,1,n,1); y=fa[fy]; } fx=top[x], fy=top[y]; } if(p[x]>p[y]) swap(x,y); T.update(p[x],p[y],(LL)z,1,1,n,1); } /* ---------------- */ int main() { while(~scanf("%lld%lld%lld%lld",&n,&m,&r,&mod)) { init(); for(int i=1;i<=n;i++) scanf("%lld",&val[i]); for(int i=1;i<n;i++) { int a,b; scanf("%d%d",&a,&b); add(a,b); add(b,a); } dfs1(r,0,1); // 根节点 前驱节点 深度 dfs2(r,r); // 当前节点 起始重结点 T.init(); for(int i=1;i<=n;i++) T.update(p[i],p[i],val[fp[p[i]]],1,1,n,1); while(m--) { int op,x,y,z; scanf("%d",&op); //printf("op%d ",op); if(op==1) { scanf("%d%d%d",&x,&y,&z); updatePath(x,y,z); } else if(op==2) { scanf("%d%d",&x,&y); printf("%lld ",queryPath(x,y)%mod); } else if(op==3) { scanf("%d%d",&x,&z); T.update(p[x],p[x]+num[x]-1,(LL)z,1,1,n,1); } else { scanf("%d",&x); T._sum=0; T.query(p[x],p[x]+num[x]-1,0LL,1,n,1); printf("%lld ",T._sum%mod); } } } return 0; }