给定一棵树 (T = (V, E)) 和点对集合 (mathcal Q subseteq V imes V) ,满足对于所有 ((u, v) in mathcal Q),都有 (u eq v),并且 (u) 是 (v) 在树 (T) 上的祖先。其中 (V) 和 (E) 分别代表树 (T) 的结点集和边集。求有多少个不同的函数 (f) : (E o {0, 1})(将每条边 (e in E) 的 (f(e)) 值置为 (0) 或 (1)),满足对于任何 ((u, v) in mathcal Q),都存在 (u) 到 (v) 路径上的一条边 (e) 使得 (f(e) = 1)。由于答案可能非常大,你只需要输出结果对 (998,244,353)(一个素数)取模的结果。
首先考虑暴力dp,设(f_{u,i})表示以(u)为根的子树已经确定,最深的(f(e)=1)(深度为(e)的较深的儿子的深度)的方案数。
然后枚举儿子(v),考虑(u)和(v)之间的边设不设为(1),根据乘法原理可以写出转移方程:
[f_{u,i}=prod_{vin son(u)}(f_{v,i}+f_{v,dep_v})qquad mx_u < ile dep_u
]
(f_{v,i})表示边设为(0),(f_{v,dep_v})表示边设为(1),(mx_u)表示(u)向上最深的在询问的祖先的深度。
然后我们观察这个转移方程,每次是形如(abcde o (a+x)(b+x)(c+x)(d+x)(e+x)),这样子的东西,那么我们可以用线段树维护标记((a,b))表示这个点(f o af+b)。然后合并标记的时候,形如((a,b),(c,d))的标记,((af+b)c+d=acf+bc+d),所以可以合并成((ac,cb+d))。
然后用线段树合并可以解决这个问题。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
const int N = 5e5;
const int p = 998244353;
using namespace std;
int n,edge[N * 2 + 5],nxt[N * 2 + 5],head[N + 5],edge_cnt,m,f[N + 5][514],dep[N + 5],rt[N + 5],mx[N + 5];
vector <int> d[N + 5];
struct node
{
int a,b;
};
node operator *(const node &x,const node &y)
{
return (node){1ll * x.a * y.a % p,(1ll * x.b * y.a % p + y.b) % p};
}
struct Seg
{
int lc[N * 100 + 5],rc[N * 100 + 5],node_cnt;
node tag[N * 100 + 5];
void add(int k,node z)
{
tag[k] = tag[k] * z;
}
void pushdown(int k)
{
if (tag[k].a != 1 || tag[k].b != 0)
{
if (!lc[k])
{
lc[k] = ++node_cnt;
tag[lc[k]] = tag[k];
}
else
add(lc[k],tag[k]);
if (!rc[k])
{
rc[k] = ++node_cnt;
tag[rc[k]] = tag[k];
}
else
add(rc[k],tag[k]);
tag[k] = (node){1,0};
}
}
void modify(int k,int l,int r,int x,int y,node z)
{
if (l >= x && r <= y)
{
add(k,z);
return;
}
int mid = l + r >> 1;
pushdown(k);
if (x <= mid)
modify(lc[k],l,mid,x,y,z);
if (y > mid)
modify(rc[k],mid + 1,r,x,y,z);
}
int query(int k,int l,int r,int x)
{
if (l == r)
return tag[k].b;
int mid = l + r >> 1;
pushdown(k);
if (x <= mid)
return query(lc[k],l,mid,x);
else
return query(rc[k],mid + 1,r,x);
}
int merge(int &x,int &y)
{
if (!rc[x] && !lc[x])
swap(x,y);
if (!lc[y] && !rc[y])
{
tag[x] = tag[x] * (node){tag[y].b,0};
return x;
}
pushdown(x);
pushdown(y);
lc[x] = merge(lc[x],lc[y]);
rc[x] = merge(rc[x],rc[y]);
return x;
}
}tree;
void add_edge(int u,int v)
{
edge[++edge_cnt] = v;
nxt[edge_cnt] = head[u];
head[u] = edge_cnt;
}
void dfs(int u,int fa)
{
dep[u] = dep[fa] + 1;
for (int i = 0;i < d[u].size();i++)
mx[u] = max(mx[u],dep[d[u][i]]);
mx[u] = max(mx[u],mx[fa]);
rt[u] = ++tree.node_cnt;
tree.modify(rt[u],1,n,mx[u] + 1,dep[u],(node){0,1});
for (int i = head[u];i;i = nxt[i])
{
int v = edge[i];
if (v == fa)
continue;
dfs(v,u);
rt[u] = tree.merge(rt[u],rt[v]);
}
if (u != 1)
tree.modify(rt[u],1,n,1,dep[u],(node){1,tree.query(rt[u],1,n,dep[u])});
}
int main()
{
scanf("%d",&n);
int u,v;
for (int i = 1;i < n;i++)
{
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
scanf("%d",&m);
for (int i = 1;i <= m;i++)
{
scanf("%d%d",&u,&v);
d[v].push_back(u);
}
dfs(1,0);
cout<<tree.query(rt[1],1,n,1)<<endl;
return 0;
}