I.IV.[NOI2020]命运
半年前水了份 \(n^2\) 暴力,没想到过了出题人用脚造的数据。这里是正解。
考虑DP。因为若两条路径呈包含关系,则更长的那条显然没用,于是设 \(f_{i,j}\) 表示所有下端在 \(i\) 子树内且未被满足的路径中,上端最深的那条的深度。明显,要且仅要满足这条路径的限制,子树中所有现行的路径都可以被满足。
考虑转移。假设我们现在想要合并某父亲 \(x\) 与儿子 \(y\)。
- 边 \((x,y)\) 选择。
则所有 \(x\) 中满足 \(i\leq dep_x\) 的 \(f_{x,i}\),都可以从 \(y\) 中所有满足 \(j\leq dep_y\) 的 \(f_{y,j}\) 转移而来。
于是有 \(f_{x,i}\leftarrow f_{x,i}\times\sum\limits_{j=0}^{dep_y}f_{y,j}\)。
- 边 \((x,y)\) 不选。
则最深的那条可以来自于 \(x\),亦可来自于 \(y\)。
于是有 \(f_{x,\max\{i,j\}}\leftarrow f_{x,i}\times f_{y,j}\)。
变换形式,得到
\(f_{x,i}\leftarrow\Big(f_{x,i}\sum\limits_{j=0}^{i-1}f_{y,j}\Big)+\Big(f_{y,i}\sum\limits_{j=0}^{i-1}f_{x,j}\Big)+f_{x,i}f_{y,i}\)
最后,两坨东西怼一块,得到最终转移式
\(f_{x,i}\leftarrow f_{x,i}\Big(\sum\limits_{j=0}^{i-1}f_{y,j}+\sum\limits_{j=0}^{dep_y}f_{y,j}\Big)+\Big(f_{y,i}\sum\limits_{j=0}^{i-1}f_{x,j}\Big)+f_{x,i}f_{y,i}\)
考虑用线段树合并维护。那三大坨前缀和,上界是 \(dep_y\) 的那坨对于不同的 \(i\) 是相同的,可以直接在 \(y\) 的线段树内维护;剩下的两坨,在合并的时候顺便维护一下即可。
时间复杂度 \(O(n\log n)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int n,m,dep[500100],DP,cnt,rt[500100];
vector<int>v[500100],u[500100];
void dfs1(int x,int fa){
dep[x]=dep[fa]+1,DP=max(DP,dep[x]);
for(auto y:v[x])if(y!=fa)dfs1(y,x);
}
#define mid ((l+r)>>1)
struct SegTree{
int lson,rson,sum,tag;
}seg[20010000];
void MUL(int x,int y){seg[x].sum=1ll*seg[x].sum*y%mod,seg[x].tag=1ll*seg[x].tag*y%mod;}
void pushdown(int x){MUL(seg[x].lson,seg[x].tag),MUL(seg[x].rson,seg[x].tag),seg[x].tag=1;}
void pushup(int x){seg[x].sum=(seg[seg[x].lson].sum+seg[seg[x].rson].sum)%mod;}
void setbound(int &x,int l,int r,int P,int sum){
if(l>P)return;
if(r<P){x=0;return;}
if(!x)x=++cnt,seg[x].tag=1;
if(l==r){(seg[x].sum+=sum)%=mod;return;}
pushdown(x),setbound(seg[x].rson,mid+1,r,P,(sum+seg[seg[x].lson].sum)%mod),setbound(seg[x].lson,l,mid,P,sum),pushup(x);
}
void setempty(int &x,int l,int r){
if(!x)x=++cnt,seg[x].tag=1;seg[x].sum++;
if(l!=r)setempty(seg[x].lson,l,mid);
}
int query(int x,int l,int r,int P){
if(l>P||!x)return 0;
if(r<=P)return seg[x].sum;
pushdown(x);return (query(seg[x].lson,l,mid,P)+query(seg[x].rson,mid+1,r,P))%mod;
}
void merge(int &x,int y,int l,int r,int sumx,int sumy){
if(!x){MUL(y,sumx),x=y;return;}
if(!y){MUL(x,sumy);return;}
if(l==r){seg[x].sum=(1ll*seg[x].sum*sumy%mod+1ll*seg[y].sum*sumx%mod+1ll*seg[x].sum*seg[y].sum%mod)%mod;return;}
pushdown(x),pushdown(y);
merge(seg[x].rson,seg[y].rson,mid+1,r,(sumx+seg[seg[x].lson].sum)%mod,(sumy+seg[seg[y].lson].sum)%mod);
merge(seg[x].lson,seg[y].lson,l,mid,sumx,sumy);
pushup(x);
}
void iterate(int x,int l,int r){
if(!x)return;
printf("%d:[%d,%d]:%d\n",x,l,r,seg[x].sum);
iterate(seg[x].lson,l,mid),iterate(seg[x].rson,mid+1,r);
}
void dfs2(int x,int fa){
setempty(rt[x],0,DP);
for(auto y:v[x])if(y!=fa)dfs2(y,x),merge(rt[x],rt[y],0,DP,0,query(rt[y],0,DP,dep[x]));
for(auto i:u[x])setbound(rt[x],0,DP,i,0);
// printf("%d:\n",x);
// iterate(rt[x],0,DP);
}
int main(){
scanf("%d",&n);
for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),v[x].push_back(y),v[y].push_back(x);
dfs1(1,0);
scanf("%d",&m);
for(int i=1,x,y;i<=m;i++)scanf("%d%d",&x,&y),u[y].push_back(dep[x]);
dfs2(1,0);
printf("%d\n",query(rt[1],0,DP,0));
return 0;
}