zoukankan      html  css  js  c++  java
  • [CEOI2004]锯木厂选址 斜率优化DP

    斜率优化DP

    先考虑朴素DP方程,

    f[i][k]代表第k个厂建在i棵树那里的最小代价,最后答案为f[n+1][3];

    f[i][k]=min(f[j][k-1] + 把j+1~i的树都运到i的代价)

    首先注意到“把j+1~i的树都运到i的代价”不太方便表达,每次都暴力计算显然是无法承受的,

    于是考虑前缀和优化,观察到先运到下一棵树那里,等一会再运下去,和直接运下去是等效的。

    设sum[i]代表1 ~ i的树都运到i的代价,

    于是根据前缀和思想,猜想我们可以用1 ~ r 的代价与 1 ~ l-1的代价获取l ~ r的代价,

    所以要做的就是吧1 ~ l-1 对 1 ~ r产生的贡献给算出来,然后减掉,

    考虑先把1 ~ l-1的树都运到l-1,所以这部分的代价是sum[l-1],

    然后再把树一次性运到r,那么代价是sum_weight[l-1] * (sum_len[r] - sum_len[l-1]);

                                        总的重量 * 现在要再次运的路程

    这里为了表示方便,用$sw$代表sum_weight,用$sl$代表sum_len;


    于是用$sum[r]$ 减去这两部分代价就可以得到$l ~ r$ 的代价(把$l ~ r$的树都运到$r$)

    代价(l ~ r)$ =  sum[r] - sum[l-1] - sw[l-1] * (sl[r] - sl[l-1]);$

    那么如何计算sum ?

    也是一样的思想,用前面的推后面的,先得到前面的代价,再加上新增的代价即可

    $sum[i]=sum[i-1] + swt[i-1] * len[i-1];$//len[i-1]代表i-1到i的距离

    于是我们就得到了DP方程:

    当$k==1$时,$f[i][k]=sum[i]$;

    else

            $f[i][k]=min(f[j][k-1] + sum[i] - sum[j] - sw[j] * (sl[i] - sl[j]));$

    但是可以发现,由于k最大就是3,而且3必须是n+1才可以取,

    而且当$k==1$时,$f[i][k]$就等于$sum[i]$,

    所以考虑优化维数:

    当$k==1$时,不用求,因为有$sum$了

    当$k==2$时,调用的$f[j][k-1]$替换为$sum[j]$,并且还可以发现由于后面有一个$-sum[j]$,所以可以直接消掉

    当$k==3$时,由于只有$n+1$可以取,所以直接在外面多写一个循环,相当于最后统计答案即可

    转移方式同朴素方程

    但是这样是$n^2$的DP,而$n$有20000,那怎么办呢?

    考虑斜率优化。

    首先我们用暴力打表可以发现,决策是单调的,

    打表代码(朴素DP):

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 #define R register int
     4 #define AC 20100
     5 int n, ans;
     6 int sum[AC], sum_weight[AC], sum_len[AC], f[AC];
     7 int weight[AC], len[AC];
     8 inline int read()
     9 {
    10     int x = 0; char c = getchar();
    11     while(c < '0' || c > '9') c = getchar();
    12     while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    13     return x; 
    14 }
    15 
    16 void pre()
    17 {
    18     n = read();
    19     for(R i = 1; i <= n; i ++)
    20         weight[i] = read(), len[i] = read();
    21 }
    22 
    23 void getsum()
    24 {
    25     for(R i = 1; i <= n + 1; i ++)//山脚的也要求
    26     {
    27         sum_len[i] = sum_len[i - 1] + len[i - 1];
    28         sum_weight[i] = sum_weight[i - 1] + weight[i];
    29         sum[i] = sum[i - 1] + sum_weight[i - 1] * len[i - 1];
    30     //    printf("%d : %d
    ",i,sum[i]);
    31     }
    32 }
    33 
    34 void work()
    35 {
    36     for(R i = 1; i <= n; i ++)
    37     {
    38         int tmp = 0;
    39         f[i] = INT_MAX;
    40         for(R j = 1;j < i;j ++)
    41         {
    42             if(sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]) < f[i])
    43             {
    44                 f[i] = sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]);
    45                 tmp = j;
    46             } 
    47         }
    48         printf("%d --- > %d
    ", tmp, i);//打表验证决策单调性
    49     }
    50     ans = INT_MAX;
    51     for(R i = 2; i <= n; i ++)//注意应该是n+1,因为山脚是在下面
    52         ans = min(ans, f[i] + sum[n + 1] - sum[i] - sum_weight[i] * (sum_len[n + 1] - sum_len[i]));
    53     for(R i = 2; i <= n; i ++) printf("%d : %d
    ", i, f[i]);
    54     printf("%d
    ", ans);
    55 }
    56 
    57 int main()
    58 {
    59     freopen("in.in", "r", stdin);
    60     freopen("out.out", "w", stdout);
    61     pre();
    62     getsum();
    63     work();
    64     fclose(stdin);
    65     fclose(stdout);
    66     return 0;
    67 }
    View Code

    于是我们推斜率优化方程:

    设有 $k < j < i$,且$j$优于$k$(相当于$j$是后面来的),则有:

    $sum[i] - sw[j] * (sl[i] - sl[j]) < sum[i] - sw[k] * (sl[i] - sl[k])$

    $sw[j] * (sl[i] - sl[j]) > sw[k] * (sl[i] - sl[k])$

    $sw[j] * sl[i] - sw[j] * sl[j] >  sw[k] * sl[i] - sw[k] * sl[k]$

    $sw[k] * sl[k] - sw[j] * sl[j] > sw[k] * sl[i] - sw[j] * sl[i]$

    $sw[k] * sl[k] - sw[j] * sl[j] > sl[i] * (sw[k] - sw[j])$

    $frac{(sw[k] * sl[k] - sw[j] * sl[j])} {(sw[k] - sw[j])} < sl[i]$ //注意sw[k] - sw[j]小于0,要变号

    所以令$K = frac{(sw[k] * sl[k] - sw[j] * sl[j])}{(sw[k] - sw[j])}$;

    则    while(head < tail && k(q[head],q[head+1]) < sum_len[i])  ++head;

    while(head < tail && k(q[tail-1],q[tail]) > k(q[tail],i)) --tail;

    最后上代码:

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 #define R register int
     4 #define AC 20100
     5 int n, ans;
     6 int sum[AC], sum_weight[AC], sum_len[AC], f[AC];
     7 int weight[AC], len[AC];
     8 int q[AC], head, tail;
     9 inline int read()
    10 {
    11     int x = 0; char c = getchar();
    12     while(c < '0' || c > '9') c = getchar();
    13     while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    14     return x; 
    15 }
    16 
    17 inline double k(int x, int y)
    18 {
    19     double a = sum_weight[x] * sum_len[x] - sum_weight[y] * sum_len[y];
    20     double b = sum_weight[x] - sum_weight[y];
    21     return a / b;
    22 }
    23 
    24 void pre()
    25 {
    26     n = read();
    27     for(R i = 1; i <= n; i ++) weight[i] = read(), len[i] = read();
    28 }
    29 
    30 void getsum()
    31 {
    32     for(R i = 1; i <= n + 1; i ++)//山脚的也要求
    33     {
    34         sum_len[i] = sum_len[i - 1] + len[i - 1];
    35         sum_weight[i] = sum_weight[i - 1] + weight[i];
    36         sum[i] = sum[i - 1] + sum_weight[i - 1] * len[i - 1];
    37     //    printf("%d : %d
    ",i,sum[i]);
    38     }
    39 }
    40 
    41 void work()
    42 {
    43     head=1;
    44     for(R i = 1; i <= n; i ++)
    45     {
    46         f[i] = INT_MAX;
    47         
    48         /*int tmp = 0;
    49         f[i] = INT_MAX;
    50         for(R j = 1; j < i; j ++)
    51         {
    52             if(sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]) < f[i])
    53             {
    54                 f[i] = sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]);
    55                 tmp = j;
    56             } 
    57         }
    58         printf("%d --- > %d
    ", tmp, i);//打表验证决策单调性*/
    59         
    60         while(head < tail && k(q[head], q[head + 1]) < sum_len[i]) ++ head;
    61         int now = q[head];
    62     //    printf("%d --- > %d
    ",now,i);
    63         f[i] = sum[i] - sum_weight[now] * (sum_len[i] - sum_len[now]);
    64         while(head < tail && k(q[tail - 1], q[tail]) > k(q[tail], i)) -- tail;
    65         q[++tail] = i;
    66     }
    67     ans = INT_MAX;
    68     for(R i = 2; i <= n; i ++)//注意应该是n+1,因为山脚是在下面,注意要从2开始,因为这是在枚举第2个厂在哪
    69         ans = min(ans, f[i] + sum[n + 1] - sum[i] - sum_weight[i] * (sum_len[n + 1] - sum_len[i]));
    70     printf("%d
    ", ans);
    71 }
    72 
    73 int main()
    74 {
    75 //    freopen("in.in", "r", stdin);
    76     pre();
    77     getsum();
    78     work();
    79 //    fclose(stdin);
    80     return 0;
    81 }
    View Code

     ---------------2018.10.12--------------优化了代码格式

  • 相关阅读:
    洛谷P2762 太空飞行计划问题
    网络流24题 gay题报告
    洛谷P1712 区间
    洛谷P2480 古代猪文
    10.9zuoye
    面向对象类编程,计算分数
    请输入验证码优化版
    面向对象式开发程序
    直接选择排序与反转排序
    随机数产生原理
  • 原文地址:https://www.cnblogs.com/ww3113306/p/8906890.html
Copyright © 2011-2022 走看看