题目描述
定义(dist(i, j))为树上(i, j)两点的距离
给出节点编号(1)到(n)的树,一个(1)到(n)的排列(a),(q)次询问,每次给出(k),求(sum_{l = 1}^{k} sum_{r = l}^{k} sum_{i = l}^{r} sum_{j = i}^{r} dist(a_i, a_j) mod 998244353)
(n,q le 1e5)
解析
设询问(k)答案为(f[k])
由于询问都是从(1)开始,一个自然的想法便是从(f[k - 1])推向(f[k])
考虑新加入(a_k)后答案的增量(g[k])
我们先把(dist(a_i, a_j))拆成(dep(a_i) + dep(a_j) - 2 cdot dep(lca(a_i, a_j)))
加入(k)位置后增加的区间是以(k)位置结尾的区间,对每个(i < k),区间([i, k])会求(k - i + 1)次与(a_k)有关的(lca),所以(g[k])包含(k cdot (k - 1) / 2 cdot dep[a_k]),同时(a_i)会和(a_k)求(i)次(dist),所以加上(dep[a_i] cdot i),然后还要加上上一次的增量(g[k - 1]),因为每个上次增量计算过的区间这一次也会多计算一次
还剩下的就是(lca)的部分,因为(a_i)会和(a_k)求(i)次(dist),所以减去的就是(2 sum_{i} i cdot dep(lca(a_i, a_k)))
这个东西据说是套路树链剖分然而蒟蒻我见都没见过qwq,具体做法是每插入一个点(a_i),把这个点到根的路径上每个点点权加(i),然后你就发现(a_k)到根的路径上的点权和就神奇地变成了这个东西……
总的来讲就是
然后(f[i] = f[i - 1] + g[i]),顺次推一遍就行了
复杂度(O(n log^2 n)),因为有个树链剖分
代码
PS.先是树剖的时候没有把size统计到父亲T飞,再是线段树询问没有push_down结果WA完……我好菜啊qwq
#include <cstdio>
#include <cstring>
#include <iostream>
#define MAXN 100005
#define REG register
typedef long long LL;
const LL mod = 998244353ll;
struct Edge {
int v, next;
Edge(int _v = 0, int _n = 0):v(_v), next(_n) {}
} edge[MAXN << 1];
int head[MAXN], fa[MAXN], dep[MAXN], top[MAXN], dfn[MAXN], idx, heavy[MAXN], size[MAXN];
int N, Q, f[MAXN], upd1, upd2;
struct SegmentTree {
int sum[MAXN << 2], add[MAXN << 2];
void push_up(int);
void push_down(int, int, int);
void update(int, int, int, int, int, int);
int query(int, int, int, int, int);
} tr;
inline void add_edge(int u, int v) { static int cnt; edge[cnt] = Edge(v, head[u]); head[u] = cnt++; }
inline void insert(int u, int v) { add_edge(u, v); add_edge(v, u); }
void dfs(int);
void dfs2(int);
inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { int res = x + y; return res >= mod ? res - mod : res; }
inline int less(int x, int y) { int res = x - y; return res < 0 ? res + mod : res; }
int main() {
freopen("sumsumsum.in", "r", stdin);
freopen("sumsumsum.out", "w", stdout);
memset(head, -1, sizeof head);
scanf("%d%d", &N, &Q);
for (int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
insert(u, v);
}
top[1] = dep[1] = 1;
dfs(1);
dfs2(1);
for (int i = 1; i <= N; ++i) {
int a; scanf("%d", &a);
inc(upd2, add(upd1, (LL)i * (i - 1) / 2 * dep[a] % mod));
inc(upd1,(LL)dep[a] * i % mod);
int cur = a;
while (cur) {
int tp = top[cur];
dec(upd2, tr.query(1, 1, N, dfn[tp], dfn[cur]) * 2 % mod);
tr.update(1, 1, N, dfn[tp], dfn[cur], i);
cur = fa[tp];
//debug
//printf("%d %d
", upd1, upd2);
}
f[i] = add(f[i - 1], upd2);
}
while (Q--) {
int k; scanf("%d", &k);
printf("%d
", f[k]);
}
return 0;
}
void dfs(int u) {
dep[u] = dep[fa[u]] + 1;
size[u] = 1;
for (int i = head[u]; ~i; i = edge[i].next)
if (edge[i].v ^ fa[u]) {
fa[edge[i].v] = u;
dfs(edge[i].v);
size[u] += size[edge[i].v];
if (!heavy[u] || size[edge[i].v] > size[heavy[u]]) heavy[u] = edge[i].v;
}
}
void dfs2(int u) {
dfn[u] = ++idx;
if (heavy[u]) {
top[heavy[u]] = top[u];
dfs2(heavy[u]);
}
for (int i = head[u]; ~i; i = edge[i].next)
if ((edge[i].v ^ fa[u]) && (edge[i].v ^ heavy[u])) { top[edge[i].v] = edge[i].v; dfs2(edge[i].v); }
}
void SegmentTree::push_down(int rt, int L, int R) {
if (add[rt]) {
int mid = (L + R) >> 1;
(add[rt << 1] += add[rt]) %= mod;
(add[rt << 1 | 1] += add[rt]) %= mod;
sum[rt << 1] = (sum[rt << 1] + add[rt] * (LL)(mid - L + 1) % mod) % mod;
sum[rt << 1 | 1] = (sum[rt << 1 | 1] + add[rt] * (LL)(R - mid) % mod) % mod;
add[rt] = 0;
}
}
inline void SegmentTree::push_up(int rt) {
sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % mod;
}
void SegmentTree::update(int rt, int L, int R, int l, int r, int v) {
if (L >= l && R <= r) {
inc(add[rt], v);
inc(sum[rt], v * (LL)(R - L + 1) % mod);
} else {
push_down(rt, L, R);
int mid = (L + R) >> 1;
if (l <= mid) update(rt << 1, L, mid, l, r, v);
if (r > mid) update(rt << 1 | 1, mid + 1, R, l, r, v);
push_up(rt);
}
}
int SegmentTree::query(int rt, int L, int R, int l, int r) {
if (L >= l && R <= r) return sum[rt];
push_down(rt, L, R);
int mid = (L + R) >> 1, res = 0;
if (l <= mid) inc(res, query(rt << 1, L, mid, l, r));
if (r > mid) inc(res, query(rt << 1 | 1, mid + 1, R, l, r));
return res;
}
//Rhein_E