zoukankan      html  css  js  c++  java
  • 树形dp

    树形dp,意思就是在树上的dp, 看了看紫书,讲了三个大点把,一个是树的最大独立集,另外一个是树的重心,最后一个是树的最长路径。给的三个例题,下面就从例题说起

    第一个:工人的请愿书 uva 12186

    这个题目给定一个公司的树状结构,每个员工都有唯一的一个直属上司,老板编号为0,员工1-n,只有下一级的工人请愿书不小于T%时,这个中级员工,才会签字传递给它的直属上司,问老板收到请愿书至少需要多少各个工人签字

    用dp(u)表示u给上级发信至少需要多少工人,那么可以假设u有k个节点,所以需要c = (T*k - 1) / 100 + 1个直接下属,把所有的子节点的dp值从小到大排序,前c个加起来就是答案。

    树形dp就是从根节点一层一层往下找,找到最优的子结构,树形的最优子结构就是子树,我是用的记忆化搜索。代码如下:

    #include <cstdio>
    #include <vector>
    #include <algorithm>
    using namespace std;
    const int maxn = 1e5 + 10;
    vector<int> sons[maxn];
    int n, t;
    int dp(int u)
    {
        if (sons[u].empty())
            return 1;
        int k = sons[u].size();
        vector<int> d;
        for (int i = 0; i < k; i++)
            d.push_back(dp(sons[u][i]));
        sort(d.begin(), d.end());
        int c = (k * t - 1) / 100 + 1;
        int ans = 0;
        for (int i = 0; i < c; i++)
            ans += d[i];
        return ans;
    }
    void init()
    {
        for (int i = 0; i <= n; i++)
            sons[i].clear();
    }
    int main()
    {
        while (~scanf("%d%d", &n, &t) && n + t)
        {
            init();
            int t;
            for (int i = 1; i <= n; i++)
            {
                scanf("%d", &t);    
                sons[t].push_back(i);
            }
            printf("%d
    ", dp(0));
        }
        return 0;
    }
    View Code

    第二个:Hali-Bula的晚会,poj3342, uva1220

    这个几乎是求树的最大独立集的模板题,就是多了一个要求,判断是否唯一

    考虑一个点,只有两种情况,选它,不选它

    所以用 d[u][0]表示以u为根的子树中,不选u点能得到的最大人数 f[u][0]表示方案唯一性,如果f[u][0] = 1说明唯一,0说明不唯一

    d[u][1]表示以u为根,选u点能得到的最大值。

    状态转移方程就是d[u][1]  = sum{d[v][0]}| v是u的子节点

    d[u][0] = sum{max(d[v][0], d[v][1])},所以代码如下:

    #include <cstdio>
    #include <iostream>
    #include <vector>
    #include <algorithm>
    #include <string>
    #include <cstring>
    #include <map>
    using namespace std;
    const int maxn = 240;
    vector<int> sons[maxn];
    map<string, int> mp;
    int f[maxn][2], d[maxn][2];
    int n, cnt;
    void init()
    {
        memset(d, -1, sizeof(d));
        for (int i = 0; i <= n; i++)
            sons[i].clear();
        mp.clear();
    }
    int dp(int u, int flag)
    {
        f[u][flag] = 1;
        if (sons[u].empty() && flag)
            return 1;
        if (sons[u].empty() && !flag)
            return 0;
        int k = sons[u].size();
        int sum = 0;
        if (flag == 1)
        {
            sum++;
            for (int i = 0; i < k; i++)
            {
                int v = sons[u][i];
                if (d[v][0] == -1)
                    d[v][0] = dp(v, 0);
                sum += d[v][0];
                f[u][1] &= f[v][0];
            }
        }
        else
        {
            for (int i = 0; i < k; i++)
            {
                int v = sons[u][i];
                if (d[v][1] == -1)
                    d[v][1] = dp(v, 1);
                if (d[v][0] == -1)
                    d[v][0] = dp(v, 0);
                if (d[v][0] == d[v][1])
                    f[u][flag] = 0;
                if (d[v][0] > d[v][1])
                    f[u][0] &= f[v][0];
                else
                    f[u][0] &= f[v][1];
                sum += max(d[v][1], d[v][0]);
            }
        }
        return sum;
    }
    
    int main()
    {
        while (~scanf("%d", &n) && n)
        {
            init();
            string s, tmp;
            cin >> s;
            mp[s] = 0;
            cnt = 1;
            for (int i = 1; i < n; i++)
            {
                cin >> s >> tmp;
                if (mp.count(s) == 0)
                    mp[s] = cnt++;
                if (mp.count(tmp) == 0)
                    mp[tmp] = cnt++;
                sons[mp[tmp]].push_back(mp[s]);
            }
            int t1 = dp(0, 0); int t2 = dp(0, 1);
            bool flag = true;
            //cout << f[0][0] << endl;
            //cout << f[0][1] << endl;
            //cout << t2 << endl;
            if (t1 == t2)
                flag = false;
            if (t1 < t2 && f[0][1] == 0)
                flag = false;
            if (t1 > t2 && f[0][0] == 0)
                flag = false;
            printf("%d %s
    ", max(t1, t2), flag ? "Yes" : "No");
        }
        return 0;
    }
    View Code

    第三个:完美的服务 poj3398, uva 1218

    这个和第二题差不多,就是状态多了一个,因为有个条件是每台计算机连接的服务器恰好是一个,所以多了一个状态,我是用记忆化搜索来写的,写完之后一直wrong answer,后来对比了网上的代码发现把inf设的太大了,估计是后来加着加着越界了,所以wrong了,后来改了就好了,不过记忆化写起来比递推代码长多了。。。

    记忆化代码:

    #include <cstdio>
    #include <iostream>
    #include <vector>
    using namespace std;
    const int maxn = 10010;
    const int inf = (1e6);
    vector<int> sons[maxn];
    int n;
    int d[maxn][3];
    void init()
    {
        for (int i = 0; i <= n; i++)
            d[i][0] = d[i][1] = d[i][2] = inf;
        for (int i = 0; i <= n; i++)
            sons[i].clear();
    }
    int dp(int u, int pre, int f)
    {
        if (sons[u].size() == 1 && sons[u][0] == pre)
            if (f == 0)
                return 1;
            else if (f == 1)
                return 0;
            else
                return inf;
        int k = sons[u].size();
        int sum = 0;
        if (f == 0)//u is server
        {
            sum++;
            for (int i = 0; i < k; i++)
            {
                int v = sons[u][i];
                if (pre == v)
                    continue;
                if (d[v][0] == inf)
                    d[v][0] = dp(v, u, 0);
                if (d[v][1] == inf)
                    d[v][1] = dp(v, u, 1);
                sum += min(d[v][0], d[v][1]);
            }
        }
        else if (f == 1)//u is not server, his father is server
        {
            for (int i = 0; i < k; i++)
            {
                int v = sons[u][i];
                if (v == pre)
                    continue;
                if (d[v][2] == inf)
                    d[v][2] = dp(v, u, 2);
                sum += d[v][2];
            }
        }
        else//u is not server and his father is not server;
        {
            int ans = inf;
            for (int i = 0; i < k; i++)
            {
                int v = sons[u][i];
                if (pre == v)
                    continue;
                if (d[u][1] == inf)
                    d[u][1] = dp(u, pre, 1);
                if (d[v][2] == inf)
                    d[v][2] = dp(v, u, 2);
                if (d[v][0] == inf)
                    d[v][0] = dp(v, u, 0);
                ans = min(ans, d[u][1] - d[v][2] + d[v][0]);
            }
            sum += ans;
        }
        return sum;
    }
    int main()
    {
        int a, b;
        while (~scanf("%d", &n) && n != -1)
        {
            if (n == 0)
                continue;
            init();
            for (int i = 1; i < n; i++)
            {
                scanf("%d %d", &a, &b);
                sons[a].push_back(b);
                sons[b].push_back(a);
            }
            int t1 = dp(1, 0, 0); int t2 = dp(1, 0, 2);
            //cout << t1 << endl;
            //cout << t2 << endl;
            printf("%d
    ", min(t1, t2));
        }
        return 0;
    }
    View Code

    递推代码:

    #include <cstdio>
    #include <cstring>
    #include <iostream>
    #include <vector>
    using namespace std;
    const int maxn = 11000;
    const int inf = 1e6;
    vector<int> sons[maxn];
    int n;
    int d[maxn][3];
    void init()
    {
        for (int i = 0; i <= n + 10; i++)
        {
            d[i][0] = 1; d[i][1] = 0; d[i][2] = inf;
        }
        for (int i = 0; i <= n; i++)
            sons[i].clear();
    }
    void dp(int u, int pre)
    {
        int k = sons[u].size();
        for (int i = 0; i < k; i++)
        {
            int v = sons[u][i];
            if (pre == v)
                continue;
            dp(v, u);
            d[u][0] += min(d[v][0], d[v][1]);
            d[u][1] += d[v][2];
            d[u][2] = min(d[u][2], d[v][0] - d[v][2]);
        }
        d[u][2] += d[u][1];
    }
    int main()
    {
        int a, b;
        while (~scanf("%d", &n) && n != -1)
        {
            if (n == 0)
                continue;
            init();
            for (int i = 1; i < n; i++)
            {
                scanf("%d %d", &a, &b);
                sons[a].push_back(b);
                sons[b].push_back(a);
            }
            dp(1, 0);
            printf("%d
    ", min(d[1][0], d[1][2]));
        }
        return 0;
    }
    View Code
  • 相关阅读:
    deepin系统换软件下载源&商店卡死刷新空白问题解决
    php数组和json数组之间的互相转化
    php 获取数组个数的方法
    php 三种文件下载的实现
    win10激活
    deepin/linux安装exe
    deepin连接windows
    deepin升级微信
    deepin安装.net core
    在Deepin 15.9下安装Wine 4.0
  • 原文地址:https://www.cnblogs.com/Howe-Young/p/4731070.html
Copyright © 2011-2022 走看看