description
给你一棵(n)个节点的树以及一个长为(m)的序列,序列每个位置上的值(in[1,n]),你需要求出把序列中所有长度为偶数的区间内所有数拿出来在树上以最小代价匹配的代价之和模(998244353)。
sol
首先拿出偶数个点在树上匹配这个问题,根据贪心,我们一定会让这些点在尽可能深的位置匹配。换句话说,每棵子树中未匹配的点至多只有一个。
那么我们考虑每一条边,这一条边会被计算贡献当且仅当这条边连接的子树里有奇数个选出的点。
那么我们相当于是要求对于每一条边,有多少个长度为偶数的区间使得这条边连接的子树里有奇数个区间内的点。
考虑把这棵子树里的点在原序列中标记为(1),其他点标记为(0),然后对这个序列做一遍前缀和。
那么我们要求的东西实际上就是,满足(i equiv j pmod 2,s_j-s_iequiv1pmod 2)的((i,j))对数,这相当于是区间([i+1,j])的和为奇数。注意这里的(i,j)可以取(0)。
然后我们就只需要对每一棵子树维护出这个东西的数量就好啦。
暴力的话就是暴(for)子树里的每一个点,可以修改原序列的单点,也就是给前缀和的一段后缀(+1),用线段树维护每个区间有多少个下标是奇数/偶数的点的前缀和是奇数,修改就相当于区间翻转,复杂度(O(n^2log m))。
然后就(dsu on tree)一波,复杂度变成了(O(nlog nlog m))。
其实直接线段树合并就(O(nlog m))了,但是懒得写了qaq。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi(){
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 1e5+5;
const int mod = 998244353;
int n,m,to[N<<1],nxt[N<<1],ww[N<<1],head[N],cnt,pre[N],fst[N];
int dis[N],sz[N],son[N],rev[N<<2],ans;
struct data{
int od,en;
data operator + (const data &b) const{
return (data){od+b.od,en+b.en};
}
}t[N<<2];
void link(int u,int v,int w){
to[++cnt]=v;nxt[cnt]=head[u];ww[cnt]=w;head[u]=cnt;
}
void reverse(int x,int l,int r){
t[x].od=(r+1)/2-l/2-t[x].od;
t[x].en=r/2-(l-1)/2-t[x].en;
rev[x]^=1;
}
void pushdown(int x,int l,int r){
if (!rev[x]) return;int mid=l+r>>1;
reverse(x<<1,l,mid);reverse(x<<1|1,mid+1,r);
rev[x]=0;
}
void modify(int x,int l,int r,int ql,int qr){
if (l>=ql&&r<=qr) {reverse(x,l,r);return;}
pushdown(x,l,r);int mid=l+r>>1;
if (ql<=mid) modify(x<<1,l,mid,ql,qr);
if (qr>mid) modify(x<<1|1,mid+1,r,ql,qr);
t[x]=t[x<<1]+t[x<<1|1];
}
void dfs_pre(int u,int f){
sz[u]=1;
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f){
dis[to[e]]=ww[e],dfs_pre(to[e],u),sz[u]+=sz[to[e]];
if (sz[to[e]]>sz[son[u]]) son[u]=to[e];
}
}
void add(int u){
for (int i=fst[u];i;i=pre[i]) modify(1,1,m,i,m);
}
void cal(int u){
int tim=(1ll*t[1].od*((m+1)/2-t[1].od)+1ll*t[1].en*(m/2+1-t[1].en))%mod;
ans=(1ll*dis[u]*tim+ans)%mod;
}
void upt(int u,int f){
add(u);
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f) upt(to[e],u);
}
void dfs(int u,int f,bool keep){
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f&&to[e]!=son[u])
dfs(to[e],u,0);
if (son[u]) dfs(son[u],u,1);
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f&&to[e]!=son[u]) upt(to[e],u);
add(u);cal(u);
if (!keep) upt(u,f);
}
int main(){
n=gi();m=gi();
for (int i=1;i<n;++i){
int u=gi(),v=gi(),w=gi();
link(u,v,w);link(v,u,w);
}
for (int i=1;i<=m;++i){
int v=gi();pre[i]=fst[v];fst[v]=i;
}
dfs_pre(1,0);dfs(1,0,1);printf("%d
",ans);return 0;
}