题目
点这里看题目。
分析
看到链的限制很奇葩, " 存在一个重要的选择 " ,于是就不难想到容斥。
首先定义 ((u_1,v_1)cup (u_2,v_2)) 表示求两条路径的边的并集。
显然容斥式子长这个样子:
24 pts
考虑暴力选出容斥式子中的 (S) ,然后用一番树上差分求出 (|igcup_{pin S}p|) 。
时间复杂度是 (O(2^mn)) 。
32 pts
同样是暴力选出容斥式子中的 (S) ,只不过我们这一次利用虚树动态维护 (|igcup_{pin S} p|) 。
简单讲一讲做法:将所有链底拿出来,按照 DFN 排序为 (a_1,a_2,...,a_p) ,每次覆盖从 (a_i) 到 (operatorname{LCA}(a_i,a_{i-1})) 和 该链链顶 中的较深的点 的那条链,做一下树上差分,最后回收标记。
现在的时间是 (O(2^mmlog_2n+n)) 。
另一个方法是,使用树链剖分维护链并,并且按照格雷码的方式来枚举集合。时间是 (O(2^mlog_2^2n)) 。
40 pts
研究 ing 。
64 pts
直接暴力枚举 (S) 的方法已经行不通了,我们需要换种方法统计后面这一大堆。
不难想到要用树形 DP 。这里的状态设计比较奇妙,可以先自己思考一会儿。
猜到了吗?最暴力的状态还是:
(f(u,S)):以 (u) 为根的子树内,选中了 (S) 集合中的链的容斥方案。
然后你就用另一种方式获得暴力容斥的分数。
注意到我们选取 (S) 集合,实际上是想要知道,从 (u) 到祖先上,哪些边会被限制到。
但是有一点是:如果有 (u) 是 (v) 的祖先,且 (u) 的父边已经被限制,那么 (v) 的父边一定被限制。这就是
这意味着,我们只需要记录最浅的被限制的边的深度(一条边的深度是较浅的那个端点的深度),于是状态可以改为:
(f(u,d,k)):只考虑以 (u) 为根的子树内的边,且此时被限制的边的最小深度是 (d) ,选择了 (k) 条链的方案数。
想到一个老 trick,冷静分析一下,发现完全不需要记录 (k) ,直接在转移的时候乘 (-1) 就行了。于是状态就变成了两维。
(以下记 (dep_u) 为 (u) 的深度)
转移如下:
-
尝试合并一棵子树 (v) 。此时我们需要考虑 ((u,v)) 是否被限制。记 (g) 为合并了 ((u,v)) 之后 (v) 的贡献。
[g(d)=f(v,d) imes 2^{[dge dep_v]} ]然后合并到 (u) 上来就有:
[f(u,d)=sum_{min{i,j}=d}f(u,i) imes g(j) ]这个看似卷积又不太像卷积的东西( (min) 卷积),我们待会儿再来谈怎么优化它。
-
考虑加入一条链,链顶的深度为 (d) ,于是有转移:
[f(u,i)=f(u,i)-sum_{min{k,d}=i} f(u,k) ]冷静分析一下之后发现:
[f(u,k)=egin{cases}0&k<d\-sum_{i=d+1}^{n}f(u,i)&k=d\f(u,k)&k>dend{cases} ]
另外还有一个边界条件,就是最初的时候有 (f(u,n)=1) ,因为压根儿没有限制。
于是你就可以拿到 (O(n^2)) 的暴力 DP 分数。结合 40 pts 的某些做法就可以获得 64 pts 的高分。
100 pts
这里需要用到一个叫做 " 整体 DP " 的技巧。
说得这么玄乎,其实这就是使用数据结构维护 DP 数组,方便进行结构性操作和区间操作。
在这个问题中,由于转移过程中 DP 数组并不会变长或者变短,因此我们可以考虑使用线段树维护 DP 数组。
那么考虑不同的转移:
-
合并子树。此时求 (g) 就相当于直接做区间乘法。
考虑之后合并到 (u) 上来。我们首先记后缀和:
[egin{aligned}su(d)&=sum_{i=d}^nf(u,d)\sg(d)&=sum_{i=d}^ng(d)end{aligned} ]那么转移就是:
[f(u,d)=f(u,d) imes sg(d+1)+g(d) imes su(d+1)+f(u,d) imes g(d) ]如果这是正常的对应项求积,我们就可以直接线段树合并,然后在叶节点乘起来。
现在我们还是可以线段树合并,不过我们需要求出 (sg) 和 (su) ,并且在叶节点加上额外贡献。
当然每次直接查询是不科学的,我们应该利用 cdq 分治的技巧。我们在线段树合并的时候,如果当前区间是 ([l,r]) ,我们就需要维护 (su(r+1)) 和 (sg(r+1)) 。
我们可以首先递归左子树,并且将右子树的 (su) 和 (sg) 贡献给左子树。之后再贡献给右子树。
于是就可以在线段树合并的过程中处理掉这个 (min) 卷积。
-
新增一条链,可以发现加入一条链做了三件事情:
- 区间清空。
- 区间求和。
- 单点赋值。
这三个东西都可以用线段树快速处理。第一个可以直接区间乘法。
一点优化:如果当前点 (u) 是多条链的链尾,我们可以只处理链顶深度最大的那条链。
这是由于,如果有两条链,链顶深度分别为 (d) 和 (d') ,且 (d'<d) ,我们可以说明只有 (d) 会影响结果。
如果我们让 (d) 在 (d') 之后加入 DP ,那么 (d') 的影响就被清除了。
否则就是 (d') 在 (d) 之后加入 DP 。首先对于 (i<d') 和 (i>d') , (f(u,i)) 不会受到影响。考虑 (f(u,d')) ,简单算一下 (sum_{k=d'+1}^nf(u,k)) 可以发现它也是 0 。所以 (d') 对答案仍然没有影响。
于是这道题就做完了,时间是 (O(nlog_2n)) 。实际上代码也不怎么难写,基本上都是常规的算法。
我不是指思维的常规。不过仔细想想这道题的思路和技巧也并不是特别难。
顺便,它有亿点点卡常。
本题一些有价值的点:
- 动态链并的方法,格雷码的应用。
- 状态设计(可以理解为优化状态),基于题目中链的性质。
- 转移中有一点点技巧:考虑边的贡献的时候,先考虑子树内,然后在合并到父亲的时候,把父边的贡献加进来。
- 整体 DP 的优化方法。
- cdq 和线段树合并实现 (min-max) 卷积。
代码
#include <cstdio>
typedef long long LL;
const int mod = 998244353;
const int MAXN = 5e5 + 5, MAXM = 5e5 + 5, MAXLOG = 40;
const int MAXS = MAXN * MAXLOG;
inline void read( int &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
inline void write( int x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
inline int MAX( const int a, const int b )
{
return a > b ? a : b;
}
struct edge
{
int to, nxt;
}Graph[MAXN << 1];
int ttag[MAXS];
int su[MAXS], lch[MAXS], rch[MAXS];
int siz;
int head[MAXN], dep[MAXN], rt[MAXN], mx[MAXN];
int N, M, cnt;
inline int newNode() { int cur = ++ siz; ttag[cur] = 1; return cur; }
inline void upt( const int x ) { su[x] = ( su[lch[x]] + su[rch[x]] ) % mod; }
inline void sub( int &x, const int v ) { x = ( x < v ? x - v + mod : x - v ); }
inline int mul( LL x, const int v ) { x *= v; if( x >= mod ) x %= mod; return x; }
inline void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
inline void mult( int x, const int p ) { if( ! x ) return ; su[x] = mul( su[x], p ), ttag[x] = mul( ttag[x], p ); }
inline void normalize( const int x ) { if( ! x || ttag[x] == 1 ) return; mult( lch[x], ttag[x] ), mult( rch[x], ttag[x] ), ttag[x] = 1; }
inline void addEdge( const int from, const int to )
{
Graph[++ cnt].to = to, Graph[cnt].nxt = head[from];
head[from] = cnt;
}
void multiply( int &x, const int l, const int r, const int segL, const int segR, const int delt )
{
if( ! x ) return ;
if( segL <= l && r <= segR ) { mult( x, delt ); return ; }
int mid = l + r >> 1; normalize( x );
if( segL <= mid ) multiply( lch[x], l, mid, segL, segR, delt );
if( mid < segR ) multiply( rch[x], mid + 1, r, segL, segR, delt );
upt( x );
}
void setVal( int &x, const int l, const int r, const int p, const int v )
{
x = x ? x : newNode();
if( l == r ) { su[x] = v; return ; }
int mid = l + r >> 1; normalize( x );
if( p <= mid ) setVal( lch[x], l, mid, p, v );
else setVal( rch[x], mid + 1, r, p, v );
upt( x );
}
int query( const int x, const int l, const int r, const int segL, const int segR )
{
if( ! x ) return 0;
if( segL <= l && r <= segR ) return su[x];
int mid = l + r >> 1, ret = 0; normalize( x );
if( segL <= mid ) add( ret, query( lch[x], l, mid, segL, segR ) );
if( mid < segR ) add( ret, query( rch[x], mid + 1, r, segL, segR ) );
return ret;
}
int merg( const int x, const int y, const int l, const int r, const int sufX, const int sufY )
{
if( ! x ) { mult( y, sufX ); return y; }
if( ! y ) { mult( x, sufY ); return x; }
if( l == r )
{
su[x] = ( ( LL ) mul( su[x], sufY ) + mul( su[y], sufX ) + mul( su[x], su[y] ) ) % mod;
return x;
}
int mid = l + r >> 1; normalize( x ), normalize( y );
lch[x] = merg( lch[x], lch[y], l, mid, ( sufX + su[rch[x]] ) % mod, ( sufY + su[rch[y]] ) % mod );
rch[x] = merg( rch[x], rch[y], mid + 1, r, sufX, sufY );
upt( x );
return x;
}
void DFS( const int u, const int fa )
{
dep[u] = dep[fa] + 1;
setVal( rt[u], 1, N, N, 1 );
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa )
{
DFS( v, u );
multiply( rt[v], 1, N, dep[v], N, 2 );
rt[u] = merg( rt[u], rt[v], 1, N, 0, 0 );
}
if( mx[u] )
{
multiply( rt[u], 1, N, 1, mx[u], 0 );
int tVal = query( rt[u], 1, N, mx[u] + 1, N );
setVal( rt[u], 1, N, mx[u], mod - tVal );
}
}
void init( const int u, const int fa )
{
dep[u] = dep[fa] + 1;
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa ) init( v, u );
}
int main()
{
read( N ); for( int i = 1, a, b ; i < N ; i ++ ) read( a ), read( b ), addEdge( a, b ), addEdge( b, a );
init( 1, 0 ), read( M ); for( int i = 1, fr, to ; i <= M ; i ++ ) read( fr ), read( to ), mx[to] = MAX( mx[to], dep[fr] );
DFS( 1, 0 ); write( su[rt[1]] ), putchar( '
' );
return 0;
}