zoukankan      html  css  js  c++  java
  • 树形背包学习笔记

    树形背包的一般形式

    给定一棵有$n$个节点的点权树,要求你从中选出$m$个节点,使得这些选出的节点的点权和最大,一个节点能被选当且仅当其父亲节点被选中,根节点可以直接选。

    $n^3$解法

    原理

    考虑设$f[u][i]$表示在$u$的子树中选择$i$个节点(包括它本身)的最大贡献,则可列出以下转移方程。
    $$
    f[u][i]=max(f[u][j]+f[v][i-j]+d[v]) [j=1...i-1]
    $$
    其中$d[v]$表示点$v$的点权,$i-j$表示在子树$v$中选择$i-j$个节点。

    由于遍历整棵树是$Theta(n)$的,而选取$i$和$j$是$O(m2)$的,所以整个程序的复杂度就是$O(nm2)$的。

    例题

    Luogu P2014 选课

    这是一道树形背包的模板题,可以将题目转化为在$n+1$个节点中选$m+1$个节点。于是最后的答案就是$f[0][m+1]$。

    #include <cstdio>
    #include <algorithm>
    using std::max;
    
    const int N = 3e2 + 10, M = 3e2 + 10;
    int n, m, f[N][N], s[N], son[N][N];
    
    void dfs (int u) {
        for (int i = 1; i <= son[u][0]; ++i) {
            int v = son[u][i]; dfs(v);
            for (int j = m + 1; j >= 1; --j)
                for (int k = 0; k < j; ++k)
                    f[u][j] = max(f[u][j], f[u][j - k] + f[v][k]); 
        }
    }
    
    int main () {
        scanf ("%d%d", &n, &m);
        for (int i = 1, fa; i <= n; ++i) {
            scanf ("%d%d", &fa, s + i);
            f[i][1] = s[i];
            son[fa][++son[fa][0]] = i;
        }
        dfs(0);
        printf ("%d
    ", f[0][m + 1]);
        return 0;
    }
    

    $n^2$解法

    警告:此算法可能思维难度较大,而且一般联赛不会考(但不排除作为压轴题考出),视情况阅读!


    原理

    显然,$n^3$算法的时间开销是很$Big$的,比如这道题:洛谷 P4322 最佳团体

    此题在$01$分数规划后采取树形背包$check$,但是,$nm^2log$的时间复杂度是不允许,考虑优化树形背包的$check$过程

    首先,既然要优化,我们就得知道瓶颈在哪。瓶颈在于,我们是一边$dfs$一边更新的,由于要遍历子树,我们同时还要知道选择多少个节点,那么我们是否可以先跑一遍$dfs$处理出$dfs$序然后根据$dfs$序,来更新。

    设$f[i][j]$为当前$dp$到$dfs$序为$i$的点,目前已经选了$j$个节点。则有转移方程($d[i]$表示点权):

    1.选取当前节点:

    $$
    f[i+1][j+1]=f[i][j]+d[i]
    $$

    如果选了这个点,则在$dfs$序后一个节点要么是它的子节点,要么下一棵子树(则证明其没有子节点)。

    2.不选当前节点:

    $$
    f[nx[i]][j]=f[i][j]
    $$

    其中$nx[i]$表示下一棵子树,因为你没选这个点,当然不能选择其子节点。

    由于$dfs$序为$Theta(n)$的,然后枚举$j$为$O(m)$的,所以总复杂度为$O(nmlog)$。

    例题

    同样是Luogu P2014 选课

    #include <cstdio>
    #include <algorithm>
    using std::min;
    typedef long long ll;
    
    const int N = 3e2 + 10, M = 3e2 + 10, Inf = 1e9 + 7;
    int n, m, d[N], s[N], dfn[N], son[N][N], time, f[N][N], nx[N];
    inline void upt (int &a, int b) { if(a < b) a = b; }
    
    void Init_dfs(int u) {
    	dfn[u] = time++;
    	for (int i = 1; i <= son[u][0]; ++i)
    		Init_dfs(son[u][i]);
    	nx[dfn[u]] = time;
    }
    
    void Doit_dp() {
    	for (int i = 1; i <= n; ++i)
    		d[dfn[i]] = s[i];
    	for (int i = 1; i <= n + 1; ++i)
    		for (int j = 0; j <= m; ++j)
    			f[i][j] = -Inf;
    	for (int i = 0; i <= n; ++i)
    		for (int j = 0; j <= min(i, m); ++j) {
    			upt(f[i + 1][j + 1], f[i][j] + d[i]);
    			upt(f[nx[i]][j], f[i][j]);
    		}
    }
    
    int main () {
    	scanf("%d%d", &n, &m); ++m;
    	for (int i = 1, fa; i <= n; ++i) {
    		scanf("%d%d", &fa, s + i);
    		son[fa][++son[fa][0]] = i;
    	}
    	Init_dfs(0);//预处理dfs
    	Doit_dp();//动态规划
    	printf("%d
    ", f[n + 1][m]);
    	return 0;
    }
    

    之前我们提到的洛谷 P4322 最佳团体,就是用$01$分数规划&树形背包来解决的

    // luogu-judger-enable-o2
    #include <cstdio>
    #include <algorithm>
    using std::min;
    using std::max;
    
    const int N = 3e3 + 10, inf = 1e9 + 7;
    const double eps = 1e-5;
    int n, K, s[N], p[N], son[N][N], dfn[N], time, nx[N];
    int from[N], to[N], nxt[N], cnt;//Edges
    double f[N][N], d[N];
    
    inline void addEdge (int u, int v) {
    	to[++cnt] = v, nxt[cnt] = from[u], from[u] = cnt;
    }
    
    inline void upt(double &a, double b) {
    	if (a < b) a = b;
    }
    
    void dfs (int u) {
    	dfn[u] = time++;
    	for (int i = from[u]; i; i = nxt[i]) dfs(to[i]);
    	nx[dfn[u]] = time;
    }
    
    inline bool check (double k) {
    	for (int i = 1; i <= n; ++i) 
    		d[dfn[i]] = p[i] - k * s[i];
    	for (int i = 1; i <= n + 1; ++i)
    		for (int j = 0; j <= K; ++j)
    			f[i][j] = -inf;
    	for (int i = 0; i <= n; ++i)
    		for (int j = 0; j <= min(i, K); ++j) {
    			upt(f[i + 1][j + 1], f[i][j] + d[i]);
    			upt(f[nx[i]][j], f[i][j]);
    		}
    	return f[n + 1][K] >= eps;
    }
    
    int main () {
    	scanf("%d%d", &K, &n); ++K;
    	for (int i = 1, fa; i <= n; ++i)  {
    		scanf("%d%d%d", s + i, p + i, &fa);
    		addEdge(fa, i);
    	}
    	dfs(0);
    	double l = 0, r = 10000, ans;
    	while (r - l >= eps) {
    		double mid = (l + r) * 0.5;
    		if (check(mid)) ans = mid, l = mid + eps;
    		else r = mid - eps;
    	}
    	printf ("%.3lf
    ", ans);
    	return 0;
    }
    
  • 相关阅读:
    二分法查找
    重构方法之一:提炼方法(Extract Method)读书笔记
    使用SQL命令手动写入Discuz帖子内容
    调整linux系统时间和时区
    怎样给访问量过大的mysql数据库减压
    MySQL提示“too many connections”的解决办法
    CentOS 6安装php加速软件Zend Guard
    CmsTop 大众版运行环境搭建 (CentOS+Nginx+PHP FastCGI)
    LEMP构建高性能WEB服务器(CentOS+Nginx+PHP+MySQL)
    CentOS-6.3安装配置Nginx
  • 原文地址:https://www.cnblogs.com/water-mi/p/9818622.html
Copyright © 2011-2022 走看看