这道题卡常啊 !
出题人说 $O(n log^2 n)$ 可过,但我写了个 $O(n log^2 n)$ 的树剖卡了半天常数.
最暴力的做法:枚举区间,然后跑一个树形DP 来求最小匹配.
显然,因为要求匹配值最小,所以一定是能匹配就先匹配.
也就是说递归完 $x$ 的所有儿子后,$x$ 的每一个儿子最多只有 1 个点还没有匹配.
这个时间复杂度是 $O(n^3)$ 的.
然后我们对每一条边分别考虑:
令 $v[x]$ 表示点 $x$ 到其父亲的边权(以 1 为根),那么 $v[x]$ 能产生贡献,当且仅当一个区间中 $x$ 子树中有奇数个点.
这个很好理解,因为如果有奇数个点,就意味着 1 个点没有被匹配到,而需要向上延伸的 $x$ 的父亲,依此类推......
那么就枚举右端点,然后令 $f[x][0/1]$ 分别表示多少个长度为偶数的区间满足在 $x$ 的子树中有偶数/奇数个点.
由于要求区间长度是偶数,我们可以分别以 $1,2$ 为起点各跑一次,每次同时加入两个点来保证长度为偶数.
考虑加入 $x,y$ 后的影响:
$x$ 到 $lca$ 与 $y$ 到 $lca$ (不包括 lca 这个点)的路径上 $f[x][0]=f[x][1]$,$f[x][1]=f[x][0]+1$
不在 $x,y$ 路径上的点 $f[x][1]$ 不变,$f[x][0] leftarrow f[x][0]+1$.
这个暴力修改的话是 $O(n^2)$ 的,可以获得 $50$pts.
满分算法的话就是用树链剖分+线段树来维护上面的东西.
我们无外乎就是要支持:每个节点维护 $f[x][0],f[x][1]$,区间加,区间交换.
然后定义标记 $(rev,x,y)$ 表示是否要交换 $f[x][0],f[x][1]$ 的值,交换后对 $f[x][0]$,$f[x][1]$ 分别加上 $x,y$.
时间复杂度为 $O(n log^2 n)$,但是会有点卡常.
这里说几个卡常技巧:
1. 读入优化
2. 开 long long 要比取模快.
3. 由于上述操作中每次加的数是 1 或 -1,所以这个标记可以直接开 int,然后区间和开 long long.
code:
#include <cstdio> #include <ctime> #include <cstring> #include <algorithm> #define N 100008 #define ll long long #define mod 998244353 #define lson now<<1 #define rson now<<1|1 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int edges,n,m,tim; int nd[N],f[N][2],fa[N]; int hd[N],to[N<<1],nex[N<<1],val[N<<1]; int dep[N],a[N],size[N],top[N],son[N],dfn[N],bu[N]; ll ans; struct data { int rev; int vx,vy; ll sx,sy,sum; data(int rev=0,int vx=0,int vy=0):rev(rev),vx(vx),vy(vy){} }s[N<<2]; inline void add(int u,int v,int c) { nex[++edges]=hd[u]; hd[u]=edges,to[edges]=v,val[edges]=c; } void dfs(int x,int ff) { size[x]=1; fa[x]=ff,dep[x]=dep[ff]+1; for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(y==ff) continue; nd[y]=val[i],dfs(y,x); size[x]+=size[y]; if(size[y]>size[son[x]]) son[x]=y; } } void dfs2(int x,int tp) { top[x]=tp; dfn[x]=++tim; bu[tim]=x; if(son[x]) dfs2(son[x],tp); for(int i=hd[x];i;i=nex[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) dfs2(to[i],to[i]); } inline int get_lca(int x,int y) { while(top[x]!=top[y]) { dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]]; } return dep[x]<dep[y]?x:y; } inline void pushup(int now) { s[now].sx=(ll)(s[lson].sx+s[rson].sx); s[now].sy=(ll)(s[lson].sy+s[rson].sy); } inline void mark_rev(int now) { swap(s[now].sx,s[now].sy); swap(s[now].vx,s[now].vy); s[now].rev^=1; } inline void mark_add(int now,int vx,int vy) { if(vx) (s[now].sx+=(ll)vx*s[now].sum); if(vy) (s[now].sy+=(ll)vy*s[now].sum); if(vx) (s[now].vx+=vx); if(vy) (s[now].vy+=vy); } inline void pushdown(int now) { if(s[now].rev) { s[now].rev=0; mark_rev(lson); mark_rev(rson); } if(s[now].vx||s[now].vy) { mark_add(lson,s[now].vx,s[now].vy); mark_add(rson,s[now].vx,s[now].vy); s[now].vx=s[now].vy=0; } } void build(int l,int r,int now) { s[now]=data(); s[now].sx=0; s[now].sy=0; if(l==r) { s[now].sum=nd[bu[l]]; return; } int mid=(l+r)>>1; build(l,mid,lson),build(mid+1,r,rson); s[now].sum=(ll)(s[lson].sum+s[rson].sum)%mod; } void REV(int l,int r,int now,int L,int R) { if(l>=L&&r<=R) { mark_rev(now); return; } pushdown(now); int mid=(l+r)>>1; if(L<=mid) REV(l,mid,lson,L,R); if(R>mid) REV(mid+1,r,rson,L,R); pushup(now); } void ADD(int l,int r,int now,int L,int R,int vx,int vy) { if(l>=L&&r<=R) { mark_add(now,vx,vy); return; } pushdown(now); int mid=(l+r)>>1; if(L<=mid) ADD(l,mid,lson,L,R,vx,vy); if(R>mid) ADD(mid+1,r,rson,L,R,vx,vy); pushup(now); } inline void upd(int x,int y) { while(top[y]!=top[x]) { ADD(1,n,1,dfn[top[y]],dfn[y],-1,0); REV(1,n,1,dfn[top[y]],dfn[y]); ADD(1,n,1,dfn[top[y]],dfn[y],0,1); y=fa[top[y]]; } if(y!=x) { ADD(1,n,1,dfn[x]+1,dfn[y],-1,0); REV(1,n,1,dfn[x]+1,dfn[y]); ADD(1,n,1,dfn[x]+1,dfn[y],0,1); } } void sol(int st) { int x,y,lca; build(1,n,1); for(int i=st;i<=m;i+=2) { if(i+1>m) break; x=a[i],y=a[i+1]; if(dep[x]>dep[y]) swap(x,y); lca=get_lca(x,y); mark_add(1,1,0); upd(lca,x); upd(lca,y); (ans+=s[1].sy)%=mod; } } char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int rd() { int x=0; char c; while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x; } int main() { // setIO("input"); n=rd(),m=rd(); int x,y,z; for(int i=1;i<n;++i) { x=rd(),y=rd(),z=rd(); if(z>=mod) z-=mod; add(x,y,z),add(y,x,z); } dfs(1,0); dfs2(1,1); for(int i=1;i<=m;++i) a[i]=rd(); sol(1),sol(2); printf("%lld ",ans); return 0; }