原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ33.html
题解
首先我们把问题转化成处理一个数组 ans ,其中 ans[i] 表示 d(u,a) 和 d(v,a) 同时为 i 的倍数的 (u,v) 个数。(最后求答案的时候只要莫比乌斯反演回来就好了。)
注意一下我的代码中对于 (u,v) 有祖先关系的是分开考虑的。
先点分治。
对于一个点分中心 x ,我们把答案分两部分考虑。
1. 在子树 x 中满足 LCA(u,v) = x 的 (u,v) 对于答案的贡献。
2. u,v 其中一个点在子树 x 中,另一个不在。
第一部分非常好求,不加赘述。
第二部分,我们考虑定义一个阀值 S ,我们预处理出 Smod[i][j] 表示 子树 x 中,到 x 的距离 mod i = j 的点的个数。这样,我们就可以 O(1) 得到 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 。这样,我们就可以在 $O(nS)$ 的复杂度内求出对于 $ans[i](ileq S)$ 的贡献。 对于 i>S 的,我们可以直接暴力计算 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 ,复杂度为 $O(n^2/S)$ 。取 $S = O(sqrt{n})$ 最优。故处理一个点分中心的复杂度为 $O(nsqrt{n})$ (假设当前连通块大小为 n)。
所以总的时间复杂度为 $O(nsqrt{n})$ 。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=200005,M=500;
int n;
vector <int> e[N];
LL ans[N],ans2[N];
int depth[N],fa[N];
void dfs(int x,int pre,int d){
fa[x]=pre,depth[x]=d;
for (auto y : e[x])
if (y!=pre)
dfs(y,x,d+1);
}
int vis[N],size[N],Size;
int Maxsize[N],rt;
void get_root(int x,int pre){
size[x]=1,Maxsize[x]=0;
for (auto y : e[x])
if (y!=pre&&!vis[y]){
get_root(y,x);
size[x]+=size[y];
Maxsize[x]=max(Maxsize[x],size[y]);
}
Maxsize[x]=max(Maxsize[x],Size-size[x]);
if (!rt||Maxsize[rt]>Maxsize[x])
rt=x;
}
vector <int> d[N];
void get_size(int x,int pre){
size[x]=1;
for (auto y : e[x])
if (y!=pre&&!vis[y])
get_size(y,x),size[x]+=size[y];
}
void getd(int x,int pre,int d,vector <int> &v){
while (d>=(int)v.size())
v.push_back(0);
v[d]++;
for (auto y : e[x])
if (y!=pre&&!vis[y])
getd(y,x,d+1,v);
}
LL S[N];
LL Smod[M][M];
void solve(int x){
rt=0;
get_root(x,0);
assert(rt!=0);
vis[x=rt]=1;
for (int i=0;i<=Size;i++)
S[i]=0;
int Mx=0;
for (auto y : e[x])
if (!vis[y]){
get_size(y,0);
if (depth[y]<depth[x])
continue;
d[y].clear();
getd(y,0,1,d[y]);
int t=d[y].size()-1;
for (int i=1;i<=t;i++){
for (int j=i<<1;j<=t;j+=i)
d[y][i]+=d[y][j];
ans[i]+=(LL)d[y][i]*S[i];
S[i]+=d[y][i];
}
Mx=max(Mx,t);
d[y].clear();
}
for (int i=Mx;i>=1;i--)
for (int j=i<<1;j<=Mx;j+=i)
S[i]-=S[j];
S[0]++;
int base=(int)(0.4*sqrt(Mx)+0.5);
for (int i=1;i<=base;i++){
for (int j=0;j<i;j++)
Smod[i][j]=0;
for (int j=0;j<=Mx;j++)
Smod[i][j%i]+=S[j];
}
for (int f=fa[x],pre=x;f&&!vis[f];pre=f,f=fa[f]){
d[f].clear();
for (auto y : e[f])
if (!vis[y]&&y!=pre&&y!=fa[f])
getd(y,f,1,d[f]);
int t=d[f].size()-1;
for (int i=1;i<=t;i++){
for (int j=i<<1;j<=t;j+=i)
d[f][i]+=d[f][j];
int tmp=(i-(depth[x]-depth[f])%i)%i;
if (i<=base)
ans[i]+=(LL)d[f][i]*Smod[i][tmp];
else
for (int j=tmp;j<=Mx;j+=i)
ans[i]+=(LL)d[f][i]*S[j];
}
d[f].clear();
}
for (auto y : e[x])
if (!vis[y])
Size=size[y],solve(y);
}
int main(){
n=read();
for (int i=2;i<=n;i++){
int x=read();
e[i].push_back(x);
e[x].push_back(i);
}
dfs(1,0,0);
Size=n;
solve(1);
for (int i=n;i>=1;i--)
for (int j=i<<1;j<=n;j+=i)
ans[i]-=ans[j];
for (int i=1;i<=n;i++)
ans2[depth[i]]++;
for (int i=n;i>=1;i--)
ans2[i]+=ans2[i+1];
for (int i=1;i<n;i++)
printf("%lld
",ans[i]+ans2[i]);
return 0;
}