【UOJ#388】【UNR#3】配对树(线段树,dsu on tree)
题面
题解
考虑一个固定区间怎么计算答案,把这些点搞下来建树,然后(dp),不难发现一个点如果子树内能够匹配的话就一定会匹配完,所以(dp)可以做到线性。
那么根据上面的(dp)方式,一条边会被匹配到,当且仅当把这条边删掉之后,两个连通块内分别有奇数个目标点。那么如果我们考虑枚举每一条边,然后把子树内的点给标记一下,于是变成了在原序列上求有多少个偶数区间满足有偶数个点被标记,这个问题可以做一个前缀和,把奇偶位置拆开,于是答案就是两个位置中的奇数个数乘上偶数个数。这个过程可以拿线段树维护。
接下来就变成每次要标记子树中的所有点,每次暴力显然不合理。
改成(dsu)这个复杂度就很合理了。这个复杂度似乎是两个(log)的。
#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
#define MAX 100100
#define MOD 998244353
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int n,m,ans;
struct Line{int v,next,w;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v,int w){e[cnt]=(Line){v,h[u],w};h[u]=cnt++;}
int wf[MAX],sz[MAX],hson[MAX];
void dfs1(int u,int ff)
{
sz[u]=1;
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(v==ff)continue;
wf[v]=e[i].w;dfs1(v,u);sz[u]+=sz[v];
if(sz[v]>sz[hson[u]])hson[u]=v;
}
}
#define lson (now<<1)
#define rson (now<<1|1)
struct Node{int s[2][2],r;}t[MAX<<2];
void pushup(int now)
{
t[now].s[0][0]=(t[lson].s[0][0]+t[rson].s[0][0])%MOD;
t[now].s[0][1]=(t[lson].s[0][1]+t[rson].s[0][1])%MOD;
t[now].s[1][0]=(t[lson].s[1][0]+t[rson].s[1][0])%MOD;
t[now].s[1][1]=(t[lson].s[1][1]+t[rson].s[1][1])%MOD;
}
void Build(int now,int l,int r)
{
if(l==r){t[now].s[l&1][0]+=1;return;}
int mid=(l+r)>>1;
Build(lson,l,mid);
Build(rson,mid+1,r);
pushup(now);
}
void putrev(int now)
{
swap(t[now].s[0][0],t[now].s[0][1]);
swap(t[now].s[1][0],t[now].s[1][1]);
t[now].r^=1;
}
void pushdown(int now)
{
if(!t[now].r)return;
putrev(lson);putrev(rson);
t[now].r^=1;
}
void Modify(int now,int l,int r,int L,int R)
{
if(L<=l&&r<=R){putrev(now);return;}
int mid=(l+r)>>1;pushdown(now);
if(L<=mid)Modify(lson,l,mid,L,R);
if(R>mid)Modify(rson,mid+1,r,L,R);
pushup(now);
}
vector<int> V[MAX];
bool vis[MAX];
void upd(int u,int ff)
{
for(int v:V[u])Modify(1,0,m,v,m);
for(int i=h[u];i;i=e[i].next)
if(e[i].v!=ff&&!vis[e[i].v])upd(e[i].v,u);
}
void dfs(int u,int ff,int tp)
{
for(int i=h[u];i;i=e[i].next)
if(e[i].v!=ff&&e[i].v!=hson[u])
dfs(e[i].v,u,0);
if(hson[u])dfs(hson[u],u,1),vis[hson[u]]=true;
upd(u,ff);
int cnt=(1ll*t[1].s[0][0]*t[1].s[0][1]+1ll*t[1].s[1][0]*t[1].s[1][1])%MOD;
ans=(ans+1ll*cnt*wf[u])%MOD;
vis[hson[u]]=false;
if(!tp)upd(u,ff);
}
int main()
{
n=read();m=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read(),w=read();
Add(u,v,w);Add(v,u,w);
}
for(int i=1;i<=m;++i)V[read()].push_back(i);
Build(1,0,m);dfs1(1,0);dfs(1,0,0);
printf("%d
",ans);
return 0;
}