好久之前写的题。。今天不知道为什么拿出来看一眼
感谢qiqi20021026的指导
原题求多少对链的交中包含的边数 (geq k)。
设(g(k))表示每个长度为(k)的链被多少对链覆盖,则答案为(g(K+1)-g(K))
考虑加入每条链,对于直上直下不经过LCA的部分,树上差分维护那些点向上的长度为(K)的链会被覆盖,然后可以求出链的对数。
经过LCA的部分,考虑对树轻重链剖分。
对于每对存在距离为(K)的点对的两条重链,用数据结构维护一下哪些长度为(K)的链会被覆盖。
这里用了一个map套map,维护两个重链之间,在第一条链上每个深度的点的被覆盖次数。
感觉根本说不明白。。毕竟自己是抄的别人代码
#include<bits/stdc++.h>
using namespace std;
#define fp(i,l,r) for(register int (i)=(l);(i)<=(r);++(i))
#define fd(i,l,r) for(register int (i)=(l);(i)>=(r);--(i))
#define fe(i,u) for(register int (i)=front[(u)];(i);(i)=e[(i)].next)
#define mem(a) memset((a),0,sizeof (a))
#define O(x) cerr<<#x<<':'<<x<<endl
#define int long long
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}
void wr(int x){
if(x<0)putchar('-'),x=-x;
if(x>=10)wr(x/10);
putchar('0'+x%10);
}
const int MAXN=151000;
struct tEdge{
int v,next;
}e[MAXN<<1];
int n,m,K,front[MAXN],tcnt,dep[MAXN],top[MAXN],hson[MAXN],sze[MAXN],fa[MAXN],anc[18][MAXN];
inline void adde(int u,int v){
e[++tcnt]=(tEdge){v,front[u]};front[u]=tcnt;
}
void dfs1(int u,int f){
sze[u]=1;fa[u]=anc[0][u]=f;dep[u]=dep[f]+1;
fp(i,1,17)anc[i][u]=anc[i-1][anc[i-1][u]];
fe(i,u){
int v=e[i].v;if(v==f)continue;
dfs1(v,u);sze[u]+=sze[v];
if(sze[v]>sze[hson[u]])hson[u]=v;
}
}
void dfs2(int u,int Tp){
top[u]=Tp;
if(hson[u])dfs2(hson[u],Tp);
fe(i,u){
int v=e[i].v;if(v==fa[u]||v==hson[u])continue;
dfs2(v,v);
}
}
int sum[MAXN],qs[MAXN],qv[MAXN],ans,res;
unordered_map<int,map<int,int> >mp[MAXN];
inline int go(int u,int d){
fd(i,17,0)
if(dep[anc[i][u]]>=d)u=anc[i][u];
return u;
}
inline int glca(int a,int b){
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]])swap(a,b);
a=fa[top[a]];
}
return dep[a]<dep[b]?a:b;
}
inline int add1(int x,int z){
if(dep[z]+K>dep[x])return x;
int t=go(x,dep[z]+K-1);
++sum[x];--sum[t];return t;
}
inline void add2(int u1,int u2,int v1,int v2,int z){
if(u2>v2)swap(u1,v1),swap(u2,v2);//O(u1);O(u2);O(v1);O(v2);O(z);
int L=max(dep[u2],K+2*dep[z]-dep[v1]),R=min(dep[u1],K+2*dep[z]-dep[v2]);
mp[u2][v2][L]++;mp[u2][v2][R+1]--;
}
int b1[50],b2[50],t1,t2;
inline void add(int x,int y){
if(dep[x]<dep[y])swap(x,y);int z=glca(x,y);
if(y==z){add1(x,z);return;}
x=add1(x,z);y=add1(y,z);int u=x,v=y;
t1=0;t2=0;
for(;top[x]!=top[z];x=fa[top[x]])b1[++t1]=x;
for(;top[y]!=top[z];y=fa[top[y]])b2[++t2]=y;
if(x!=z)b1[++t1]=x;if(y!=z)b2[++t2]=y;int j=t2;
fp(i,1,t1)
{
int u1=b1[i],u2=top[u1];
for(j=min(t2,j+1);j;j--){
int v1=b2[j],v2=top[v1];
if(dep[u1]+dep[v1]-2*dep[z]<K)continue;
if(dep[u2]+dep[v2]-2*dep[z]>K)break;
add2(u1,u2,v1,v2,z);
}
}
}
inline int C(int x){return x*(x-1)/2;}
void dp(int u,int f){
fe(i,u){
int v=e[i].v;if(v==f)continue;
dp(v,u);sum[u]+=sum[v];
}
res+=C(sum[u]);
}
inline void solve(int Coe){
res=0;fp(i,1,n)sum[i]=0,mp[i].clear();
fp(i,1,m)add(qs[i],qv[i]);
dp(1,0);
fp(i,1,n)
for(auto x:mp[i]){
int tot=0,lst=0;
for(auto y:x.second){
res+=C(tot)*(y.first-lst);lst=y.first;tot+=y.second;
}
}
ans+=res*Coe;
}
main(){
n=read();m=read();K=read();
fp(i,1,n-1){
int u=read(),v=read();
adde(u,v);adde(v,u);
}
dfs1(1,0);dfs2(1,1);
fp(i,1,m)qs[i]=read(),qv[i]=read();
solve(1);++K;solve(-1);
printf("%lld
",ans);
return 0;
}