ZJOI2019Day2的温暖题,然后考场上只会大常数的(O(nlog^3 n)),就懒得写拿了60pts走人
首先我们简化题意,容易发现每个点能到达的点形成了一个联通块,我们只需要统计出这个联通块的大小即可
再进一步,我们发现如果把每条经过(x)的路径((u,v))上的两个端点取出,并且维护它们之间的最小生成树,这棵生成树的大小就是最后的答案(可以画图或是感性理解)
接下来就考虑怎么维护每个点出去的生成树大小,首先我们强制选择(1)号点,然后用类似于建虚树的方法,每次加入一个新的点就通过LCA来计算距离,从而推出生成树的大小
所以大致思路也有了,我们需要一个能维护区间的支持插入删除的数据结构,那么很容易想到用线段树了
那么怎么维护经过一个点的所有路径呢,其实很套路,因为这里线段树上的基本信息就是一个点的存在与否,因此可以树上差分
具体的,对于一条路径((u,v)),我们在以(1)为根的树上将(u,v)两点打上标记,然后在(operatorname{LCA}(u,v),father_{operatorname{LCA}(u,v})上删除即可
离线之后就是套路的线段树合并了,然后中间转移有一个求(operatorname{LCA})的过程,用欧拉序+RMQ即可做到(O(nlog n))的复杂度
CODE
#include<cstdio>
#include<cctype>
#include<vector>
#define RI register int
#define CI const int
#define Tp template <typename T>
#define pb push_back
using namespace std;
const int N=200005;
struct edge
{
int to,nxt;
}e[N<<1]; vector <int> tag[N]; long long ans;
int head[N],n,m,cnt,x,y,rt[N],dfn[N],anc[N],dep[N],idx;
inline void addedge(CI x,CI y)
{
e[++cnt]=(edge){y,head[x]}; head[x]=cnt;
e[++cnt]=(edge){x,head[y]}; head[y]=cnt;
}
class FileInputOutput
{
private:
static const int S=1<<21;
#define tc() (A==B&&(B=(A=Fin)+fread(Fin,1,S,stdin),A==B)?EOF:*A++)
char Fin[S],*A,*B;
public:
Tp inline void read(T& x)
{
x=0; char ch; while (!isdigit(ch=tc()));
while (x=(x<<3)+(x<<1)+(ch&15),isdigit(ch=tc()));
}
#undef tc
}F;
#define to e[i].to
class Euler_Order_On_Tree
{
private:
static const int P=18;
int f[N<<1][P],log[N];
inline int mindep(CI x,CI y)
{
return dep[x]<dep[y]?x:y;
}
inline void swap(int& x,int& y)
{
int t=x; x=y; y=t;
}
public:
inline void DFS(CI now,CI fa=0)
{
anc[now]=fa; dep[now]=dep[fa]+1; f[++idx][0]=now; dfn[now]=idx;
for (RI i=head[now];i;i=e[i].nxt) if (to!=fa) DFS(to,now),f[++idx][0]=now;
}
inline void init(void)
{
RI i,j; for (log[0]=-1,i=1;i<=idx;++i) log[i]=log[i>>1]+1;
for (j=1;j<P;++j) for (i=1;i+(1<<j)-1<=idx;++i)
f[i][j]=mindep(f[i][j-1],f[i+(1<<j-1)][j-1]);
}
inline int getLCA(int x,int y)
{
if (!x||!y) return 0; x=dfn[x]; y=dfn[y]; if (x>y) swap(x,y);
int k=log[y-x+1]; return mindep(f[x][k],f[y-(1<<k)+1][k]);
}
}T;
class Segment_Tree
{
private:
static const int P=25;
struct segment
{
int ch[2],mi,mx,size;
}node[N*P]; int c[N*P],tot;
#define lc(x) node[x].ch[0]
#define rc(x) node[x].ch[1]
#define L(x) node[x].mi
#define R(x) node[x].mx
#define S(x) node[x].size
inline void pushup(CI now)
{
S(now)=S(lc(now))+S(rc(now))-dep[T.getLCA(R(lc(now)),L(rc(now)))];
L(now)=L(lc(now))?L(lc(now)):L(rc(now)); R(now)=R(rc(now))?R(rc(now)):R(lc(now));
}
public:
inline void modify(int& now,CI p,CI mv,CI l=1,CI r=idx)
{
if (!now) now=++tot; if (l==r)
return (void)(c[now]+=mv,S(now)=c[now]?dep[p]:0,L(now)=R(now)=c[now]?p:0);
int mid=l+r>>1; if (dfn[p]<=mid) modify(lc(now),p,mv,l,mid);
else modify(rc(now),p,mv,mid+1,r); pushup(now);
}
inline void merge(int& x,CI y,CI l=1,CI r=idx)
{
if (!x||!y) return (void)(x|=y); if (l==r)
return (void)(c[x]+=c[y],S(x)|=S(y),L(x)|=L(y),R(x)|=R(y)); int mid=l+r>>1;
merge(lc(x),lc(y),l,mid); merge(rc(x),rc(y),mid+1,r); pushup(x);
}
inline int query(CI now)
{
return S(now)-dep[T.getLCA(L(now),R(now))];
}
#undef lc
#undef rc
#undef L
#undef R
#undef S
}SEG;
inline void DFS(CI now,CI fa=0)
{
for (RI i=head[now];i;i=e[i].nxt) if (to!=fa) DFS(to,now);
for (int it:tag[now]) SEG.modify(rt[now],it,-1);
ans+=SEG.query(rt[now]); SEG.merge(rt[anc[now]],rt[now]);
}
#undef to
int main()
{
//freopen("CODE.in","r",stdin); freopen("CODE.out","w",stdout);
RI i; for (F.read(n),F.read(m),i=1;i<n;++i) F.read(x),F.read(y),addedge(x,y);
for (T.DFS(1),T.init(),i=1;i<=m;++i)
{
F.read(x); F.read(y); int fa=T.getLCA(x,y);
SEG.modify(rt[x],x,1); SEG.modify(rt[x],y,1);
SEG.modify(rt[y],x,1); SEG.modify(rt[y],y,1);
tag[fa].pb(x); tag[fa].pb(y); tag[anc[fa]].pb(x); tag[anc[fa]].pb(y);
}
return DFS(1),printf("%lld",ans>>1LL),0;
}