洛谷P6773 [NOI2020]命运(整体 dp)
题目大意
形式化的:给定一棵树 (T = (V, E)) 和点对集合 (mathcal Q subseteq V imes V) ,满足对于所有 ((u, v) in mathcal Q),都有 (u eq v),并且 (u) 是 (v) 在树 (T) 上的祖先。其中 (V) 和 (E) 分别代表树 (T) 的结点集和边集。求有多少个不同的函数 (f) : (E o {0, 1})(将每条边 (e in E) 的 (f(e)) 值置为 (0) 或 (1)),满足对于任何 ((u, v) in mathcal Q),都存在 (u) 到 (v) 路径上的一条边 (e) 使得 (f(e) = 1)。由于答案可能非常大,你只需要输出结果对 (998,244,353)(一个素数)取模的结果。
数据范围
全部数据满足:(n leq 5 imes 10^5),(m leq 5 imes 10^5)。输入构成一棵树,并且对于 (1 leq i leq m),(u_i) 始终为 (v_i) 的祖先结点。
解题思路
首先题目大意就是每个边可以是 0 或 1,有 m 个条件,要求树上的一条直上直下的链中至少有一个 1,求方案数。
考虑 dp,(dp[x][y]) 表示 x 子树内的边状态已经确定,下端点在子树内且没有被满足的条件中,上端点最深的深度是 y,特殊地,如果子树内的条件都满足那么 y = 0。记录最深是因为深得满足浅的也就会跟着满足。
考虑如何转移,对于 x 的儿子点 y,有
简单来说就是看这条边是 1 还是 0,前缀和优化一下就有 64 pt,其中如果建出来虚树就能再拿 8 pt。
现在是正解时间,重新审查我们的 dp
考虑线段树合并,(sum[y][i]) 可以先从线段树上查一下,剩下的都是和下标相关的,我们先合并左子树然后合并右子树,走右子树的时候就可以把左边的 sum 加上了,整体乘一个数可以打个标记即可,细节如下
// s1 -> (sum[y][dep[x]]+sum[y][i]), s2 -> sum[x][i-1]
int merge(int x, int y, int l, int r, ll &s1, ll &s2) {
if (!x && !y) return 0;
if (!x || !y) {
if (!x) {
add(s1, sum(y));
mul(y) = mul(y) * s2 % P;
sum(y) = sum(y) * s2 % P;
return y;
}
add(s2, sum(x));
sum(x) = sum(x) * s1 % P, mul(x) = mul(x) * s1 % P;
return x;
}
if (l == r) {
ll tx = sum(x), ty = sum(y); add(s1, ty);
sum(x) = (sum(x) * s1 + sum(y) * s2) % P;
add(s2, tx);
return x;
}
spread(x), spread(y); //下放标记
int mid = (l + r) >> 1;
ls(x) = merge(ls(x), ls(y), l, mid, s1, s2);
rs(x) = merge(rs(x), rs(y), mid + 1, r, s1, s2);
sum(x) = (sum(ls(x)) + sum(rs(x))) % P;
return x;
}
/*
/> フ
| _ _|
/`ミ _x 彡
/ |
/ ヽ ?
/ ̄| | | |
| ( ̄ヽ__ヽ_)_)
\二つ
*/
#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=x*10+(c^48);
if (f) x=-x;
}
template<typename F>
inline void write(F x, char ed = '
') {
static short st[30];short tp=0;
if(x<0) putchar('-'),x=-x;
do st[++tp]=x%10,x/=10; while(x);
while(tp) putchar('0'|st[tp--]);
putchar(ed);
}
template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }
template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }
const int P = 998244353;
const int N = 500050;
vector<int> v[N];
int h[N], ne[N<<1], to[N<<1], dep[N], tot, m, n;
inline void adde(int x, int y) {
ne[++tot] = h[x], to[h[x] = tot] = y;
}
inline void add(ll &x, ll y) { x += y, (x >= P) && (x -= P); }
struct node {
int ls, rs;
ll sum, mul;
#define mul(p) t[p].mul
#define ls(p) t[p].ls
#define rs(p) t[p].rs
#define sum(p) t[p].sum
}t[N<<5];
void spread(int p) {
if (ls(p)) {
sum(ls(p)) = sum(ls(p)) * mul(p) % P;
mul(ls(p)) = mul(ls(p)) * mul(p) % P;
}
if (rs(p)) {
sum(rs(p)) = sum(rs(p)) * mul(p) % P;
mul(rs(p)) = mul(rs(p)) * mul(p) % P;
}
mul(p) = 1;
}
ll query(int rt, int l, int r, int L) {
if (!rt || r <= L) return sum(rt);
int mid = (l + r) >> 1; ll sum = 0;
spread(rt);
if (mid < L) add(sum, query(rs(rt), mid + 1, r, L));
add(sum, query(ls(rt), l, mid, L));
return sum;
}
int cnt;
void change(int &p, int l, int r, int x) {
p = ++cnt, sum(p) = mul(p) = 1;
if (l == r) return;
int mid = (l + r) >> 1;
if (x <= mid) change(ls(p), l, mid, x);
else change(rs(p), mid + 1, r, x);
}
int merge(int x, int y, int l, int r, ll &s1, ll &s2) {
if (!x && !y) return 0;
if (!x || !y) {
if (!x) {
add(s1, sum(y));
mul(y) = mul(y) * s2 % P;
sum(y) = sum(y) * s2 % P;
return y;
}
add(s2, sum(x));
sum(x) = sum(x) * s1 % P, mul(x) = mul(x) * s1 % P;
return x;
}
if (l == r) {
ll tx = sum(x), ty = sum(y); add(s1, ty);
sum(x) = (sum(x) * s1 + sum(y) * s2) % P;
add(s2, tx);
return x;
}
spread(x), spread(y);
int mid = (l + r) >> 1;
ls(x) = merge(ls(x), ls(y), l, mid, s1, s2);
rs(x) = merge(rs(x), rs(y), mid + 1, r, s1, s2);
sum(x) = (sum(ls(x)) + sum(rs(x))) % P;
return x;
}
int T[N];
void dfs(int x, int fa) {
dep[x] = dep[fa] + 1; int mx = 0;
for (auto t: v[x]) Mx(mx, dep[t]);
change(T[x], 0, n, mx);
for (int i = h[x]; i; i = ne[i]) {
int y = to[i]; if (y == fa) continue;
dfs(y, x);
ll S = query(T[y], 0, n, dep[x]), SS = 0;
T[x] = merge(T[x], T[y], 0, n, S, SS);
}
}
int main() {
// freopen ("destiny.in","r",stdin);
// freopen ("destiny.out","w",stdout);
read(n);
for (int i = 1, x, y;i < n; i++)
read(x), read(y), adde(x, y), adde(y, x);
read(m);
for (int i = 1, x, y;i <= m; i++)
read(x), read(y), v[y].push_back(x);
dfs(1, 0), write(query(T[1], 0, n, 0));
return 0;
}