题意:给定带点权的树,问多少个连通块,其乘积<=M; N<=2000,M<1e6;
思路:连通块-->分治; 由于普通的树DP在合并的时候复杂度会高一个M,所以用依赖背包来做。 (当然,由于体积分布是离散的,可能有些选手用map也可以过,这样避免了每次都for(i,1,M),取决于数据吧)。
那么现在的复杂度就是O(NlogN*M) ,空间为O(N*M),尚待优化。
这里非常巧妙的把<sqrt(M)的和大于sqrt(M)的分开保存,那么前者就是正常的背包,表示背包里存了多少东西; 后者可以看成背包里最多还可以存多少东西。 那么复杂度就变成了O(Nsqrt(M)logN); 空间O(Nsqrt(M)); 就可以过了。
#include<bits/stdc++.h> #define rep(i,a,b) for(int i=a;i<=b;i++) #define ll long long using namespace std; const int maxn=2010; const int Mod=1e9+7; int w[maxn],ans; int dp1[maxn][maxn>>1],dp2[maxn][maxn>>1],dp[maxn][maxn]; int Laxt[maxn],Next[maxn<<1],To[maxn<<1],cnt; int son[maxn],sz[maxn],vis[maxn],rt,SZ; int times,p[maxn],M,qM; void MOD(int &X){if(X>Mod) X-=Mod;} void add(int u,int v) { Next[++cnt]=Laxt[u]; Laxt[u]=cnt; To[cnt]=v; } void getroot(int u,int f) //得到重心 { sz[u]=1; son[u]=0; for(int i=Laxt[u];i;i=Next[i]){ int v=To[i]; if(vis[v]||v==f) continue; getroot(v,u); sz[u]+=sz[v]; son[u]=max(son[u],sz[v]); } son[u]=max(son[u],SZ-son[u]); if(rt==0||son[u]<son[rt]) rt=u; } void dfs(int u,int f) //得到dfs序。 { p[++times]=u; sz[u]=1; for(int i=Laxt[u];i;i=Next[i]){ if(To[i]==f||vis[To[i]]) continue; dfs(To[i],u); sz[u]+=sz[To[i]]; } } void cal() { rep(i,1,times+1){ memset(dp1[i],0,sizeof(dp1[i])); memset(dp2[i],0,sizeof(dp2[i])); } dp1[times+1][1]=1; for(int i=times;i>=1;i--){ int x=w[p[i]]; rep(j,1,min(qM,M/x)) { int k=j*x; if(k<=qM) MOD(dp1[i][k]+=dp1[i+1][j]); else MOD(dp2[i][M/k]+=dp1[i+1][j]); } rep(j,x,qM) { MOD(dp2[i][j/x]+=dp2[i+1][j]); } rep(j,1,qM) MOD(dp1[i][j]+=dp1[i+sz[p[i]]][j]); rep(j,1,qM) MOD(dp2[i][j]+=dp2[i+sz[p[i]]][j]); } rep(i,1,qM) MOD(ans+=dp1[1][i]); rep(i,1,qM) MOD(ans+=dp2[1][i]); ans--; //减去为空的情况 if(ans<0) ans+=Mod; } void solve(int u) //分治 { vis[u]=1; times=0; dfs(u,0); cal(); for(int i=Laxt[u];i;i=Next[i]){ if(vis[To[i]]) continue; SZ=sz[To[i]]; rt=0; getroot(To[i],0); solve(rt); } } int main() { int T,N,u,v; scanf("%d",&T); while(T--){ scanf("%d%d",&N,&M); qM=sqrt(M); rep(i,1,N) scanf("%d",&w[i]); rep(i,1,N) Laxt[i]=vis[i]=0; cnt=0; rep(i,1,N-1){ scanf("%d%d",&u,&v); add(u,v); add(v,u); } SZ=N; rt=0; getroot(1,0); ans=0; solve(rt); printf("%d ",ans); } return 0; }