易错点
注意两次dfs里记录信息时要核对一下是不是记录正确
然后就是线段树里一堆 += 不要写成 =
Code
#include<iostream> #include<cstdio> #include<algorithm> using namespace std; const int MAXN = 1e5 + 5; const int MAXM = 1e5 + 5; const int INF = 0x3f3f3f3f; int n,m; int root,p,tot; struct Edge { int nxt; int to,w; } l[MAXN << 1]; struct Tree { int sum; int tag; } t[MAXN << 2]; int deep[MAXN],fa[MAXN],siz[MAXN],son[MAXN]; int head[MAXN],cnt; int id[MAXN],val[MAXN],w[MAXN],top[MAXN]; void add(int x,int y) { cnt++; l[cnt].nxt = head[x]; l[cnt].to = y; head[x] = cnt; return; } void dfs1(int x,int from) { deep[x] = deep[from] + 1; fa[x] = from; siz[x] = 1; int maxsiz = -INF; for(int i = head[x]; i; i = l[i].nxt) { if(l[i].to == from) continue; dfs1(l[i].to,x); siz[x] += siz[l[i].to]; if(siz[l[i].to] > maxsiz) { maxsiz = siz[l[i].to]; son[x] = l[i].to; } } return; } void dfs2(int x,int y,int from) { id[x] = ++tot; val[tot] = w[x]; top[x] = y; if(!son[x]) return; dfs2(son[x],y,x); for(int i = head[x]; i; i = l[i].nxt) { if(l[i].to == from || l[i].to == son[x]) continue; dfs2(l[i].to,l[i].to,x); } return; } void update(int pos) { t[pos].sum = (t[pos << 1].sum + t[pos << 1 | 1].sum) % p; return; } void pushdown(int L,int R,int pos) { if(!t[pos].tag) return; int mid = (L + R) >> 1; t[pos << 1].sum += t[pos].tag * (mid - L + 1); t[pos << 1 | 1].sum += t[pos].tag * (R - mid); t[pos << 1].tag += t[pos].tag; t[pos << 1 | 1].tag += t[pos].tag; t[pos].tag = 0; return; } void build(int L,int R,int pos) { if(L == R) { t[pos].sum = val[L] % p; return; } int mid = (L + R) >> 1; build(L,mid,pos << 1); build(mid + 1,R,pos << 1 | 1); update(pos); return; } void modify(int L,int R,int ll,int rr,int pos,int v) { if(ll <= L && R <= rr) { t[pos].sum += v * (R - L + 1); t[pos].tag += v; return; } if(R < ll || rr < L) return; int mid = (L + R) >> 1; pushdown(L,R,pos); modify(L,mid,ll,rr,pos << 1,v); modify(mid + 1,R,ll,rr,pos << 1 | 1,v); update(pos); return; } int query(int L,int R,int ll,int rr,int pos) { if(ll <= L && R <= rr) { return t[pos].sum % p; } if(R < ll || rr < L) return 0; int mid = (L + R) >> 1; pushdown(L,R,pos); return query(L,mid,ll,rr,pos << 1) + query(mid + 1,R,ll,rr,pos << 1 | 1); } void way_add(int x,int y,int z) { while(top[x] != top[y]) { if(deep[top[x]] < deep[top[y]]) swap(x,y); modify(1,n,id[top[x]],id[x],1,z % p); x = fa[top[x]]; } if(deep[x] > deep[y]) swap(x,y); modify(1,n,id[x],id[y],1,z % p); return; } int way_ask(int x,int y) { int ans = 0; while(top[x] != top[y]) { if(deep[top[x]] < deep[top[y]]) swap(x,y); ans += query(1,n,id[top[x]],id[x],1); ans %= p; x = fa[top[x]]; } if(deep[x] > deep[y]) swap(x,y); ans += query(1,n,id[x],id[y],1); return ans % p; } void son_add(int x,int y) { modify(1,n,id[x],id[x] + siz[x] - 1,1,y % p); } int son_ask(int x) { return query(1,n,id[x],id[x] + siz[x] - 1,1); } int main() { scanf("%d%d%d%d",&n,&m,&root,&p); for(int i = 1; i <= n; i++) scanf("%d",&w[i]); int x,y,z; for(int i = 1; i < n; i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs1(root,0); dfs2(root,root,0); build(1,n,1); // for(int i = 1;i <= n;i++) { // cout<<"DEBUG:"<<top[i]<<" "<<fa[i]<<" "<<deep[i]<<" "<<id[i]<<" "<<val[i]<<endl; // } /* 操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z 操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和 操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z 操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和 */ int opt; while(m--) { scanf("%d",&opt); if(opt == 1) { scanf("%d%d%d",&x,&y,&z); way_add(x,y,z); } else if(opt == 2) { scanf("%d%d",&x,&y); printf("%d ",way_ask(x,y) % p); } else if(opt == 3) { scanf("%d%d",&x,&y); son_add(x,y); } else if(opt == 4) { scanf("%d",&x); printf("%d ",son_ask(x) % p); } } return 0; }