首先,对于从每个点出发的路径,答案一定是过这个点的路径所覆盖的点数。然后可以做树上差分,对每个点记录路径产生总贡献,然后做一个树剖维护,对每个点维护一个动态开点线段树。最后再从根节点开始做一遍dfs,把每个节点对应的线段树启发式合并即可。时空复杂度均为O(nlog2n)。听说还有一个log的做法,但感觉太神仙不会,不过2个log能过就不管了。
#include<bits/stdc++.h> #define lson l,mid,tr[rt].lc #define rson mid+1,r,tr[rt].rc using namespace std; const int N=1e5+7; struct Seg{int lc,rc,tag,sz;}tr[N*300]; int n,m,cnt,fa[N],sz[N],dep[N],son[N],top[N],dfn[N],rt[N]; long long ans; vector<int>G[N]; void dfs1(int u,int f) { sz[u]=1,dep[u]=dep[f]+1,fa[u]=f; for(int i=0;i<G[u].size();i++) if(G[u][i]!=f) { dfs1(G[u][i],u),sz[u]+=sz[G[u][i]]; if(sz[son[u]]<sz[G[u][i]])son[u]=G[u][i]; } } void dfs2(int u,int tp) { top[u]=tp,dfn[u]=++cnt; if(son[u])dfs2(son[u],tp); for(int i=0;i<G[u].size();i++)if(G[u][i]!=fa[u]&&G[u][i]!=son[u])dfs2(G[u][i],G[u][i]); } void pushup(int rt,int len) {if(tr[rt].tag)tr[rt].sz=len;else tr[rt].sz=tr[tr[rt].lc].sz+tr[tr[rt].rc].sz;} void modify(int rt,int v,int len){tr[rt].tag+=v;pushup(rt,len);} void update(int L,int R,int v,int l,int r,int&rt) { if(!rt)rt=++cnt; if(L<=l&&r<=R){modify(rt,v,r-l+1);return;} int mid=l+r>>1; if(L<=mid)update(L,R,v,lson); if(R>mid)update(L,R,v,rson); pushup(rt,r-l+1); } void Update(int&rt,int x,int y,int v) { while(top[x]!=top[y])update(dfn[top[x]],dfn[x],v,1,n,rt),x=fa[top[x]]; if(x!=y)update(dfn[y]+1,dfn[x],v,1,n,rt); } int lca(int x,int y) { while(top[x]!=top[y])if(dep[top[x]]<dep[top[y]])y=fa[top[y]];else x=fa[top[x]]; return dep[x]<dep[y]?x:y; } int merge(int u,int v,int l,int r) { if(!u||!v)return u+v; tr[u].tag+=tr[v].tag; if(l==r){pushup(u,1);return u;} int mid=l+r>>1; tr[u].lc=merge(tr[u].lc,tr[v].lc,l,mid); tr[u].rc=merge(tr[u].rc,tr[v].rc,mid+1,r); pushup(u,r-l+1); return u; } void dfs3(int u,int f) { for(int i=0;i<G[u].size();i++) if(G[u][i]!=f)dfs3(G[u][i],u),rt[u]=merge(rt[u],rt[G[u][i]],1,n); ans+=max(tr[rt[u]].sz-1,0); } int main() { scanf("%d%d",&n,&m); for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x); dfs1(1,0); dfs2(1,1); cnt=0; for(int i=1,x,y,f;i<=m;i++) { scanf("%d%d",&x,&y); f=lca(x,y); Update(rt[x],x,fa[f],1),Update(rt[x],y,f,1); Update(rt[y],x,fa[f],1),Update(rt[y],y,f,1); Update(rt[f],x,fa[f],-1),Update(rt[f],y,f,-1); if(fa[f])Update(rt[fa[f]],x,fa[f],-1),Update(rt[fa[f]],y,f,-1); } dfs3(1,0); cout<<ans/2; }