Solution
对每个点进行贡献统计,搞了半天搞不出来。结果是对每条边进行统计,计算会被经过多少次。能想到这个应该是由于题目要求算路径长度,这和每条边被经过的次数挂钩。这个转换真是巧妙。
首先有一个显然的性质,集合点一定在若干关键点的中间位置。
考虑枚举每一条边,由于是一棵树,所以断掉这条边会把图分成两半。记其中一边的节点个数为 (s),那么另一边有 (n-s) 个。考虑在其中一边选 (i) 个关键点((iin[1,m-1]),由于一边至少要有一个关键点,不然根本不会经过该条边),有 (inom{s}{i}) 种情况,对于另一边有 (inom{n-s}{m-i}) 种情况。根据上面的性质,集合点的若干子树关键点个数应该尽量平均。那么对于这条枚举的边,集合点一定在关键点多的那一边,所以应该是关键点少的一边走向多的一边,会有 (min{i,m-i}) 的贡献。那么答案就是
最小值不好直接搞,所以考虑分类讨论。会发现折半讨论后,两种情况是相通的,所以现只讨论一种。注意 (m-1) 是奇数的时候会少一种,直接暴力加上即可。式子化简为
后面一坨直接求,可以不看。前面的 (i) 显然可以吸进去,再把吐出来的 (s) 提到式子前面,变成
上标是个常数,为了方便记为 (k)。现在的任务就是求出后面的式子了。设函数 (G(s)) 为后面的一坨,不能直接化简了,所以考虑其组合意义来递推。
其组合意义是:有一个长为 (n-1) 的序列,一共选 (m-1) 个数,而在前 (s-1) 个数中至多选 (k-1) 个数,剩下的数在后面选。考虑如何推到 (G(s+1)),其计数在前 (s) 个数中至多选 (k-1),只与 (G(s)) 不同如果在前 (s-1) 个数中选了 (k-1) 个数,位置 (s) 就不能选了,所以减去在前面 (s-1) 个数中强制选 (k-1) 个数,强制选位置 (s),在后面剩下的 (n-s) 个数中随便选 (m-k-1) 个数的方案数。那么
可以做到 (O(n)) 递推。最后处理一下边界,即 (G(1))。
注意到当 (i) 是大于 (1) 的正整数的时候,(inom{0}{i-1}) 是 (0),所以 (G(1)) 就只有一项了。
再注意,当 (k=0) 时,有 (G(1)=0)。另一种对应的情况是 (G(n-s))
#include<stdio.h>
#define N 1000007
#define ll long long
#define Mod 1000000007
#define re register
inline int read(){
int x=0,flag=1; char c=getchar();
while(c<'0'||c>'9'){if(c=='-') flag=0;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-48;c=getchar();}
return flag? x:-x;
}
struct E{
int next,to;
}e[N];
int head[N],cnt=0;
inline void add(int id,int to){
e[++cnt]=(E){head[id],to};
head[id]=cnt;
}
ll n,m,s[N];
ll fac[N],K,inv[N],G[N];
ll C(int x,int y){
if(y<0) return 0;
if(x<y) return 0;
if(!x||x==y) return 1;
return fac[x]*inv[y]%Mod*inv[x-y]%Mod;
}
ll qpow(ll x,ll y){
ll ret=1,cnt=0;
while(y>=(1LL<<cnt)){
if(y&(1LL<<cnt)) ret=(ret*x)%Mod;
x=(x*x)%Mod,cnt++;
}
return ret;
}
ll ans=0;
void dfs(int u){
s[u]=1;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to,S;
dfs(v);
s[u]+=(S=s[v]);
ans=(ans+G[S]+G[n-S])%Mod;
if(!(m&1)) ans=(ans+C(S,m/2)*C(n-S,m/2)%Mod*(m/2)%Mod)%Mod;
}
}
int main(){
freopen("meeting.in","r",stdin);
freopen("meeting.out","w",stdout);
fac[0]=1;
for(re int i=1;i<N;i++)
fac[i]=1LL*fac[i-1]*i%Mod;
inv[N-1]=qpow(fac[N-1],Mod-2);
for(re int i=N-1;i>=1;i--)
inv[i-1]=inv[i]*i%Mod;
n=read(),m=read();
for(re int i=2;i<=n;i++) add(read(),i);
K=(m-1)>>1;
if(K) G[1]=C(n-1,m-1);
for(re int i=2;i<=n;i++)
G[i]=((G[i-1]-C(i-2,K-1)*C(n-i,m-K-1)%Mod)%Mod+Mod)%Mod;
for(re int i=2;i<=n;i++) G[i]=1LL*G[i]*i%Mod;
dfs(1);
printf("%lld",ans);
}