原文链接www.cnblogs.com/zhouzhendong/p/UOJ470.html
前言
做完情报中心来看这个题突然发现两题有相似之处然后就会做了。
题解
首先,我们考虑将所有答案点对分为两类。
- 一个节点对其祖先的贡献。
- 来自一个节点的不同子树之间节点的贡献。
第一种情况非常简单,这里不加赘述。
对于第二种情况,我们首先考虑简单做法:
考虑对于每一个节点分开处理。
按照某一种顺序枚举它的子树,对于所有“一端在当前子树内,另一端在当前子树之前的子树”的路径,我们求它们的贡献。
接下来提到的“虚树“中默认加入当前节点。
考虑对当前子树内路径端点建立虚树,然后在虚树上 dfs。对于虚树上的一个节点,它在另外一个子树中有相同语言的节点就是它在虚树上的子树中的所有端点的另一端点构成的虚树大小。
一个节点的子树中所有端点对应的点构成的虚树可以由儿子节点的虚树合并而来。
如果事先将虚树内的节点存在 set 中,则可以在关于点数较少的虚树的复杂度内合并两棵虚树,具体地说是 size * log(n) 。
考虑使用 DSU on tree,我们可以得到一个 (O(nlog ^ 3n)) 的做法。
注意到,在很多问题里,线段树合并都可以处理树上启发式合并的问题,而且复杂度都会下降。这里也类似,考虑合并两个 dfs序 分别独立的虚树时,只需要特殊考虑 dfs序 小的虚树的 dfs序最大节点和 dfs序 大的虚树的dfs序最小节点到根的路径交即可。
于是,我们考虑采用线段树合并维护子树虚树 size,由于线段树合并中需要求 LCA,所以我们考虑用 ST表 来求 LCA,做到单次询问 (O(1)),即可得到一个总时间复杂度 (O((n+m)log n)) 的做法。
代码
#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof x)
#define For(i,a,b) for (int i=(a);i<=(b);i++)
#define Fod(i,b,a) for (int i=(b);i>=(a);i--)
#define fi first
#define se second
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define outval(x) cerr<<#x" = "<<x<<endl
#define outtag(x) cerr<<"---------------"#x"---------------"<<endl
#define outarr(a,L,R) cerr<<#a"["<<L<<".."<<R<<"] = ";
For(_x,L,R)cerr<<a[_x]<<" ";cerr<<endl;
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=100005*2;
int n,m;
vector <int> e[N];
struct cha{
int x,y,lca;
int xf,yf;
}a[N];
int depth[N],fa[N][20];
int ett[N],c=0,I[N];
void dfs(int x,int pre,int d){
depth[x]=d,fa[x][0]=pre;
For(i,1,19)
fa[x][i]=fa[fa[x][i-1]][i-1];
ett[I[x]=++c]=x;
for (int y : e[x])
if (y!=pre)
dfs(y,x,d+1),ett[++c]=x;
}
int st[N][20],Log[N];
int min_dep(int x,int y){
return depth[x]<depth[y]?x:y;
}
void Get_ST(){
For(i,2,c)
Log[i]=Log[i>>1]+1;
For(i,1,c){
st[i][0]=ett[i];
For(j,1,19){
st[i][j]=st[i][j-1];
if (i-(1<<(j-1))>0)
st[i][j]=min_dep(st[i][j],st[i-(1<<(j-1))][j-1]);
}
}
}
int LCA(int x,int y){
x=I[x],y=I[y];
if (x>y)
swap(x,y);
int d=Log[y-x+1];
return min_dep(st[x+(1<<d)-1][d],st[y][d]);
}
int Dis(int x,int y){
return depth[x]+depth[y]-2*depth[LCA(x,y)];
}
namespace Seg{
const int S=N*20*2;
int sz[S],lp[S],rp[S],ls[S],rs[S];
int cnt=0;
void pushup(int rt){
if (!sz[ls[rt]]&&!sz[rs[rt]])
sz[rt]=lp[rt]=rp[rt]=0;
else if (!sz[rs[rt]])
sz[rt]=sz[ls[rt]],lp[rt]=lp[ls[rt]],rp[rt]=rp[ls[rt]];
else if (!sz[ls[rt]])
sz[rt]=sz[rs[rt]],lp[rt]=lp[rs[rt]],rp[rt]=rp[rs[rt]];
else {
sz[rt]=sz[ls[rt]]+sz[rs[rt]]-depth[LCA(rp[ls[rt]],lp[rs[rt]])];
lp[rt]=lp[ls[rt]],rp[rt]=rp[rs[rt]];
}
}
void Ins(int &rt,int L,int R,int x){
if (!rt)
rt=++cnt,sz[rt]=ls[rt]=rs[rt]=lp[rt]=rp[rt]=0;
if (L==R){
lp[rt]=rp[rt]=x,sz[rt]=depth[x];
return;
}
int mid=(L+R)>>1;
if (I[x]<=mid)
Ins(ls[rt],L,mid,x);
else
Ins(rs[rt],mid+1,R,x);
pushup(rt);
}
int Merge(int x,int y,int L,int R){
if (!x||!y)
return x|y;
if (L==R)
return x;
int mid=(L+R)>>1,rt=++cnt;
ls[rt]=Merge(ls[x],ls[y],L,mid);
rs[rt]=Merge(rs[x],rs[y],mid+1,R);
pushup(rt);
return rt;
}
}
int go_son(int x,int f){
Fod(i,19,0)
if (depth[x]-(1<<i)>depth[f])
x=fa[x][i];
return x;
}
LL ans=0;
vector <int> qid[N];
int up[N];
bool cmp_qid(int x,int y){
return I[a[x].xf]<I[a[y].xf];
}
bool cmpI(int x,int y){
return I[x]<I[y];
}
int rt[N];
void Solve(int x,int *id,int n){
static int t[N],st[N];
int tc=0,top=0;
For(i,0,n-1)
t[++tc]=a[id[i]].x;
t[++tc]=x;
sort(t+1,t+tc+1,cmpI);
tc=unique(t+1,t+tc+1)-t-1;
For(i,1,tc)
rt[t[i]]=0;
For(i,0,n-1)
Seg::Ins(rt[a[id[i]].x],1,c,a[id[i]].y);
For(_,1,tc){
int i=t[_];
if (top){
int lca=LCA(i,st[top]);
while (depth[st[top]]>depth[lca]){
int now=st[top];
if (depth[st[top-1]]>=depth[lca]){
ans+=(LL)(depth[now]-depth[st[top-1]])*(Seg::sz[rt[now]]-depth[x]);
rt[st[top-1]]=Seg::Merge(rt[st[top-1]],rt[now],1,c);
top--;
}
else {
ans+=(LL)(depth[now]-depth[lca])*(Seg::sz[rt[now]]-depth[x]);
rt[lca]=rt[now];
st[top]=lca;
break;
}
}
}
st[++top]=i;
}
while (top>1){
int now=st[top];
ans+=(LL)(depth[now]-depth[st[top-1]])*(Seg::sz[rt[now]]-depth[x]);
rt[st[top-1]]=Seg::Merge(rt[st[top-1]],rt[now],1,c);
top--;
}
}
void Solve(int x,int pre){
for (int y : e[x])
if (y!=pre)
Solve(y,x),up[x]=max(up[x],up[y]-1);
ans+=up[x];
sort(qid[x].begin(),qid[x].end(),cmp_qid);
int s=(int)qid[x].size();
for (int i=0,j;i<s;i=j+1){
for (j=i;j+1<s&&I[a[qid[x][i]].xf]==I[a[qid[x][j+1]].xf];j++);
Solve(x,&qid[x][i],j-i+1);
}
}
int main(){
n=read(),m=read();
For(i,1,n-1){
int x=read(),y=read();
e[x].pb(y),e[y].pb(x);
}
dfs(1,0,0);
Get_ST();
For(i,1,m){
int x=a[i].x=read(),y=a[i].y=read(),lca=a[i].lca=LCA(x,y);
up[x]=max(up[x],depth[x]-depth[lca]);
up[y]=max(up[y],depth[y]-depth[lca]);
if (x!=lca&&y!=lca){
a[i].xf=go_son(x,lca);
a[i].yf=go_son(y,lca);
if (I[a[i].xf]<I[a[i].yf])
swap(a[i].xf,a[i].yf),swap(a[i].x,a[i].y);
qid[lca].pb(i);
}
}
Solve(1,0);
cout<<ans<<endl;
return 0;
}