zoukankan      html  css  js  c++  java
  • CCPC-Wannafly Winter Camp Day1 流流流动 (树形dp)

    题目描述

     

    喜欢数学的wlswls最近被萎住了。

    现在他一共有1...n1...n这么多数字,取数字ii会得到f[i]f[i]的收益。数字之间有些边,对于所有的i(i != 1)i(i!=1),若ii为奇数,则ii与3i+13i+1之间有边,否则ii与i/2i/2之间有边。如果一条边的两个顶点xyxy都被取了,那么会失去d[min(x, y)]d[min(x,y)]的价值。请问wlswls怎么取,才能使得收益最大?

     
     

    输入描述

     

    第一行一个整数nn。

    接下来一行nn个整数表示ff。

    接下来一行nn个整数表示dd。

    1 leq n leq 1001n100

    1 leq f[i], d[i] leq 10001f[i],d[i]1000

    输出描述

     

    输出一个整数表示答案。

    样例输入 1 

    2
    10 10 
    1 2

    样例输出 1

    19


    思路:
    根据题目给的建边条件,建边后会形成一个森林,然后把森林转化为一个0为根节点的树,随后进行树形dp。
    定义状态:
    dp[u][0/1] 0为 第u个节点的子树中不取第u个节点的最多利益值,
           1为第u个节点的子树中取第u个节点的最多利益值,
    常规的树形dp套路,
    dp[u][0]+=max(dp[v][0],dp[v][1]);
    dp[u][1]+=max(dp[v][0],dp[v][1]-d[min(u,v)]);// 一个边上两个节点都取的话,要减去对应的值。
    最后max(dp[0][0],dp[0][1])就是我们的答案值。
    细节见代码:
    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <cmath>
    #include <queue>
    #include <stack>
    #include <map>
    #include <set>
    #include <vector>
    #include <iomanip>
    #define ALL(x) (x).begin(), (x).end()
    #define rt return
    #define dll(x) scanf("%I64d",&x)
    #define xll(x) printf("%I64d
    ",x)
    #define sz(a) int(a.size())
    #define all(a) a.begin(), a.end()
    #define rep(i,x,n) for(int i=x;i<n;i++)
    #define repd(i,x,n) for(int i=x;i<=n;i++)
    #define pii pair<int,int>
    #define pll pair<long long ,long long>
    #define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
    #define MS0(X) memset((X), 0, sizeof((X)))
    #define MSC0(X) memset((X), '', sizeof((X)))
    #define pb push_back
    #define mp make_pair
    #define fi first
    #define se second
    #define eps 1e-6
    #define gg(x) getInt(&x)
    #define db(x) cout<<"== [ "<<x<<" ] =="<<endl;
    using namespace std;
    typedef long long ll;
    ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
    ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
    ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;}
    inline void getInt(int* p);
    const int maxn = 1010;
    const int inf = 0x3f3f3f3f;
    /*** TEMPLATE CODE * * STARTS HERE ***/
    
    int pre[maxn];
    int f[maxn];
    int d[maxn];
    int dp[maxn][2];
    int n;
    int findpar(int x)
    {
        return pre[x] == 0 ? x : pre[x] = findpar(pre[x]);
    }
    void mer(int x, int y)
    {
        x = findpar(x);
        y = findpar(y);
        if (x != y)
        {
            pre[x] = y;
        }
    }
    std::vector<int> v[maxn];
    int w;
    void dfs(int u, int pre)
    {
        // cout<<u<<" "<<pre<<endl;
        dp[u][0] = 0;
        dp[u][1] = f[u];
        for (auto x : v[u])
        {
            if (x != pre)
            {
                dfs(x, u);
                dp[u][0] += max(dp[x][0], dp[x][1]);
                dp[u][1] += max(dp[x][0], dp[x][1] - d[min(u, x)]);
            }
        }
    }
    int main()
    {
        //freopen("D:\common_text\code_stream\in.txt","r",stdin);
        //freopen("D:\common_text\code_stream\out.txt","w",stdout);
        gbtb;
        cin >> n;
        repd(i, 1, n)
        {
            cin >> f[i];
        }
        repd(i, 1, n)
        {
            cin >> d[i];
        }
        repd(i, 2, n)
        {
            if (i & 1)
            {
                if (3 * i + 1 <= n)
                {
                    v[i].push_back(3 * i + 1);
                    v[3 * i + 1].push_back(i);
                    mer(i, 3 * i + 1);
                }
            } else
            {
                v[i].push_back(i / 2);
                v[i / 2].push_back(i);
                mer(i, i / 2);
            }
        }
        repd(i, 1, n)
        {
            if (pre[i] == 0)
            {
                v[0].push_back(i);
                v[i].push_back(0);
            }
        }
        dfs(0, 0);
        cout << max(dp[0][0], dp[0][1]) << endl;
    
        return 0;
    }
    
    inline void getInt(int* p) {
        char ch;
        do {
            ch = getchar();
        } while (ch == ' ' || ch == '
    ');
        if (ch == '-') {
            *p = -(getchar() - '0');
            while ((ch = getchar()) >= '0' && ch <= '9') {
                *p = *p * 10 - ch + '0';
            }
        }
        else {
            *p = ch - '0';
            while ((ch = getchar()) >= '0' && ch <= '9') {
                *p = *p * 10 + ch - '0';
            }
        }
    }




    本博客为本人原创,如需转载,请必须声明博客的源地址。 本人博客地址为:www.cnblogs.com/qieqiemin/ 希望所写的文章对您有帮助。
  • 相关阅读:
    mitm iptables ssltrip set ferret hamster
    SQL注入的常用函数和语句
    SQL注入的字符串连接函数
    SQL注入的分类
    DNS配置详解
    Linux的任务计划--cron入门
    Linux文件系统层次结构标准
    Linux的awk命令
    Linux的sed命令
    Linux的find命令
  • 原文地址:https://www.cnblogs.com/qieqiemin/p/10907090.html
Copyright © 2011-2022 走看看