抄题解.jpg
发现原来的(O(n^2))的换根(dp)好像行不通了呀
我们考虑非常牛逼的长链剖分
我们设(f[x][j])表示在(x)的子树中距离(x)为(j)的点有多少个
(g[x][j])表示在(x)的子树里,满足如下条件的点对((u,v))的个数
-
设(k=LCA(u,v)),满足(dis(u,k)=dis(v,k)=d)
-
满足(dis(k,x)=d-j)
我们发现可以如果(v)是(x)的儿子,那么距离(v)为(j-1)的点和(x)的距离就是(j),那么到(k)的距离就是(d-j+j=d),和点对到(k)的距离相等
于是我们可以这样合并
[ans+=f[v][j-1] imes g[x][j]
]
自然还有
[ans+=f[x][j-1] imes g[v][j]
]
(f)数组的更新非常简单啊,就是(f[x][j]+=f[v][j-1]),这个我们可以用长链剖分优化到(O(n))
之后是(g)的更新
首先我们有(g[x][j]+=g[v][j+1]),就是到(x)距离为(d-j)的(k)到(v)的距离必然是(d-j-1),这里我们也可以直接长链剖分
之后(g[x][j+1]+=f[x][j+1] imes f[v][j]),这样产生的点对的(LCA)就是(x),到(x)的距离也就是(j+1),符合条件,这里直接暴力转移就好了
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
const int maxn=100006;
inline int read() {
char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[maxn<<1];
int head[maxn],len[maxn],n,num,son[maxn],deep[maxn];
LL tax[maxn*6],*id=tax,*f[maxn],*g[maxn],ans;
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs1(int x) {
for(re int i=head[x];i;i=e[i].nxt) {
if(deep[e[i].v]) continue;
deep[e[i].v]=deep[x]+1;
dfs1(e[i].v);
if(len[e[i].v]>len[son[x]]) son[x]=e[i].v;
}
len[x]=len[son[x]]+1;
}
void dfs(int x) {
f[x][0]=1;
if(son[x]) {
g[son[x]]=g[x]-1;
f[son[x]]=f[x]+1;
dfs(son[x]);
}
ans+=g[x][0];
for(re int i=head[x];i;i=e[i].nxt) {
if(deep[e[i].v]<deep[x]||son[x]==e[i].v) continue;
f[e[i].v]=id;id+=len[e[i].v]+1;
g[e[i].v]=id+len[e[i].v]+1;id+=2*len[e[i].v]+2;
dfs(e[i].v);
for(re int j=len[e[i].v];j>=0;--j) {
if(j) ans+=f[x][j-1]*g[e[i].v][j];
ans+=g[x][j+1]*f[e[i].v][j];
g[x][j+1]+=f[e[i].v][j]*f[x][j+1];
}
for(re int j=0;j<=len[e[i].v];j++) {
if(j) g[x][j-1]+=g[e[i].v][j];
f[x][j+1]+=f[e[i].v][j];
}
}
}
int main() {
n=read();
for(re int x,y,i=1;i<n;i++)
x=read(),y=read(),add(x,y),add(y,x);
deep[1]=1;dfs1(1);
f[1]=id;id+=len[1]+1;
g[1]=id+len[1]+1;//由于我们继承重儿子是g[son[x]]=g[x]-1,所以得在这个指针前面留一些空位置来让后面的状态继承
id+=2*len[1]+2;
dfs(1);printf("%lld
",ans);
return 0;
}