今天才发现自己根本不会树形背包,我太菜了。
一般的树形背包是这样做的:
看上去,它的复杂度是 $O(nk^2)$ 的。
第一种优化:
这里,如果第二维的大小和子树大小有关,同时又不超过一个常数 $k$ 。例如:第二维表示子树内选了多少个点,那么通过一些精妙的分析和上界优化,复杂度就可以变成 $O(nk)$ 了。
以下的 $siz_x$ 表示合并 $son$ 这个子树前 $x$ 子树的大小(注意:不是 $x$ 的真实子树大小,这里很重要)。
这样分析出来的复杂度就是 $O(nk)$ .
证明:摘自这里;
首先,定义 $T(n)$ 为处理 $n$ 这棵子树时所用的时间,$f(n)$ 为处理 $n$ 这个点时所用的时间。
$T(x)=left(sum_{f_y=x} T_{y} ight)+f(x)\f(x)=min(m,siz(y_1)) imes min(m,siz(y_1))+min(m,siz(y_1)+siz(y_2)) imes min(m,siz(y_1))\ ~~~~~~~~~~~+cdots+min(m,siz(x)) imes min(m,siz(y_n))$
现在进行一番放缩,把每个乘法的前一项统一变成 $min(m,siz(x))$ ,这样显然只会使答案变大,所以分析出来的复杂度上界就应该是正确的。
$f(x)=min(m,siz(x)) imes left(sumlimits_{f_y=x} min(m,siz(y)) ight)$
再次放缩,把后面括号里的 $min$ 直接扔掉,得:
$f(x)=min(m,siz(x)) imes left(sumlimits_{f_y=x} siz(y) ight)\~~~~~~~~=min(m,siz(x)) imes siz(x)$
对于 $siz(x)<m$ 的点,首先考虑他的子树都是叶子的情况:
$T(x)=siz(x)^2+sum 1$
对于任意 $siz(x)<m$ 的点,递归证明,由于 “平方和小于和的平方” ,所以 $T(x)$ 与 $siz(x)^2$ 同阶;
对于 $siz(x)>m$ 的点,首先考虑它的所有子树都小于 $m$ 的情况:
$T(x)=m imes siz(x)+sum siz(j)^2$
接着放缩可得,$T(x)$ 与 $m imes siz(x)$ 同阶;
继续使用递归证明的技巧,考虑某一层出现了子树大于 $m$ 的情况:
$T(x)=m imes siz(x)+sum siz(j)^2+sum siz(j) imes m$
所以,$T(x)$ 还是与 $m imes siz(x)$ 同阶;
综上所述,这种做法的复杂度是 $n imes k$ 。
选课加强版:https://www.luogu.org/problem/U53204
1 # include <cstdio> 2 # include <iostream> 3 # include <cstring> 4 # include <vector> 5 # define R register int 6 7 using namespace std; 8 9 const int N=100005; 10 struct edge 11 { 12 int to,nex; 13 }; 14 int si,h=0,n,m; 15 edge g[N<<1]; 16 int firs[N],a[N],siz[N]; 17 bool vis[N]; 18 int dp[100000100]; 19 20 void add(int u,int v) 21 { 22 g[++h].to=v; 23 g[h].nex=firs[u]; 24 firs[u]=h; 25 } 26 27 void dfs(int x) 28 { 29 dp[x*(m+1)+1]=a[x]; 30 siz[x]=1; 31 vis[x]=true; 32 int j; 33 for (R i=firs[x];i;i=g[i].nex) 34 { 35 j=g[i].to; 36 if(vis[j]) continue; 37 dfs(j); 38 for (int k=min(siz[x]+siz[j],m);k>=1;--k) 39 for (int z=max(1,k-siz[x]);z<=min(siz[j],k-1);++z) 40 dp[x*(m+1)+k]=max(dp[x*(m+1)+k],dp[x*(m+1)+k-z]+dp[j*(m+1)+z]); 41 siz[x]+=siz[j]; 42 } 43 } 44 45 int read() 46 { 47 int x=0; 48 char c=getchar(); 49 while (!isdigit(c)) c=getchar(); 50 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar(); 51 return x; 52 } 53 54 int main() 55 { 56 scanf("%d%d",&n,&m); m++; 57 memset(g,0,sizeof(g)); 58 for (R i=1;i<=n;i++) 59 { 60 si=read(),a[i]=read(); 61 add(i,si); 62 add(si,i); 63 } 64 dfs(0); 65 printf("%d",dp[m]); 66 return 0; 67 }
这种做法比较好写,而且还有一个优点,就是它事实上求出了每棵子树的 $dp$ 值,换句话说,它可以统计到每个连通块的答案。当然,它也有一定的局限性,那就是第二维必须和子树的大小有关,否则复杂度就不对了。下面,再来介绍另一种不要求第二维大小的做法。
第二种优化:
首先对树求出后序遍历序,设 $f[i][j]$ 表示:dfs序编号在i之前的点当前都满足依赖条件时的背包;$j$ 表示什么因题目而异;看上去有点难以理解?解释一下,“当前满足依赖条件”是指,在仅考虑前 $i$ 个点构成的森林的情况下,每个点都满足依赖关系(当前已经出现的祖先都被选了,还没出现的祖先不用考虑)。转移方程十分简单,在往森林里一个点时,如果不选它,那它的子树就都不能选,因为它的子树的dfs序是一段连续的区间,我们直接跳回到还没有考虑过这棵子树时的状态;如果选它,那就从上一个点进行转移即可。复杂度显然为 $n imes m$ 。
这种方法比上一种还好写,但是它也有一个问题,那就是只能算出以指定点为根时的答案,而不能做任意联通块。
1 # include <cstdio> 2 # include <iostream> 3 # include <cstring> 4 # include <vector> 5 # define R register int 6 7 using namespace std; 8 9 const int N=100005; 10 struct edge 11 { 12 int to,nex; 13 }; 14 int si,h=0,n,m; 15 edge g[N<<1]; 16 int firs[N],a[N],siz[N]; 17 bool vis[N]; 18 int dp[100000100]; 19 int no[N],cnt; 20 21 void add(int u,int v) 22 { 23 g[++h].to=v; 24 g[h].nex=firs[u]; 25 firs[u]=h; 26 } 27 28 int read() 29 { 30 int x=0; 31 char c=getchar(); 32 while (!isdigit(c)) c=getchar(); 33 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar(); 34 return x; 35 } 36 37 void dfs (int x) 38 { 39 int j; 40 siz[x]=1; 41 for (R i=firs[x];i;i=g[i].nex) 42 { 43 j=g[i].to; 44 if(vis[j]) continue; 45 vis[j]=1; 46 dfs(j); 47 siz[x]+=siz[j]; 48 } 49 no[++cnt]=x; 50 } 51 52 int main() 53 { 54 scanf("%d%d",&n,&m); m++; 55 memset(g,0,sizeof(g)); 56 for (R i=1;i<=n;i++) 57 { 58 si=read(),a[i]=read(); 59 add(i,si); 60 add(si,i); 61 } 62 vis[0]=1; 63 dfs(0); 64 int x; 65 for (R i=1;i<=cnt;++i) 66 { 67 x=no[i]; 68 for (R j=1;j<=m;++j) 69 dp[i*(m+1)+j]=max(dp[(i-1)*(m+1)+j-1]+a[x],dp[(i-siz[x])*(m+1)+j]); 70 } 71 printf("%d",dp[cnt*(m+1)+m]); 72 return 0; 73 }
学习了以上知识后,我们来做一道题?
Shopping:https://www.lydsy.com/JudgeOnline/problem.php?id=4182
这好像是个权限题?那我来概述一下题意:
给定一棵 $n$ 个点的树,每个点上有一种物品 $(w,c,d)$ 表示它的价值是 $w$ ,价格是 $c$ ,有 $d$ 个。你有 $m$ 元钱,并希望它们能买到价值和最大的物品,还有一个限制是买了物品的点必须是树上的一个连通块,求最大价值。$n<=500,m<=4000,d<=100$
一个显然的思路是直接上树形背包的第一种做法,因为它事实上是在每个连通块最高的点处对这个连通块进行了处理,可以直接求出这道题的答案。
不过,别忘了第一种优化的前提,如果你以为它任何条件下都适用,那就会 $TLE$ 得很惨。在这道题中,即使是很小的子树也可以有满的 $dp$ 数组,所以复杂度就是 $O(nm^2)$
看起来第一种做法已经走进死路,让我们来考虑一下第二种做法吧。
第二种做法可以枚举根,复杂度 $n^2m$ ,感觉已经有了不少改进呢!可以发现,枚举根是一个比较愚蠢的方法,因为在做第一次的时候,就已经把所有与这个根有交的连通块都算过了,接下来只需要对每个子树再做就好了。子树大小有可能不平均?点分治!
1 # include <cstdio> 2 # include <iostream> 3 # include <cstring> 4 # include <vector> 5 # define R register int 6 7 using namespace std; 8 9 const int N=502; 10 int T,n,m,h,x,y,cnt,rt,ans,S,d; 11 int firs[N],siz[N],no[N],vis[N],w[N],c[N],maxs[N]; 12 int dp[N][4005]; 13 struct edge 14 { 15 int too,nex; 16 }g[N<<1]; 17 struct thi 18 { 19 int c,w; 20 thi (int a=0,int b=0) { c=a; w=b; } 21 }; 22 vector <thi> v[N]; 23 24 void add (int x,int y) 25 { 26 g[++h].nex=firs[x]; 27 firs[x]=h; 28 g[h].too=y; 29 } 30 31 int read() 32 { 33 int x=0; 34 char c=getchar(); 35 while (!isdigit(c)) c=getchar(); 36 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar(); 37 return x; 38 } 39 40 void get_root (int x,int f) 41 { 42 siz[x]=1,maxs[x]=0; 43 int j; 44 for (R i=firs[x];i;i=g[i].nex) 45 { 46 j=g[i].too; 47 if(vis[j]||f==j) continue; 48 get_root(j,x); 49 siz[x]+=siz[j]; 50 maxs[x]=max(maxs[x],siz[j]); 51 } 52 maxs[x]=max(maxs[x],S-siz[x]); 53 if(maxs[x]<maxs[rt]) rt=x; 54 } 55 56 void dfs (int x,int f) 57 { 58 int j; siz[x]=1; 59 for (R i=firs[x];i;i=g[i].nex) 60 { 61 j=g[i].too; 62 if(vis[j]||j==f) continue; 63 dfs(j,x); 64 siz[x]+=siz[j]; 65 } 66 no[++cnt]=x; 67 } 68 69 void pdc (int x) 70 { 71 cnt=0; 72 dfs(x,0); 73 for (R i=0;i<=cnt;++i) 74 for (R j=0;j<=m;++j) 75 dp[i][j]=0; 76 for (R i=1;i<=cnt;++i) 77 { 78 int a=no[i],vs=v[a].size(); 79 for (R j=0;j<=m;++j) 80 dp[i][j]=max(dp[i][j],dp[ i-siz[a] ][j]); 81 for (R k=0;k<vs;++k) 82 for (R j=m;j>=v[a][k].c;--j) 83 dp[i][j]=max(dp[i][j],max(dp[i-1][ j-v[a][k].c ]+v[a][k].w,dp[i][ j-v[a][k].c ]+v[a][k].w)); 84 } 85 for (R i=1;i<=m;++i) 86 ans=max(ans,dp[cnt][i]); 87 } 88 89 void solve (int x) 90 { 91 vis[x]=1; 92 pdc(x); 93 int j; 94 for (R i=firs[x];i;i=g[i].nex) 95 { 96 j=g[i].too; 97 if(vis[j]) continue; 98 rt=0; maxs[rt]=n; S=siz[j]; 99 get_root(j,0); 100 solve(rt); 101 } 102 } 103 104 void t4182() 105 { 106 n=read(),m=read(); 107 memset(firs,0,sizeof(firs)); 108 memset(g,0,sizeof(g)); 109 memset(vis,0,sizeof(vis)); 110 h=0; 111 for (R i=1;i<=n;++i) 112 v[i].clear(); 113 for (R i=1;i<=n;++i) 114 w[i]=read(); 115 for (R i=1;i<=n;++i) 116 c[i]=read(); 117 for (R i=1;i<=n;++i) 118 { 119 d=read(); 120 int x=1; 121 while(x<=d) 122 { 123 d-=x; 124 v[i].push_back(thi(c[i]*x,w[i]*x)); 125 x<<=1; 126 } 127 if(d>0) v[i].push_back(thi(c[i]*d,w[i]*d)); 128 } 129 for (R i=1;i<n;++i) 130 { 131 x=read(),y=read(); 132 add(x,y); add(y,x); 133 } 134 rt=0; 135 S=maxs[rt]=n; 136 get_root(1,0); 137 ans=0; 138 solve(rt); 139 printf("%d ",ans); 140 } 141 142 int main() 143 { 144 scanf("%d",&T); 145 while(T--) 146 t4182(); 147 return 0; 148 }
---shzr