zoukankan      html  css  js  c++  java
  • 签到题 [换根dp]

    签到题


    color{red}{正解部分}

    这道题 每个子树看成一个子问题, 求出每个子树的答案, 然后往上合并得到总答案 .

    设当前节点有 22 个子树, 权值和节点数量 分别是 sum1,size1,sum2,size2sum_1, size_1, sum_2, size_2,子树内的答案为 ans1,ans2ans_1, ans_2

    则先往 11 儿子走对答案的贡献为:
    ans1+sum1+ans2+(size1+1)×sum2ans_1+ sum_1+ ans_2 + (size_1+1) imes sum_2
    22 儿子走对答案的贡献为:
    ans2+sum2+ans1+(size2+1)×sum1ans_2+ sum_2+ ans_1 + (size_2+1) imes sum_1,

    当 走11儿子 比 走22儿子 更优时,

    ans1+sum1+ans2+(size1+1)×sum2<ans2+sum2+ans1+(size2+1)×sum1ans_1+ sum_1+ ans_2 + (size_1+1) imes sum_2 < ans_2+ sum_2+ ans_1 + (size_2+1) imes sum_1

    化简得 size1×sum2<size2×sum1size_1 imes sum_2 < size_2 imes sum_1 .

    所以以 sizex×sumysize_x imes sum_y 从小到大排序后, 从小到大按顺序 dfsdfs 即可实现答案最优 .


    现在已经解决了当根固定时的答案, 考虑如何计算 所有节点作为根的 最优值,

    可以想到 先求出以 11 为根 的答案, 然后进行 换根,


    现在已经计算出了 ansxans_x, 且要将 根的位置xyx ightarrow y, 要求 yy 为根的答案,
    首先观察 树的信息 哪里发生了变化,

    1. yy为根 的子树 从 xx 的子树中移除掉了, sizex=sizey,sumx=sumysize_x -=size_y,sum_x-=sum_y
    2. xx为根 的子树 成为了 yy 的新子树, sizey+=sizex,sumy+=sumxsize_y += size_x, sum_y += sum_x .

    ansxans_x 的影响为 ansx=ansy+sizey×sumy+sizey×sumyans_x -= ans_y + size_{y前子树} imes sum_y + size_y imes sum_{y后子树},
    其中 ansyans_y 在往下递归的时候使用子树信息计算即可 .


    color{red}{实现部分}

    #include<bits/stdc++.h>
    #define reg register
    #define pb push_back
    typedef long long ll;
    
    int read(){
            char c;
            int s = 0, flag = 1;
            while((c=getchar()) && !isdigit(c))
                    if(c == '-'){ flag = -1, c = getchar(); break ; }
            while(isdigit(c)) s = s*10 + c-'0', c = getchar();
            return s * flag;
    }
    
    const int maxn = 200005;
    
    int N;
    int num0;
    int A[maxn];
    int size[maxn];
    int head[maxn];
    
    ll tot;
    ll Ans;
    ll sum[maxn];
    ll ans[maxn];
    
    struct Edge{ int nxt, to; } edge[maxn << 1];
    
    void Add(int from, int to){
            edge[++ num0] = (Edge){ head[from], to };
            head[from] = num0;
    }
    
    bool cmp(int a, int b){ return size[a]*sum[b] < size[b]*sum[a]; }
    
    void DFS_1(int k, int fa){
            std::vector <int> B;
            sum[k] = A[k], size[k] = 1;
            for(reg int i = head[k]; i; i = edge[i].nxt){ 
                    int to = edge[i].to; 
                    if(to == fa) continue ; B.pb(to); 
                    DFS_1(to, k);
                    sum[k] += sum[to], size[k] += size[to];
            }
            std::sort(B.begin(), B.end(), cmp); 
            ans[k] = A[k]; ll last = 1;
            for(reg int i = 0; i < B.size(); i ++){
                    int to = B[i]; 
                    ans[k] += ans[to] + last * sum[to], last += size[to];
            }
    }
    
    void DFS_2(int k, int fa){
            std::vector <int> B;
            for(reg int i = head[k]; i; i = edge[i].nxt) B.pb(edge[i].to);
            std::sort(B.begin(), B.end(), cmp);
            ans[k] = A[k]; ll last = 1;
            for(reg int i = 0; i < B.size(); i ++){
                    int to = B[i];
                    ans[k] += ans[to] + last * sum[to];
                    last += size[to];
            }
            Ans = std::min(Ans, ans[k]);
            last = 1; ll suf = tot - A[k];
            for(reg int i = 0; i < B.size(); i ++){
                    int to = B[i]; suf -= sum[to];
                    if(to != fa){
                            ll t1 = ans[k], t2 = ans[to];
                            ans[k] -= ans[to] + last*sum[to] + size[to]*suf;
                            size[k] -= size[to], sum[k] -= sum[to];
                            size[to] += size[k], sum[to] += sum[k];
                            DFS_2(to, k);
                            size[to] -= size[k], sum[to] -= sum[k];
                            size[k] += size[to], sum[k] += sum[to];
                            ans[k] = t1, ans[to] = t2;
                    }
                    last += size[to];
            }
    }
    
    int main(){
            N = read();
            for(reg int i = 1; i < N; i ++){ int u = read(), v = read(); Add(u, v), Add(v, u); }
            for(reg int i = 1; i <= N; i ++) A[i] = read(), tot += A[i];
            DFS_1(1, 1); 
            Ans = ans[1]; DFS_2(1, 1);
            printf("%lld
    ", Ans);
            return 0;
    }
    
  • 相关阅读:
    解决跨域问题 cors~ JSONP~
    session,cookie,sessionStorage,localStorage的区别~~~前端面试
    数据库索引的理解
    script的按需加载
    es6 笔记
    JS 工具函数
    JS Error
    数组方法重写:forEach, map, filter, every, some, reduce
    JS: GO 和 AO
    立即执行函数
  • 原文地址:https://www.cnblogs.com/zbr162/p/11822467.html
Copyright © 2011-2022 走看看