I.II.[ZJOI2019]语言
一开始看错题,以为同一种语言会被普及多次,然后就成了神题不会做。一看题解,发现自己看错题了,原来是垃圾题。
一个点所能到达的点,只有与它在同一条路径上出现过的点,换句话说就是经过它全部路径的并。
全部路径的并很好搞,就是全部路径端点建出虚树的大小。虚树大小也很好搞,就是将所有点按dfs序排序后,两两相邻点(包括一头一尾)间距离之和的二分之一。
于是就要对每个点维护所有经过它的路径的端点。对于一条路径,其相当于是对经过的所有点的端点集合内多加了两个点,也即静态路径加,树上差分随手搞。
但是树上差分就意味着一个节点的端点集合是其所有儿子的端点集合之并。于是随手线段树合并维护一下即可。
需要注意的是,这样搞会使得一个点对被计算两遍,答案应该除以 \(2\)。
时间复杂度 \(O(n\log n)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll res;
int n,m,dfn[100100],pa[100100],rev[100100],fir[100100],tot,lim,st[200100][20],LG[200100],dep[100100],rt[100100],cnt;
vector<int>v[100100];
void dfs(int x,int fa){
pa[x]=fa,dfn[x]=++tot,rev[tot]=x,st[++lim][0]=x,fir[x]=lim,dep[x]=dep[fa]+1;
for(auto y:v[x])if(y!=fa)dfs(y,x),st[++lim][0]=x;
}
int MIN(int x,int y){return dep[x]<dep[y]?x:y;}
int LCA(int x,int y){
x=fir[x],y=fir[y];
if(x>y)swap(x,y);
int k=LG[y-x+1];
return MIN(st[x][k],st[y-(1<<k)+1][k]);
}
int DIS(int x,int y){return dep[x]+dep[y]-2*dep[LCA(x,y)];}
#define mid ((l+r)>>1)
struct SegTree{
int lson,rson,lval,rval,sum,tms;
}seg[3201000];
void pushup(int &x){
if(!seg[x].lson&&!seg[x].rson){x=0;return;}
if(seg[x].lson&&!seg[x].rson){seg[x].lval=seg[seg[x].lson].lval,seg[x].rval=seg[seg[x].lson].rval,seg[x].sum=seg[seg[x].lson].sum;return;}
if(!seg[x].lson&&seg[x].rson){seg[x].lval=seg[seg[x].rson].lval,seg[x].rval=seg[seg[x].rson].rval,seg[x].sum=seg[seg[x].rson].sum;return;}
seg[x].lval=seg[seg[x].lson].lval,seg[x].rval=seg[seg[x].rson].rval;
seg[x].sum=seg[seg[x].lson].sum+seg[seg[x].rson].sum+DIS(seg[seg[x].lson].rval,seg[seg[x].rson].lval);
}
void turnon(int &x,int l,int r,int P){
if(l>P||r<P)return;
if(!x)x=++cnt;
if(l==r)seg[x].lval=seg[x].rval=rev[P],seg[x].tms++;
else turnon(seg[x].lson,l,mid,P),turnon(seg[x].rson,mid+1,r,P),pushup(x);
}
void turnoff(int &x,int l,int r,int P){
if(l>P||r<P||!x)return;
if(l==r){seg[x].tms--;if(!seg[x].tms)x=0;}
else turnoff(seg[x].lson,l,mid,P),turnoff(seg[x].rson,mid+1,r,P),pushup(x);
}
void merge(int &x,int y,int l,int r){
if(!x){x=y;return;}
if(!y)return;
if(l==r){seg[x].tms+=seg[y].tms;return;}
merge(seg[x].lson,seg[y].lson,l,mid),merge(seg[x].rson,seg[y].rson,mid+1,r),pushup(x);
}
vector<int>in[100100],out[100100];
void dfs2(int x,int fa){
for(auto y:v[x])if(y!=fa)dfs2(y,x),merge(rt[x],rt[y],1,n);
for(auto i:in[x])turnon(rt[x],1,n,i);
for(auto i:out[x])turnoff(rt[x],1,n,i);
res+=seg[rt[x]].sum+DIS(seg[rt[x]].lval,seg[rt[x]].rval);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),v[x].push_back(y),v[y].push_back(x);
dfs(1,0);
for(int i=2;i<=lim;i++)LG[i]=LG[i>>1]+1;
for(int j=1;j<=LG[lim];j++)for(int i=1;i+(1<<j)-1<=lim;i++)st[i][j]=MIN(st[i][j-1],st[i+(1<<(j-1))][j-1]);
for(int i=1,x,y;i<=m;i++){
scanf("%d%d",&x,&y);
in[x].push_back(dfn[x]),in[x].push_back(dfn[y]);
in[y].push_back(dfn[x]),in[y].push_back(dfn[y]);
int lca=LCA(x,y);
// printf("%d %d %d\n",x,y,lca);
out[lca].push_back(dfn[x]),out[lca].push_back(dfn[y]);
out[pa[lca]].push_back(dfn[x]),out[pa[lca]].push_back(dfn[y]);
}
dfs2(1,0),printf("%lld\n",res>>2);
return 0;
}