一、题目
有一棵大小为 (n) 的无根树,问有多少个连通块的点权之积小于等于 (m)
(nleq 2000,mleq 10^6)
二、解法
不难想到树上背包的做法,但是因为乘法并没有适于背包的性质所以直接 ( t T) 飞了(我还抱有幻想写过一发)
再深层地往下想其实是乘法不支持合并,那么我们就不合并,而把他转化成一个单点加入的问题。方法是先确定一个根,那么选取一个点就说明必须选它的父亲,所以要去选一个子树就必须选这个子树的根。
这本质是在 ( t dfn) 序序列上 (dp) 的过程,但是实际实现中却不需要真正求出这个序列。定义 (dp[u][i]) 为考虑点 (u) 已经加入的子树,乘积为 (i) 的选取方案数。那么访问到儿子的时候只需要用父亲更新它,回溯时再用它更新父亲即可。
但是这个状态的第二维太大了,注意到 (lfloorfrac{x}{nm} floor=lfloorfrac{lfloorfrac{x}{n} floor}{m} floor),所以我们可以把 (m) 整除 (i) 的值定义到状态里面。根据整除分块状态数变成了 (O(sqrt m)),根据结论值相同在以后的转移方法也相同所以正确性得到保证。
还有就是根需要枚举,但是本题做每个根的过程是相对独立的所以不能一边求出。那么我们尝试减少子树大小,使用点分治就可以把复杂度优化成 (O(nsqrt mlog n))
三、总结
不支持合并的树上问题可以考虑转成 ( t dfn) 序上 (dp)
本题使用了等价类的优化方法,逆向思维之后整除分块让状态数大幅减少(整除值相同的划分为等价类)
根独立的 (dp) 问题可以考虑用点分治优化,原理就是缩小子树中的需要考虑节点数。
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
const int M = 2005;
const int N = 1000005;
const int MOD = 1e9+7;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int T,n,m,k,ns,rt,a[M],f[N],w[M],siz[M],mx[M];
int ans,dp[M][M],vis[M];vector<int> g[M];
void add(int &x,int y) {x=(x+y)%MOD;}
void find(int u,int fa)
{
siz[u]=1;mx[u]=0;
for(auto v:g[u]) if(v^fa && !vis[v])
{
find(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],ns-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void dfs(int u,int fa)
{
for(int i=1;i<=k;i++) dp[u][i]=0;
for(int i=1;i<=k;i++) if(w[i]>=a[u])//add the item
add(dp[u][f[w[i]/a[u]]],dp[fa][i]);
for(auto v:g[u]) if(!vis[v] && v^fa)
{
dfs(v,u);
for(int i=1;i<=k;i++)
//select the subtree/or just skip it
add(dp[u][i],dp[v][i]);
}
}
void solve(int u)
{
vis[u]=1;dp[0][k]=1;dfs(u,0);
for(int i=1;i<=k;i++) add(ans,dp[u][i]);
for(auto v:g[u]) if(!vis[v])
{
ns=siz[v];rt=0;
find(v,0);solve(rt);
}
dp[0][k]=0;//which is my mistake
}
signed main()
{
T=read();
while(T--)
{
for(int i=1;i<=k;i++) f[w[i]]=0,w[i]=0;
n=read();m=read();k=ans=0;
for(int i=m,ls=0;i>=1;i--)
{
int x=m/i;
w[f[x]=(x!=ls)?++k:k]=x;
ls=x;
}
for(int i=1;i<=n;i++)
a[i]=read(),g[i].clear(),vis[i]=0;
for(int i=1;i<n;i++)
{
int u=read(),v=read();
g[u].push_back(v);
g[v].push_back(u);
}
mx[0]=ns=n;rt=0;
find(1,0);solve(rt);
printf("%d
",ans);
}
}