题目
\(n\leq 3\times 10^5,K\leq 10\)。
思路
考虑用总方案数减去空间简单度不超过 \(k\) 的方案数。
发现 \(k\) 很小,可以枚举所有点 \(i\),那么对于一个 \(|i-j|\leq k\) 的点 \(j\),发现这个点对贡献了路径 \(i\to j\) “两端”点的数量之积。
但是直接计算容易重复,发现每次将是 \(dfs\) 序不超过 3 个区间的点的乘积,那么求出每个点字数点的 \(dfs\) 序区间,然后扔到二维平面上,转换成求矩形面积并的问题。
扫描线+线段树即可。
时间复杂度 \(O(nk\log n)\)。
代码
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=300010,LG=20;
int n,m,tot,cnt1,cnt2,head[N],dfn[N],size[N],f[N][LG+1],dep[N],L[N],R[N];
ll ans;
struct edge
{
int next,to;
}e[N*2];
struct node
{
int x,l,r;
}line1[N*40],line2[N*40];
bool operator <(node x,node y)
{
return x.x<y.x;
}
void add(int from,int to)
{
e[++tot].to=to;
e[tot].next=head[from];
head[from]=tot;
}
void dfs(int x,int fa)
{
dfn[x]=++tot; size[x]=1;
f[x][0]=fa; dep[x]=dep[fa]+1;
for (int i=1;i<=LG;i++)
f[x][i]=f[f[x][i-1]][i-1];
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x);
size[x]+=size[v];
}
}
L[x]=dfn[x]; R[x]=dfn[x]+size[x]-1;
}
int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int i=LG;i>=0;i--)
if (dep[f[x][i]]>=dep[y]) x=f[x][i];
if (x==y) return x;
for (int i=LG;i>=0;i--)
if (f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
return f[x][0];
}
int findson(int x,int y)
{
for (int i=LG;i>=0;i--)
if (dep[f[y][i]]>dep[x]) y=f[y][i];
return y;
}
void insert(int l1,int l2,int r1,int r2)
{
line1[++cnt1]=(node){min(l1,l2),min(r1,r2),max(r1,r2)};
line2[++cnt2]=(node){max(l1,l2)+1,min(r1,r2),max(r1,r2)};
}
bool cmp(node x,node y)
{
return x.x<y.x;
}
struct SegTree
{
int l[N*4],r[N*4],sum[N*4],cnt[N*4];
void build(int x,int ql,int qr)
{
l[x]=ql; r[x]=qr;
if (ql==qr) return;
int mid=(ql+qr)>>1;
build(x*2,ql,mid); build(x*2+1,mid+1,qr);
}
void update(int x,int ql,int qr,int val)
{
if (l[x]==ql && r[x]==qr)
{
sum[x]+=val;
if (sum[x]>0) cnt[x]=r[x]-l[x]+1;
else if (ql==qr) cnt[x]=0;
else cnt[x]=cnt[x*2]+cnt[x*2+1];
return;
}
int mid=(l[x]+r[x])>>1;
if (qr<=mid) update(x*2,ql,qr,val);
else if (ql>mid) update(x*2+1,ql,qr,val);
else update(x*2,ql,mid,val),update(x*2+1,mid+1,qr,val);
if (sum[x]) cnt[x]=r[x]-l[x]+1;
else cnt[x]=cnt[x*2]+cnt[x*2+1];
}
}seg;
int main()
{
int size = 256 << 20; //250M
char*p=(char*)malloc(size) + size;
__asm__("movl %0, %%esp\n" :: "r"(p) );
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0;
dfs(1,0);
for (int i=1;i<=n;i++)
for (int j=i+1;j<=min(n,i+m);j++)
{
bool flag=0;
if (dfn[i]>dfn[j]) swap(i,j),flag=1;
int p=lca(i,j);
if (p==i)
{
int soni=findson(i,j);
if (L[soni]>1) insert(1,L[soni]-1,L[j],R[j]);
if (R[soni]<n) insert(L[j],R[j],R[soni]+1,n);
}
else insert(L[i],R[i],L[j],R[j]);
if (flag) swap(i,j);
}
seg.build(1,1,n);
sort(line1+1,line1+1+cnt1);
sort(line2+1,line2+1+cnt2);
for (int i=1,j=1,k=1;i<=n;i++)
{
for (;line1[j].x==i && j<=cnt1;j++)
seg.update(1,line1[j].l,line1[j].r,1);
for (;line2[k].x==i && k<=cnt2;k++)
seg.update(1,line2[k].l,line2[k].r,-1);
ans+=seg.cnt[1];
}
printf("%lld\n",1LL*n*(n-1)/2LL-ans+n);
return 0;
}