zoukankan      html  css  js  c++  java
  • Shopping(树形背包+点分治)

    Solution

    我们发现要解决一个树上的连通块问题,解决这种问题的时候我们不妨先随便选一个根,如果要选某两个点则他们到n的路径上的点都会被选就变成了一个树形背包问题。

    注意这里是多重背包,所以我们可以用单调队列优化,时间复杂度$O(N^2M)$。

    考虑暴力选根的时候会把很多重复的情况算进去,所以我们可以用点分治,只计算根的孩子之间的贡献,递归子树时其余兄弟节点就不用管了。

    因为每次选的是重心,所以子树大小必然减一半,时间复杂度$O(NMlog{N})$。

    点分治+树形背包,这是一种常见的处理树上连通块的方法。

    Code

    因为加了单调队列优化,所以要注意树形背包时倒着做(即从叶节点开始)。

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    const int N = 5010, M = 40010, inf = 0x3f3f3f3f;
    struct node{
        int pre, to;
    }edge[N << 1];
    int head[N], tot;
    int T;
    int n, m;
    int sz[N], mx[N], rt;
    int w[N], v[N], d[N];
    int dfn[N], bl[N], dep;
    int dp[N][M];
    int ans;
    bool vis[N];
    void get_root(int x, int tot_size, int fa) {
        sz[x] = 1;
        mx[x] = 0;
        for (int i = head[x]; i; i = edge[i].pre) {
            int y = edge[i].to;
            if (y == fa || vis[y]) continue;
            get_root(y, tot_size, x);
            sz[x] += sz[y];
            if (mx[x] < sz[y]) {
                mx[x] = sz[y];
            }
        }
        mx[x] = max(mx[x], tot_size - sz[x]);
        if (mx[x] < mx[rt]) {
            rt = x;
        }
    }
    void dfs(int x, int fa) {
        sz[x] = 1;
        dfn[++dep] = x;
        bl[x] = dep;
        for (int i = head[x]; i; i = edge[i].pre) {
            int y = edge[i].to;
            if (y == fa || vis[y]) continue;
            dfs(y, x);
            sz[x] += sz[y];
        }
    }
    void cmax(int &x, int y) {
        x = max(x, y);
    }
    int q[N];
    void solve(int x) {
        vis[x] = 1;
        dep = 0;
        dfs(x, 0);
        for (int i = 0; i <= dep + 1; i++) {
            for (int j = 0; j <= m; j++) {
                dp[i][j] = 0;
            }
        }
        for (int i = dep; i >= 1; i--) {
            for (int j = 0; j <= m; j++) {
                cmax(dp[i][j], dp[i + sz[dfn[i]]][j]);
            }
            int a = v[dfn[i]];
            int b = w[dfn[i]];
            int c = d[dfn[i]];
            for (int j = 0; j < a; j++) {
                int heead = 1, tail = 0;
                for (int k = 0; j + k * a <= m; k++) {
                    while (heead <= tail && q[heead] < k - c) heead++;
                    if (heead <= tail) cmax(dp[i][j + k * a], dp[i + 1][j + q[heead] * a] - q[heead] * b + k * b);
                    while (heead <= tail && dp[i + 1][j + a * q[tail]] - q[tail] * b <= dp[i + 1][j + a * k] - k * b) tail--;
                    q[++tail] = k;
                }
            }
        }
        for (int i = 1; i <= m; i++) {
            ans = max(ans, dp[1][i]);
        }
        for (int i = head[x]; i; i = edge[i].pre) {
            int y = edge[i].to;
            if (vis[y]) continue;
            rt = 0;
            get_root(y, sz[y], x);
            solve(rt);
        }
    }
    int read() {
        int ret = 0, f = 1;
        char ch = getchar();
        while (!isdigit(ch)) {
            if (ch == '-') f = -1;
            ch = getchar();
        }
        while (isdigit(ch)) {
            ret = (ret << 1) + (ret << 3) + ch - '0';
            ch = getchar();
        }
        return ret * f;
    }
    void write(int x) {
        if (x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
    void print(int x) {
        if (x < 0) {
            x = -x;
            putchar('-');
        }
        write(x);
        putchar('
    ');
    }
    void init() {
        ans = 0;
        memset(vis, 0, sizeof(vis));
        memset(head, 0, sizeof(head));
        tot = 0;
    }
    void add(int u, int vv) {
        edge[++tot] = node{head[u], vv};
        head[u] = tot;
    } 
    int main() {
        mx[0] = inf;
        T = read();
        while (T--) {
            init();
            n = read(); m = read();
            for (int i = 1; i <= n; i++) {
                w[i] = read();
            }
            for (int i = 1; i <= n; i++) {
                v[i] = read();
            }
            for (int i = 1; i <= n; i++) {
                d[i] = read();
            }
            for (int i = 1, u, vv; i < n; i++) {
                u = read();
                vv = read();
                add(u, vv);
                add(vv, u);
            }
            rt = 0;
            get_root(1, n, 0);
            solve(rt);
            print(ans);
        }
        return 0;
    }
  • 相关阅读:
    ACE 的一些词汇
    odbc连接不上,初步猜想是myodbc安装有问题
    1分钟 当数据库管理员
    硬件申请
    编译删除
    ASP.NET之数据绑定
    发布、订阅、复制、同步SQL Server 2000 数据库
    SQL——添加约束的语句
    SQL——规则
    十大著名黑客—— 凯文米特尼克
  • 原文地址:https://www.cnblogs.com/zcr-blog/p/12989029.html
Copyright © 2011-2022 走看看