首先考虑$prufer$序列,那么问题转化为求
一个长为$n - 2$的序列,总共有$n$个元素,恰有$m$个元素不出现在序列中的方案数
考虑容斥,答案即为 至少$m$个元素不出现 - 至少$m + 1$个不出现 + 至少$m + 2$个不出现......
至少$m$个元素不出现的方案数为$C(n, m) * (n - i)^{n - 2}$
接着考虑容斥系数,通过数学归纳法,我们发现是$C(i, m)$
然后就没了,复杂度$O(n log n)$
注:$n = 1$或者$n = 2$时,树没有$prufer$序列,记得特判
#include <cstdio> #include <cstring> #include <iostream> using namespace std; #define ri register int #define sid 1005000 #define mod 1000000007 int n, m, ans; int inv[sid], fac[sid]; void Init_C() { fac[0] = inv[0] = fac[1] = inv[1] = 1; for(ri i = 2; i <= n; i ++) { inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; fac[i] = 1ll * fac[i - 1] * i % mod; } for(ri i = 2; i <= n; i ++) inv[i] = 1ll * inv[i] * inv[i - 1] % mod; } int C(int n, int m) { if(n < m) return 0; return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod; } int fp(int a, int k) { int ret = 1; for( ; k; k >>= 1, a = 1ll * a * a % mod) if(k & 1) ret = 1ll * ret * a % mod; return ret; } int main() { cin >> n >> m; if(n == 1 || n == 2) { printf("1 "); return 0; } Init_C(); for(ri i = m, j = 1; i <= n; i ++, j *= -1) { ans += (1ll * j * C(i, m) * C(n, i) % mod * fp(n - i, n - 2) % mod); if(ans < 0) ans += mod; if(ans >= mod) ans -= mod; } printf("%d ", ans); return 0; }