原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ33.html
题解
首先我们把问题转化成处理一个数组 ans ,其中 ans[i] 表示 d(u,a) 和 d(v,a) 同时为 i 的倍数的 (u,v) 个数。(最后求答案的时候只要莫比乌斯反演回来就好了。)
注意一下我的代码中对于 (u,v) 有祖先关系的是分开考虑的。
先点分治。
对于一个点分中心 x ,我们把答案分两部分考虑。
1. 在子树 x 中满足 LCA(u,v) = x 的 (u,v) 对于答案的贡献。
2. u,v 其中一个点在子树 x 中,另一个不在。
第一部分非常好求,不加赘述。
第二部分,我们考虑定义一个阀值 S ,我们预处理出 Smod[i][j] 表示 子树 x 中,到 x 的距离 mod i = j 的点的个数。这样,我们就可以 O(1) 得到 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 。这样,我们就可以在 $O(nS)$ 的复杂度内求出对于 $ans[i](ileq S)$ 的贡献。 对于 i>S 的,我们可以直接暴力计算 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 ,复杂度为 $O(n^2/S)$ 。取 $S = O(sqrt{n})$ 最优。故处理一个点分中心的复杂度为 $O(nsqrt{n})$ (假设当前连通块大小为 n)。
所以总的时间复杂度为 $O(nsqrt{n})$ 。
代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=200005,M=500; int n; vector <int> e[N]; LL ans[N],ans2[N]; int depth[N],fa[N]; void dfs(int x,int pre,int d){ fa[x]=pre,depth[x]=d; for (auto y : e[x]) if (y!=pre) dfs(y,x,d+1); } int vis[N],size[N],Size; int Maxsize[N],rt; void get_root(int x,int pre){ size[x]=1,Maxsize[x]=0; for (auto y : e[x]) if (y!=pre&&!vis[y]){ get_root(y,x); size[x]+=size[y]; Maxsize[x]=max(Maxsize[x],size[y]); } Maxsize[x]=max(Maxsize[x],Size-size[x]); if (!rt||Maxsize[rt]>Maxsize[x]) rt=x; } vector <int> d[N]; void get_size(int x,int pre){ size[x]=1; for (auto y : e[x]) if (y!=pre&&!vis[y]) get_size(y,x),size[x]+=size[y]; } void getd(int x,int pre,int d,vector <int> &v){ while (d>=(int)v.size()) v.push_back(0); v[d]++; for (auto y : e[x]) if (y!=pre&&!vis[y]) getd(y,x,d+1,v); } LL S[N]; LL Smod[M][M]; void solve(int x){ rt=0; get_root(x,0); assert(rt!=0); vis[x=rt]=1; for (int i=0;i<=Size;i++) S[i]=0; int Mx=0; for (auto y : e[x]) if (!vis[y]){ get_size(y,0); if (depth[y]<depth[x]) continue; d[y].clear(); getd(y,0,1,d[y]); int t=d[y].size()-1; for (int i=1;i<=t;i++){ for (int j=i<<1;j<=t;j+=i) d[y][i]+=d[y][j]; ans[i]+=(LL)d[y][i]*S[i]; S[i]+=d[y][i]; } Mx=max(Mx,t); d[y].clear(); } for (int i=Mx;i>=1;i--) for (int j=i<<1;j<=Mx;j+=i) S[i]-=S[j]; S[0]++; int base=(int)(0.4*sqrt(Mx)+0.5); for (int i=1;i<=base;i++){ for (int j=0;j<i;j++) Smod[i][j]=0; for (int j=0;j<=Mx;j++) Smod[i][j%i]+=S[j]; } for (int f=fa[x],pre=x;f&&!vis[f];pre=f,f=fa[f]){ d[f].clear(); for (auto y : e[f]) if (!vis[y]&&y!=pre&&y!=fa[f]) getd(y,f,1,d[f]); int t=d[f].size()-1; for (int i=1;i<=t;i++){ for (int j=i<<1;j<=t;j+=i) d[f][i]+=d[f][j]; int tmp=(i-(depth[x]-depth[f])%i)%i; if (i<=base) ans[i]+=(LL)d[f][i]*Smod[i][tmp]; else for (int j=tmp;j<=Mx;j+=i) ans[i]+=(LL)d[f][i]*S[j]; } d[f].clear(); } for (auto y : e[x]) if (!vis[y]) Size=size[y],solve(y); } int main(){ n=read(); for (int i=2;i<=n;i++){ int x=read(); e[i].push_back(x); e[x].push_back(i); } dfs(1,0,0); Size=n; solve(1); for (int i=n;i>=1;i--) for (int j=i<<1;j<=n;j+=i) ans[i]-=ans[j]; for (int i=1;i<=n;i++) ans2[depth[i]]++; for (int i=n;i>=1;i--) ans2[i]+=ans2[i+1]; for (int i=1;i<n;i++) printf("%lld ",ans[i]+ans2[i]); return 0; }