Description
别忘了这是一棵动态树, 每时每刻都是动态的. 小明要求你在这棵树上维护两种事件
事件 (0) : 这棵树长出了一些果子, 即某个子树中的每个节点都会长出 (K) 个果子.
事件 (1) : 小明希望你求出几条树枝上的果子数. 一条树枝其实就是一个从某个节点到根的路径的一段. 每次小明会选定一些树枝, 让你求出在这些树枝上的节点的果子数的和. 注意, 树枝之间可能会重合, 这时重合的部分的节点的果子只要算一次.
Input
第一行一个整数 (n(1le nle 200000)) , 即节点数.
接下来 (n-1) 行, 每行两个数字 (u, v) . 表示果子 (u) 和果子 (v) 之间有一条直接的边. 节点从 (1) 开始编号.
在接下来一个整数 (nQ(1le nQle 200000)) , 表示事件.
最后 (nQ) 行, 每行开头要么是 (0) , 要么是 (1) .
如果是 (0) , 表示这个事件是事件 (0) . 这行接下来的 (2) 个整数 (u, delta) 表示以 (u) 为根的子树中的每个节点长出了 (delta) 个果子.
如果是 (1) , 表示这个事件是事件 (1) . 这行接下来一个整数 (K(1le Kle 5)) , 表示这次询问涉及 (K) 个树枝. 接下来K对整数 (u_k, v_k) , 每个树枝从节点 (u_k) 到节点 (v_k) . 由于果子数可能非常多, 请输出这个数模 (2^{31}) 的结果.
Output
对于每个事件 (1) , 输出询问的果子数.
Sample Input
5
1 2
2 3
2 4
1 5
3
0 1 1
0 2 3
1 2 3 1 1 4
Sample Output
13
HINT
(1 le n le 200,000, 1 <= nQ le 200000, K le 5.)
生成每个树枝的过程是这样的:先在树中随机找一个节点, 然后在这个节点到根的路径上随机选一个节点, 这两个节点就作为树枝的两端.
Solution
操作 (0) 就直接链剖了以后用线段树维护就好了,这个不用讲吧。
操作 (1) 的话, HINT 里面讲了一下数据构造方式,那么我们可以很容易的脑补出一个树枝是不会跨过根的,而且 (K) 又很小。那么我们可以同样用链剖 + 线段树做,用容斥来维护一下。就是单条链价值 - 两条链相交价值 + 三 - 四 + 五。
#include<bits/stdc++.h>
using namespace std;
#define N 200001
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define ll long long
inline int read() {
int x = 0, flag = 1; char ch = getchar(); while (!isdigit(ch)) { if (!(ch ^ '-')) flag = -1; ch = getchar(); }
while (isdigit(ch)) x = (x << 1) + (x << 3) + ch - '0', ch = getchar(); return x * flag;
}
#define ls rt << 1
#define rs ls | 1
#define mid (l + r >> 1)
const ll mod = 2147483648ll;
int n;
int siz[N], son[N], fa[N], dep[N], top[N], into[N], outo[N], ind;
ll sum[N << 2], tag[N << 2];
int qu[6], qv[6];
ll w[N];
bool _vis[N];
struct edge{ int v, next; }e[N << 1];
int head[N], tot = 1;
void inline insert(int u, int v) { e[++tot].v = v, e[tot].next = head[u], head[u] = tot; }
void inline add(int u, int v) { insert(u, v), insert(v, u); }
void dfs(int u) {
siz[u] = 1;
for(int i = head[u], v; i; i = e[i].next) if((v = e[i].v) ^ fa[u]) {
fa[v] = u, dep[v] = dep[u] + 1, dfs(v), siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs(int u, int tp) {
into[u] = ++ind, top[u] = tp;
if(son[u]) dfs(son[u], tp);
for(int i = head[u], v; i; i = e[i].next) if(((v = e[i].v) ^ fa[u]) && (son[u] ^ v)) dfs(v, v);
outo[u] = ind;
}
inline void downTag(int rt, int l, int r) {
sum[ls] += (mid - l + 1) * tag[rt], tag[ls] += tag[rt];
sum[rs] += (r - mid) * tag[rt], tag[rs] += tag[rt];
tag[rt] = 0;
}
void modify(int rt, int l, int r, int L, int R, ll val) {
if(l >= L && r <= R) { (sum[rt] += (r - l + 1) * val) %= mod, (tag[rt] += val) %= mod; return; }
downTag(rt, l, r);
if(L <= mid) modify(ls, l, mid, L, R, val);
if(R > mid) modify(rs, mid + 1, r, L, R, val);
sum[rt] = sum[ls] + sum[rs];
}
ll query(int rt, int l, int r, int L, int R) {
if(l >= L && r <= R) return sum[rt];
downTag(rt, l, r); ll ret = 0;
if(L <= mid) (ret += query(ls, l, mid, L, R)) %= mod;
if(R > mid) (ret += query(rs, mid + 1, r, L, R)) %= mod;
return ret;
}
ll ask(int x, int y) {
ll ret = 0;
while(top[x] ^ top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
(ret += query(1, 1, n, into[top[x]], into[x])) %= mod, x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
return (ret + query(1, 1, n, into[x], into[y])) % mod;
}
int lca(int x, int y) {
while(top[x] ^ top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x, y); x = fa[top[x]]; }
return dep[x] < dep[y] ? x : y;
}
void update(int u, ll val) { w[u] += val; for (int i = head[u], v; i; i = e[i].next) if((v = e[i].v) ^ fa[u]) update(v, val); }
int main() {
cin >> n;
rep(i, 2, n) add(read(), read());
dfs(1), dfs(1, 1);
int Q = read();
while(Q--)
if(!read()) {
int x = read(); ll val = read();
modify(1, 1, n, into[x], outo[x], val);
}
else {
int k = read(); ll ret = 0;
rep(i, 1, k) {
qu[i] = read(), qv[i] = read(), (ret += ask(qu[i], qv[i])) %= mod;
if(dep[qu[i]] < dep[qv[i]]) swap(qu[i], qv[i]);
}
if(k >= 2) rep(i, 1, k) rep(j, i + 1, k) {
int x = lca(qu[i], qu[j]), y = dep[qv[i]] > dep[qv[j]] ? qv[i] : qv[j];
if(dep[y] > dep[x]) continue;
((ret += mod) -= ask(x, y)) %= mod;
}
if(k >= 3) rep(i1, 1, k) rep(i2, i1 + 1, k) rep(i3, i2 + 1, k) {
int x = lca(qu[i1], lca(qu[i2], qu[i3])), y;
if(dep[qv[i1]] > dep[qv[i2]]) y = qv[i1]; else y = qv[i2];
if(dep[qv[i3]] > dep[y]) y = qv[i3];
if(dep[y] > dep[x]) continue;
(ret += ask(x, y)) %= mod;
}
if(k >= 4) rep(i1, 1, k) rep(i2, i1 + 1, k) rep(i3, i2 + 1, k) rep(i4, i3 + 1, k) {
int x = lca(qu[i1], lca(qu[i2], lca(qu[i3], qu[i4]))), y;
if(dep[qv[i1]] > dep[qv[i2]]) y = qv[i1]; else y = qv[i2];
if(dep[qv[i3]] > dep[y]) y = qv[i3];
if(dep[qv[i4]] > dep[y]) y = qv[i4];
if(dep[y] > dep[x]) continue;
((ret += mod) -= ask(x, y)) %= mod;
}
if(k >= 5) rep(i1, 1, k) rep(i2, i1 + 1, k) rep(i3, i2 + 1, k) rep(i4, i3 + 1, k) rep(i5, i4 + 1, k) {
int x = lca(qu[i1], lca(qu[i2], lca(qu[i3], lca(qu[i4], qu[i5])))), y;
if(dep[qv[i1]] > dep[qv[i2]]) y = qv[i1]; else y = qv[i2];
if(dep[qv[i3]] > dep[y]) y = qv[i3];
if(dep[qv[i4]] > dep[y]) y = qv[i4];
if(dep[qv[i5]] > dep[y]) y = qv[i5];
if(dep[y] > dep[x]) continue;
(ret += ask(x, y)) %= mod;
}
printf("%lld
", ret);
}
return 0;
}