- 给定一棵 (n) 个点的以 (1) 为根有根树,给每条边染黑白两色使得所有 (m) 条给出的祖-孙路径上都存在一条边的颜色为黑色。
- 求染色方案数模 (998244353)。
- (n, m leq 5 imes 10^5)。
设 (f[x][d]) 表示满足了 (x) 子树内的所有路径,且从 (x) 子树内伸出来的未满足的路径的“祖”节点的最大深度为 (d) 的方案数。
这个dp是 max 卷积的形式,用线段树合并优化。
#include <bits/stdc++.h>
#define dbg(...) std::cerr << " 33[32;1m", fprintf(stderr, __VA_ARGS__), std::cerr << " 33[0m"
template <class T, class U>
inline bool smin(T &x, const U &y) { return y < x ? x = y, 1 : 0; }
template <class T, class U>
inline bool smax(T &x, const U &y) { return x < y ? x = y, 1 : 0; }
using LL = long long;
using PII = std::pair<int, int>;
constexpr int N(5e5 + 5), P(998244353);
int n, m;
inline void inc(int &x, int y) {
(x += y) >= P ? x -= P : 0;
}
inline int sum(int x, int y) {
return x + y >= P ? x + y - P : x + y;
}
struct Node {
Node *ls, *rs;
int sum, tag;
void times(int);
void pushup() { sum = ::sum(ls->sum, rs->sum); }
void pushdown() {
if (tag != 1) {
ls->times(tag), rs->times(tag);
tag = 1;
}
}
} t[N * 60], *null = t;
inline void Node::times(int x) {
if (this != null) {
sum = 1LL * sum * x % P;
tag = 1LL * tag * x % P;
}
}
void ins(Node *&o, int l, int r, int x, int y) {
static Node *p = t + 1;
if (o == null) {
o = p++;
o->ls = o->rs = null;
o->sum = 0;
o->tag = 1;
}
o->pushdown();
inc(o->sum, y);
if (l == r) return;
int m = l + r >> 1;
x <= m ? ins(o->ls, l, m, x, y) : ins(o->rs, m + 1, r, x, y);
}
void del(Node *&o, int l, int r, int x, int y) {
if (x <= l && r <= y) {
o = null;
return;
}
o->pushdown();
int m = l + r >> 1;
if (x <= m) del(o->ls, l, m, x, y);
if (y > m) del(o->rs, m + 1, r, x, y);
o->pushup();
}
Node* merge(Node *x, Node *y, int l, int r, int xl, int yl) {
if (x == null) {
y->times(xl);
return y;
}
if (y == null) {
x->times(yl);
return x;
}
if (l == r) {
x->sum = (1LL * x->sum * yl + 1LL * y->sum * xl + 1LL * x->sum * y->sum) % P;
return x;
}
x->pushdown(), y->pushdown();
int m = l + r >> 1;
x->rs = merge(x->rs, y->rs, m + 1, r, sum(xl, x->ls->sum), sum(yl, y->ls->sum));
x->ls = merge(x->ls, y->ls, l, m, xl, yl);
x->pushup();
return x;
}
std::vector<int> g[N];
int dep[N], max_dep, dn[N];
void go(int x, int fa) {
smax(max_dep, dep[x]);
for (int y : g[x]) {
if (y == fa) continue;
dep[y] = dep[x] + 1;
go(y, x);
}
}
Node *root[N];
void dfs(int x, int fa) {
root[x] = null;
ins(root[x], 0, max_dep, dn[x], 1);
for (int y : g[x]) {
if (y == fa) continue;
dfs(y, x);
root[x] = merge(root[x], root[y], 0, max_dep, 0, 0);
}
if (dep[x] < max_dep) {
del(root[x], 0, max_dep, dep[x] + 1, max_dep);
}
if (x > 1) {
ins(root[x], 0, max_dep, 0, root[x]->sum);
}
}
int main() {
freopen("destiny.in", "r", stdin);
freopen("destiny.out", "w", stdout);
null->ls = null->rs = null;
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n;
for (int i = 1, x, y; i < n; i++) {
std::cin >> x >> y;
g[x].push_back(y);
g[y].push_back(x);
}
go(1, 0);
std::cin >> m;
while (m--) {
int x, y;
std::cin >> x >> y;
smax(dn[y], dep[x] + 1);
}
dfs(1, 0);
std::cout << root[1]->sum << "
";
return 0;
}