【题目链接】
【算法】
树链剖分
对于线段树的每个节点,记录这段区间的最小值,最小值的个数,值为0的个数,此外,还要维护两个懒惰标记
【代码】
本题细节很多,写程序时要认真严谨!
#include<bits/stdc++.h> using namespace std; #define MAXN 100010 #define MAXLOG 20 const int INF = 1e9; int i,n,m,tot,opt,u,v,c,x,y,timer,Lca,tmp; int dep[MAXN],dfn[MAXN],head[MAXN],size[MAXN],anc[MAXN][MAXLOG],fa[MAXN],top[MAXN],son[MAXN]; struct Edge { int to,nxt; } e[MAXN<<1]; struct SegmentTree { struct Node { int l,r,sum,cnt,Min,taga,tagb; } Tree[MAXN<<2]; inline void build(int index,int l,int r) { int mid; Tree[index].l = l; Tree[index].r = r; Tree[index].sum = Tree[index].cnt = r - l + 1; Tree[index].taga = -1; Tree[index].tagb = 0; Tree[index].Min = 0; if (l == r) return; mid = (l + r) >> 1; build(index<<1,l,mid); build(index<<1|1,mid+1,r); } inline void pushdown(int index) { int l = Tree[index].l,r = Tree[index].r; int mid = (l + r) >> 1; if (Tree[index].taga != -1) { Tree[index<<1].sum = mid - l + 1; if (!Tree[index].taga) Tree[index<<1].cnt = mid - l + 1; else Tree[index<<1].cnt = 0; Tree[index<<1].Min = Tree[index].taga; Tree[index<<1|1].sum = r - mid; if (!Tree[index].taga) Tree[index<<1|1].cnt = r - mid; else Tree[index<<1|1].cnt = 0; Tree[index<<1|1].Min = Tree[index].taga; Tree[index<<1].tagb = Tree[index<<1|1].tagb = 0; Tree[index<<1].taga = Tree[index<<1|1].taga = Tree[index].taga; Tree[index].taga = -1; } if (Tree[index].tagb) { Tree[index<<1].Min += Tree[index].tagb; if (!Tree[index<<1].Min) Tree[index<<1].cnt = Tree[index<<1].sum; else Tree[index<<1].cnt = 0; Tree[index<<1|1].Min += Tree[index].tagb; if (!Tree[index<<1|1].Min) Tree[index<<1|1].cnt = Tree[index<<1|1].sum; else Tree[index<<1|1].cnt = 0; if (Tree[index<<1].taga != -1) Tree[index<<1].taga += Tree[index].tagb; else Tree[index<<1].tagb += Tree[index].tagb; if (Tree[index<<1|1].taga != -1) Tree[index<<1|1].taga += Tree[index].tagb; else Tree[index<<1|1].tagb += Tree[index].tagb; Tree[index].tagb = 0; } } inline void update(int index) { Tree[index].Min = min(Tree[index<<1].Min,Tree[index<<1|1].Min); Tree[index].cnt = Tree[index<<1].cnt + Tree[index<<1|1].cnt; if (Tree[index<<1].Min < Tree[index<<1|1].Min) Tree[index].sum = Tree[index<<1].sum; else if (Tree[index<<1|1].Min < Tree[index<<1].Min) Tree[index].sum = Tree[index<<1|1].sum; else Tree[index].sum = Tree[index<<1].sum + Tree[index<<1|1].sum; } inline void modify(int index,int l,int r,int val) { int mid; if (l > r) return; if (Tree[index].l == l && Tree[index].r == r) { Tree[index].Min = val; Tree[index].taga = val; Tree[index].tagb = 0; Tree[index].sum = r - l + 1; if (!val) Tree[index].cnt = r - l + 1; else Tree[index].cnt = 0; return; } pushdown(index); mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= r) modify(index<<1,l,r,val); else if (mid + 1 <= l) modify(index<<1|1,l,r,val); else { modify(index<<1,l,mid,val); modify(index<<1|1,mid+1,r,val); } update(index); } inline void add(int index,int l,int r,int val) { int mid; if (l > r) return; if (Tree[index].l == l && Tree[index].r == r) { Tree[index].Min += val; if (Tree[index].taga != -1) Tree[index].taga += val; else Tree[index].tagb += val; if (!Tree[index].Min) Tree[index].cnt = Tree[index].sum; else Tree[index].cnt = 0; return; } pushdown(index); mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= r) add(index<<1,l,r,val); else if (mid + 1 <= l) add(index<<1|1,l,r,val); else { add(index<<1,l,mid,val); add(index<<1|1,mid+1,r,val); } update(index); } inline int query_min(int index,int l,int r) { int mid; if (l > r) return INF; if (Tree[index].l == l && Tree[index].r == r) return Tree[index].Min; pushdown(index); mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= r) return query_min(index<<1,l,r); else if (mid + 1 <= l) return query_min(index<<1|1,l,r); else return min(query_min(index<<1,l,mid),query_min(index<<1|1,mid+1,r)); } inline int query() { return Tree[1].cnt - 1; } } T; inline void add(int u,int v) { tot++; e[tot] = (Edge){v,head[u]}; head[u] = tot; } inline void dfs1(int u) { int i,v; size[u] = 1; anc[u][0] = fa[u]; for (i = 1; i < MAXLOG; i++) { if (dep[u] < (1 << i)) break; anc[u][i] = anc[anc[u][i-1]][i-1]; } for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (fa[u] != v) { dep[v] = dep[u] + 1; fa[v] = u; dfs1(v); size[u] += size[v]; if (size[v] > size[son[u]]) son[u] = v; } } } inline void dfs2(int u,int tp) { int i,v; dfn[u] = ++timer; top[u] = tp; if (son[u]) dfs2(son[u],tp); for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (fa[u] != v && son[u] != v) dfs2(v,v); } } inline void solve1(int u,int v,int c) { int tu = top[u],tv = top[v]; while (tu != tv) { T.modify(1,dfn[tv],dfn[v],c); v = fa[tv]; tv = top[v]; } T.modify(1,dfn[u]+1,dfn[v],c); } inline void solve2(int u,int v,int c) { int tu = top[u],tv = top[v]; while (tu != tv) { T.add(1,dfn[tv],dfn[v],c); v = fa[tv]; tv = top[v]; } T.add(1,dfn[u]+1,dfn[v],c); } inline int query_min(int u,int v) { int tu = top[u],tv = top[v],ans = INF; while (tu != tv) { ans = min(ans,T.query_min(1,dfn[tv],dfn[v])); v = fa[tv]; tv = top[v]; } ans = min(ans,T.query_min(1,dfn[u]+1,dfn[v])); return ans; } inline int lca(int x,int y) { int i,t; if (dep[x] > dep[y]) swap(x,y); t = dep[y] - dep[x]; for (i = 0; i < MAXLOG; i++) { if (t & (1 << i)) y = anc[y][i]; } if (x == y) return x; for (i = MAXLOG - 1; i >= 0; i--) { if (anc[x][i] != anc[y][i]) { x = anc[x][i]; y = anc[y][i]; } } return fa[x]; } template <typename T> inline void read(T &x) { int f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) { if (c == '-') f = -f; } for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } template <typename T> inline void write(T x) { if (x < 0) { putchar('-'); x = -x; } if (x > 9) write(x/10); putchar(x%10+'0'); } template <typename T> inline void writeln(T x) { write(x); puts(""); } int main() { read(n); read(m); for (i = 1; i < n; i++) { read(x); read(y); add(x,y); add(y,x); } dfs1(1); dfs2(1,1); T.build(1,1,timer); while (m--) { read(opt); if (opt == 1) { read(u); read(v); read(c); Lca = lca(u,v); solve1(Lca,u,c); solve1(Lca,v,c); } else { read(u); read(v); read(c); Lca = lca(u,v); tmp = min(query_min(Lca,u),query_min(Lca,v)); if (tmp + c < 0) c = -tmp; solve2(Lca,u,c); solve2(Lca,v,c); } writeln(T.query()); } return 0; }