题目大意:
给你一颗$n(nle5000)$个点的树,选3个点使得它们两两距离相等,问共有几种选法。
思路:
首先我们不难发现一个性质:对于每3个符合条件的点,我们总能找到一个点使得这个点到那3个点距离相等。
我们不妨称之为“中转点”。
显然答案就是对于每个中转点,不同子树中到这个点距离相等的三元点对的数量。
我们可以先枚举每个点作为中转点的情况。
暴力求出以这个点的每个子结点为根的子树,不同深度的结点的数量(显然深度就是到这个中转点的距离)。
我们可以用calc[i][j]表示对于当前中转点,来自j个不同子树的深度为i的结点共有多少种不同的组合。
转移方程为calc[i][j]+=calc[i][j-1]*cnt[i]。
1 #include<cstdio> 2 #include<cctype> 3 #include<vector> 4 #include<cstring> 5 typedef long long int64; 6 inline int getint() { 7 register char ch; 8 while(!isdigit(ch=getchar())); 9 register int x=ch^'0'; 10 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 11 return x; 12 } 13 const int N=5001; 14 std::vector<int> e[N]; 15 inline void add_edge(const int &u,const int &v) { 16 e[u].push_back(v); 17 e[v].push_back(u); 18 } 19 int n,cnt[N]; 20 int64 calc[N][4]; 21 void dfs(const int &x,const int &par,const int &dep) { 22 cnt[dep]++; 23 for(unsigned i=0;i<e[x].size();i++) { 24 const int &y=e[x][i]; 25 if(y==par) continue; 26 dfs(y,x,dep+1); 27 } 28 } 29 int main() { 30 n=getint(); 31 for(register int i=1;i<n;i++) { 32 add_edge(getint(),getint()); 33 } 34 int64 ans=0; 35 for(register int x=1;x<=n;x++) { 36 memset(calc,0,sizeof calc); 37 for(register int i=1;i<=n;i++) calc[i][0]=1; 38 for(register unsigned i=0;i<e[x].size();i++) { 39 memset(cnt,0,sizeof cnt); 40 const int &y=e[x][i]; 41 dfs(y,x,1); 42 for(register int j=3;j;j--) { 43 for(register int i=0;i<=n;i++) { 44 calc[i][j]+=calc[i][j-1]*cnt[i]; 45 } 46 } 47 } 48 for(register int i=1;i<=n;i++) { 49 ans+=calc[i][3]; 50 } 51 } 52 printf("%lld ",ans); 53 return 0; 54 } 现在考虑当$nle10^5$的情况。
考虑$nle10^5$的情况。
$f[i][j]$标示以$i$为根的子树中,与$i$距离为$j$的点数。$g[i][j]$标示以$i$为根的子树中,与$i$距离为$j$的点对数。则不难想到一种$O(n^2)$的转移:
$$
egin{align*}
&g[x][i-1]+=g[y][i]\
&g[x][i+1]+=f[x][i+1] imes f[y][i]\
&f[x][i+1]+=f[y][i]
end{align*}
$$
边界为$f[x][0]=1$。
考虑优化这个转移,不难发现,若$y$是$x$枚举到的第一个子结点,则转移时只进行第一、第三个转移。因此我们可以考虑通过指针来实现,免去转移的过程。
将原树进行长链剖分,对于重边直接修改指针,对于轻边暴力转移,可以证明这样是$O(n)$的。
1 #include<list> 2 #include<cstdio> 3 #include<cctype> 4 typedef long long int64; 5 inline int getint() { 6 register char ch; 7 while(!isdigit(ch=getchar())); 8 register int x=ch^'0'; 9 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 10 return x; 11 } 12 const int N=100001; 13 std::list<int> e[N]; 14 int dep[N],bot[N]; 15 int64 mem[N*6],ans,*f[N],*g[N],*ptr=mem; 16 inline void add_edge(const int &u,const int &v) { 17 e[u].push_back(v); 18 e[v].push_back(u); 19 } 20 void dfs(const int &x,const int &par) { 21 dep[bot[x]=x]=dep[par]+1; 22 for(std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) { 23 const int &y=*i; 24 if(y==par) continue; 25 dfs(y,x); 26 if(dep[bot[y]]>dep[bot[x]]) bot[x]=bot[y]; 27 } 28 for(register std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) { 29 const int &y=*i; 30 if(y==par||(bot[y]==bot[x]&&x!=1)) continue; 31 f[bot[y]]=ptr+=dep[bot[y]]-dep[x]+1; 32 g[bot[y]]=++ptr; 33 ptr+=(dep[bot[y]]-dep[x])*2+1; 34 } 35 } 36 void dp(const int &x,const int &par) { 37 for(std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) { 38 const int &y=*i; 39 if(y==par) continue; 40 dp(y,x); 41 if(bot[y]==bot[x]) { 42 f[x]=f[y]-1; 43 g[x]=g[y]+1; 44 } 45 } 46 ans+=g[x][0]; 47 f[x][0]=1; 48 for(register std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) { 49 const int &y=*i; 50 if(y==par||bot[y]==bot[x]) continue; 51 for(register int i=0;i<=dep[bot[y]]-dep[x];i++) { 52 ans+=f[x][i-1]*g[y][i]+g[x][i+1]*f[y][i]; 53 } 54 for(register int i=0;i<=dep[bot[y]]-dep[x];i++) { 55 g[x][i-1]+=g[y][i]; 56 g[x][i+1]+=f[x][i+1]*f[y][i]; 57 f[x][i+1]+=f[y][i]; 58 } 59 } 60 } 61 int main() { 62 const int n=getint(); 63 for(register int i=1;i<n;i++) { 64 add_edge(getint(),getint()); 65 } 66 dfs(1,0); 67 dp(1,0); 68 printf("%lld ",ans); 69 return 0; 70 }