一、题目
二、解法
其实这题挺难的,而且我觉得网上的题解讲的有点不清楚 (...)
看到题目要求的是 (f(x)^k) 并且 (kleq 200),搞一个傻逼斯特林反演即可:
问题变成了对于所有虚树,求出恰好选 (i) 条边的方案数。这个问题并不好做,我们严格遵循 性质->状态->dp
的思路:
首先分析性质:虚树是由点集生成的,所以我们计数时要考虑点的状态(以点为主体);再此基础上我们统计选出 (i) 条边的方案,这两步计数是分开的,但是我们可以放在一个 (dp) 中处理。
然后分析状态:我们想维护选出 (i) 条边的方案,而树上 (dp) 基本上是从子树合并这种情况入手的,我们考虑两棵虚树如何合并。我们已知的是 (u) 为根的原始树选出 (j) 条边的方案,还已知 (v) 为根的子树选出 (k) 条边的方案。相当于我们把某序列分成两部分,已知各部分的组合方案要求总数的组合方案,那么由枚举法可知把 (j+k=i) 的情况相乘然后累加上去即可,这样我们完成了虚树的合并。
最后设计 (dp):设 (f[u][i]) 表示以 (u) 为根的所有虚树中选出 (i) 条边的方案数,虚树合并转移:
还有一个部分是考虑点的状态,在合并完之后我们考虑算出 以 (u) 为根的原始虚树(+)以 (u) 为根的 (v) 的虚树(+)以 (u) 为根的合并虚树即可。第一三种情况好算,第二种情况讨论一下父边的是否选即可。
最后证明一下时间复杂度,合并转移时的边界是 (min(siz[u],k) imes min(siz[v],k)),相当于从 (u) 子树中选出 ( t dfs) 序后 (k) 个和 (v) 子树中 ( t dfs) 序的前 (k) 个来算。考虑单个点的贡献,发现它只会和 ( t dfs) 序相邻的左右 (2k) 个点转移并产生 (1) 的贡献,所以时间复杂度 (O(nk))
三、总结
这道题说明组合数也是可合并的,不要被吓到了,按照思路步步分析即可。
对于第二维是背包合并的树 (dp),设第二维大小是 (k),可以知道复杂度大小是 (O(nk)) 的。
#include <cstdio>
#include <iostream>
using namespace std;
const int N = 205;
const int M = 100005;
const int MOD = 1e9+7;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,k,tot,f[M],ans[M],siz[M],dp[M][N],g[N],S[N][N];
struct edge
{
int v,next;
}e[2*M];
void add(int &x,int y) {x=(x+y)%MOD;}
void dfs(int u,int fa)
{
dp[u][0]=1;siz[u]=1;
for(int x=f[u];x;x=e[x].next)
{
int v=e[x].v;
if(v==fa) continue;
dfs(v,u);
for(int i=0;i<=min(siz[u],k);i++)
for(int j=0;j<=min(siz[v],k);j++)
{
if(i+j>k) break;
add(g[i+j],dp[u][i]*dp[v][j]);
add(ans[i+j],dp[u][i]*dp[v][j]);
add(g[i+j+1],dp[u][i]*dp[v][j]);
add(ans[i+j+1],dp[u][i]*dp[v][j]);
}
siz[u]+=siz[v];
for(int i=0;i<=k;i++)
{
add(dp[u][i],g[i]),g[i]=0;
add(dp[u][i],dp[v][i]);
if(i) add(dp[u][i],dp[v][i-1]);
}
}
}
signed main()
{
n=read();k=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
e[++tot]=edge{v,f[u]},f[u]=tot;
e[++tot]=edge{u,f[v]},f[v]=tot;
}
dfs(1,0);
S[0][0]=1;int res=0,fac=1;
for(int i=1;i<=k;i++)
for(int j=1;j<=i;j++)
S[i][j]=(S[i-1][j-1]+S[i-1][j]*j)%MOD;
for(int i=1;i<=k;i++,fac=fac*i%MOD)
add(res,S[k][i]*fac%MOD*ans[i]);
printf("%lld
",res);
}