zoukankan      html  css  js  c++  java
  • HDU3045 Picnic Cows (斜率DP优化)(数形结合)

    转自PomeCat

    “DP的斜率优化——对不必要的状态量进行抛弃,对不优的状态量进行搁置,使得在常数时间内找到最优解成为可能。斜率优化依靠的是数形结合的思想,通过将每个阶段和状态的答案反映在坐标系上寻找解答的单调性,来在一个单调的答案(下标)队列中O(1)得到最优解。”

    https://wenku.baidu.com/view/b97cd22d0066f5335a8121a3.html

    “一些试题中繁杂的代数关系身后往往隐藏着丰富的几何背景,而借助背景图形的性质,可以使那些原本复杂的数量关系和抽象的概念,显得直观,从而找到设计算法的捷径。”—— 周源《浅谈数形结合思想在信息学竞赛中的应用》

    斜率优化的核心即为数形结合,具体来说,就是以DP方程为基础,通过变形来使得原方程形如一个函数解析式,再通过建立坐标系的方式,将每一个DP方程代表的状态表示在坐标系中,在确定“斜率”单调性的前提下,进行形如单调队列操作的舍解和维护操作。

    一个算法总是用于解决实际问题的,所以结合例题来说是最好的:

    Picnic Cows(HDU3045)

    题目大意: 
    给出一个有N (1<= N <=400000)个正数的序列,要求把序列分成若干组(可以打乱顺序),每组的元素个数不能小于T (1 < T <= N)。每一组的代价是每个元素与最小元素的差之和,总代价是每个组的代价之和,求总代价的最小值。

    样例输入包含: 
    第一行 N 
    第二行 N个数,如题意

    样例输出包含: 
    第一行 最小的总代价

    分析: 
    首先,审题。可以打乱序列顺序,又知道代价为组内每个元素与最小值差之和,故想到贪心,先将序列排序(用STL sort)。 
    先从最简单的DP方程想起: 
    容易想到:

    f[i] = min( f[j] + (a[j + 1 -> i] - Min k) ) (0 <= j < i)

    – –> f[i] = min( f[j] + sum[i] - sum[j] - a[j + 1] * ( i - j ) )

    Min k 代表序列 j + 1 -> i 内的最小值,排序后可以简化为a[j + 1]。提取相似项合并成前缀和sum。这个方程的思路就是枚举 j 不断地计算状态值更新答案。但是数据规模达到了 40000 ,这种以O(n ^ 2)为绝对上界方法明显行不通。所以接下来我们要引入“斜率”来优化。

    首先要对方程进行变形: 
    f[i] = f[j] + sum[i] - sum[j] - a[j + 1] * ( i - j ) 
    – –> f[i] = (f[j] - sum[j] + a[j + 1] * j) - i * a[j + 1] + sum[i] 
    (此步将只由i决定的量与只由j决定的量分开) 
    由于 sum[i] 在当前枚举到 i 的状态下是一个不变量,所以在分析时可以忽略(因为对决策优不优没有影响)(当然写的时候肯定不能忽略)

    令 i = k 
    a[j + 1] = x 
    f[j] - sum[j] + a[j + 1] * j = y 
    f[i] = b 
    原方程变为 
    – –> b = y - k * x 
    移项 
    – –> y = k * x + b

    是不是很眼熟? 没错,这就是直线的解析式。观察这个式子,我们可以发现,当我们吧许许多多的答案放在坐标系上构成点集,且枚举到 i 时,过每一个点的斜率是一样的!! 很抽象? 看图

    图1

    可以看出,我们要求的f[i]就是截距,明显,延长最右边的直线交于坐标轴可得最小值。难道只要每次提取最靠近 i 的状态就行了嘛?现实没有那么美好。

    图2

    像这样的情况,过2直线的截距明显比过3直线的截距要小, 意味着更优(在找求解最小值问题时),这种情况下我们之前的猜想便行不通。

    那怎么办呢?这时就要用到斜率优化的核心思想——维护凸包。 
    何为凸包? 
    不懂得同学还是戳这里:http://baike.baidu.com/link?url=OWX7haiZHtuKD2hjbEBVouUGxKXIMvXDnKra0xDhxuz9zGttTg_JoRwmUcbrGD9Xp2BnbCJDF8BblaQffDBvblg0xNqgIOXCVZpQ7ZNBkWG

    其实我们要维护的凸包与这个凸包并无关系,只是在图中长得像罢了。 
    那为什么要维护凸包呢? 
    还要看图: 
    图3

    这就是一个下凸包,由图可见,最前面的那个点的截距最小,也诠释了维护凸包的真正含义(想一想优先队列,是不是队首最优?)。那还是有人会提出疑问,为什么非要维护这样的凸包呢? 答案就是,f[i]明显是递增的(相比于f[j] 加上一个代价),所以会在图中自然而然地显现出 Y 随着 X增加而增加的情况,呈现出凸包的模样。

    可能这个过程比较晦涩难懂,没懂的同学可以多看几遍。

    现在我们回到对  的分析

    现在假设我们正在枚举 j 来更新答案,有一个数 k, j < k < i 
    再来假设 k 比 j 优(之所以要假设正是要推出具体情况方便舍解)

    则有

    f[k] + sum[i] - sum[k] - a[k + 1] * (i - k) <= 
    f[j] + sum[i] - sum[j] - a[j + 1] * (i - j) (k > j)

    移项消项得 

    f[k] - sum[k] + a[k+ 1] * k - (f[j] - sum[j] + a[j + 1] * j) <= i * (a[k + 1] - a[j+ 1])

    将只与 i 有关的元素留下,剩下的除过去, 得到

    f[k] - sum[k] + a[k+ 1] * k - (f[j] - sum[j] + a[j + 1] * j) / (a[k + 1] - a[j + 1])<= i 

    (这里注意判断除数是否为负, 要变号,当然这里排序过后对于 k > j a[k] 肯定大于 a[j])

    这个式子什么意思呢?我用人类的语言解释一下。 
    设 Xi = a[i],            Yi = f[i] - sum[i] + a[i + 1] * i 
    那么原式即为如下形式:

    (Yk - Yj) / (Xk - Xj) <= i

    意思就是当有k 和 j 满足 j < k 的前提下 满足此不等式 
    则证明 j 没有 k 优

    而这个式子的左边数学意义是斜率, 而右边是一个递增的变量, 所以递增的 i 会淘汰队列里的元素, 而为了高效的淘汰, 我们会(在这道题里)选用斜率递增的单调队列,也就是上凸包。(再看看前面的图,是不是斜率在递增)

    我们还可以从另一个角度理解维护上凸包的理由

    仔细观察下面的图:

    一开始,1 号点的截距比2号点更优

    这里写图片描述

    随着斜率的变化,两个点的截距变得一样了

    然后,斜率接着变化,1号点开始没有2号点优了,所以要舍弃

    这里写图片描述

    后面的过程,3号点会渐渐超过2号点,并淘汰掉2号点

    这里写图片描述

    分析整个过程,最优点依次是 1 -> 2 -> 3,满足单调的要求

    但是如果是一个上凸包会怎样呢?

    这里只给出最终图(有兴趣的同学可以自己推一推),可以预见的是,在1赶超2前,3先赶超了2就破坏了顺序,因此不行

    这里写图片描述

    思路大概是清晰了,现在只剩下代码实现方面的问题了

    下面就看看单调队列的操作

    先将推出的X, Y用函数表示方便计算: 
    X:

    dnt X( int i, int j )
    {
        return a[j + 1] - a[i + 1];
    }
     
    • 1
    • 2
    • 3
    • 4

    (dnt 是 typedef 的 long long)

    Y:

    dnt Y( int i, int j )
    {
        return f[j] - sum[j] + j * a[j + 1] - (f[i] - sum[i] + i * a[i + 1]);
    }
     
    • 1
    • 2
    • 3
    • 4

    处理队首:

    for(; h + 1 < t && Y(Q[h + 1], Q[h + 2]) <= i * X(Q[h + 1], Q[h + 2]); h++);
     
    • 1

    从队尾维护单调性: 
    (这里是一个下凸包所以两点之间的斜率要递增,即 斜率(1, 2) < 斜率(2, 3), 前一个斜率比后一个小)

    for(; h + 1 < t && Y(Q[t - 1], Q[t]) * X(Q[t], cur) >= X(Q[t - 1], Q[t]) * Y(Q[t], cur); t--);
     
    • 1

    (注意,要把除法写成乘的形式,不然精度可能会出问题)

    斜率优化部分已经完结(说起来挺复杂其实代码很短),接下来就放出AC代码:

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <iostream>
    using namespace std;
    
    typedef long long dnt;
    
    int n, T, Q[405005];
    dnt sum[405005], f[405005], a[405005];
    
    dnt Y( int i, int j )
    {
        return f[j] - sum[j] + j * a[j + 1] - (f[i] - sum[i] + i * a[i + 1]);
    }
    
    dnt X( int i, int j )
    {
        return a[j + 1] - a[i + 1];
    }
    
    dnt DP( int i, int j )
    {
        return f[j] + (sum[i] - sum[j]) - (i - j) * a[j + 1];
    }
    
    inline dnt R()
    {
        static char ch;
        register dnt res, T = 1;
        while( ( ch = getchar() ) < '0'  || ch > '9' )if( ch == '-' )T = -1; 
            res = ch - 48;
        while( ( ch = getchar() ) <= '9' && ch >= '0')
            res = res * 10 + ch - 48;
        return res*T;
    }
    
    int main()
    {
        sum[0] = 0;
        while(~scanf( "%d%d", &n, &T ))
        {
            a[0] = 0, f[0] = 0;
            for(int i = 1; i <= n; i++)
                scanf( "%I64d", &a[i] );
            sort(a + 1, a + n + 1);
            for(int i = 1; i <= n; i++)
                sum[i] = sum[i - 1] + a[i];
            int h = 0, t = 0;
            Q[++t] = 0;
            for(int i = 1; i <= n; i++)
            {
                int cur = i - T + 1;
                for(; h + 1 < t && Y(Q[h + 1], Q[h + 2]) <= i * X(Q[h + 1], Q[h + 2]); h++);
                f[i] = DP(i, Q[h + 1]);
                if(cur < T) continue;
                for(; h + 1 < t && Y(Q[t - 1], Q[t]) * X(Q[t], cur) >= X(Q[t - 1], Q[t]) * Y(Q[t], cur); t--);
                Q[++t] = cur;
            }
            printf( "%I64d
    ", f[n] );
        }   
        return 0;
    }

     我自己的版本:

    #include<cstdio>
    #include<cstdlib>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    using namespace std;
    const int maxn=800010;
    long long  dp[maxn],q[maxn];
    long long  a[maxn],sum[maxn];
    long long getdp(long long i,long long j)
    {
        return dp[j]+(sum[i]-sum[j])-a[j+1]*(i-j);
    }
    long long getdy(long long j,long long k)//得到 yj-yk  k<j
    {
        return dp[j]-sum[j]+j*a[j+1]-(dp[k]-sum[k]+k*a[k+1]);
    }
    long long getdx(long long j,long long k)//得到 xj-xk  k<j
    {
        return a[j+1]-a[k+1];
    }
    int main()
    {
        long long i,j,n,k,head,tail,m;
        while(~scanf("%lld%lld",&n,&m)){
            head=tail=0;
            sum[0]=q[0]=dp[0]=q[1]=0;
            for(i=1;i<=n;i++) scanf("%lld",&a[i]);        
            sort(a+1,a+n+1);
            for(i=1;i<=n;i++) sum[i]=sum[i-1]+a[i];
            for(i=1;i<=n;i++){
                           //删去队首斜率小于目前斜率的点 
                while(head<tail&&(getdy(q[head+1],q[head])<=i*getdx(q[head+1],q[head]))) head++;
                dp[i]=getdp(i,q[head]);
                j=i-m+1;
                if(j<m) continue;
                //接下来是对j而不是i进行处理 ,保证了间隔大于m-1的要求 
                while(head<tail&&(getdy(j,q[tail])*getdx(j,q[tail-1])<=getdy(j,q[tail-1])*getdx(j,q[tail]))) tail--;
                q[++tail]=j;
            }
            printf("%lld
    ",dp[n]);
        }
        return 0;
    }
    View Code
  • 相关阅读:
    二分图匹配(匈牙利算法)
    最长共公子序列(LCS)
    网页常用Js代码
    linux 服务器常用命令整理
    阿里云学生服务器搭建网站-Ubuntu16.04安装php开发环境
    BAT批处理中的字符串处理详解(字符串截取)
    DOS批处理高级教程(还不错)(转)
    EntityFramework的linq扩展where
    RestSharp发送请求得到Json数据
    socket
  • 原文地址:https://www.cnblogs.com/hua-dong/p/7818231.html
Copyright © 2011-2022 走看看