zoukankan      html  css  js  c++  java
  • CodeForces 424D: ...(二分)

    题意:给出一个n*m的矩阵,内有一些数字。当你从一个方格走到另一个方格时,按这两个方格数字的大小,有(升,平,降)三种费用。你需要在矩阵中找到边长大于2的一个矩形,使得按这个矩形顺时针行走一圈的费用,与给定费用最接近。3<=n,m<=300。

    思路:O(1)计算一个矩形的费用不是什么难事,因为考虑到有前缀性质(前缀性质:[l,r] = [0,r] - [0,l-1]),只要预处理好各行各个方向行走的费用,就容易计算。

    直接枚举容易得到O(n^4)的算法。难以过。这时就应当想到优化。实际上,经过优化,可以得到O(n^3 *log(n))的算法。优化的方法如下:只枚举上下两层位置和右边界位置,正常思路是再枚举左边界位置,如果我们能二分得到左边界位置,就完美了。可惜直接二分并不满足性质。[本题关键点]这时需要构造一个前缀性质。如图

    细了就不说了。思考一下吧~。然后边扫边插入前面的前缀到set里面,然后用lower_bound就可以了。不过注意边界问题。

    代码:

    #include <cstdio>
    #include <cstdlib>
    #include <cstring>
    #include <cmath>
    #include <set>
    #include <vector>
    
    using namespace std;
    
    #define R 0
    #define L 1
    #define U 2
    #define D 3
    #define N 400
    int sum[4][N][N];
    int t[3];
    int mat[N][N];
    int n, m, goalt;
    
    void init() {
        for (int i = 0; i < n; i++) {
            sum[R][i][0] = sum[L][i][0] = 0;
            for (int j = 1; j < m; j++) {
                sum[R][i][j] = sum[R][i][j-1];
                sum[L][i][j] = sum[L][i][j-1];
                if (mat[i][j] == mat[i][j-1]) {
                    sum[R][i][j] += t[0];
                    sum[L][i][j] += t[0];
                    continue;
                }
                if (mat[i][j] > mat[i][j-1]) {
                    sum[R][i][j] += t[1];
                    sum[L][i][j] += t[2];
                    continue;
                }
                if (mat[i][j] < mat[i][j-1]) {
                    sum[R][i][j] += t[2];
                    sum[L][i][j] += t[1];
                    continue;
                }
            }
        }
    
        for (int j = 0; j < m; j++) {
            sum[U][0][j] = sum[D][0][j] = 0;
            for (int i = 1; i < n; i++) {
                sum[U][i][j] = sum[U][i-1][j];
                sum[D][i][j] = sum[D][i-1][j];
                if (mat[i][j] == mat[i-1][j]) {
                    sum[U][i][j] += t[0];
                    sum[D][i][j] += t[0];
                    continue;
                }
                if (mat[i-1][j] > mat[i][j]) {
                    sum[U][i][j] += t[1];
                    sum[D][i][j] += t[2];
                    continue;
                }
                if (mat[i-1][j] < mat[i][j]) {
                    sum[U][i][j] += t[2];
                    sum[D][i][j] += t[1];
                    continue;
                }
            }
        }
    }
    
    int ans[4];
    int minabsdis;
    typedef pair<int,int> pii;
    
    //#define FIX 10000000
    inline int getLeftVal(int iup, int idn, int j) {
        //int res = -(sum[U][idn][j] - sum[U][iup][j]) + sum[R][iup][j] + sum[L][idn][j];
        //printf("(%d,%d,%d) = %d
    ", iup+1, idn+1, j, res);
        return -(sum[U][idn][j] - sum[U][iup][j]) + sum[R][iup][j] + sum[L][idn][j];
    }
    void find() {
        minabsdis = 999999999;
        for (int iup = 0; iup < n; iup++) {
            for (int idn = iup+2; idn < n; idn++) {
                set< pair<int,int> > leftsum;
                leftsum.clear();
                set< pair<int,int> >::iterator spi;
                //printf("(%d,%d)
    ", iup+1, idn+1);
                leftsum.insert(pii(getLeftVal(iup,idn,0), 0));
                //printf("first = (%d,%d)
    ", (*leftsum.begin()).first,(*leftsum.begin()).second);
                for (int j = 2; j < m; j++) {
                    int now = sum[R][iup][j] + sum[L][idn][j] + sum[D][idn][j]-sum[D][iup][j];
                    int should = now - goalt;
                    //printf("(%d) should = %d
    ", j, should);
                    spi = leftsum.lower_bound(pii(should, 0));
                    if (spi == leftsum.end()) {
                        //puts("meet end");
                        spi--;
                    }
                    else if (spi != leftsum.begin()){
                        int rnum = now-(*spi).first;
                        spi--;
                        int lnum = now-(*spi).first;
                        spi++;
                        if (fabs(lnum-goalt) < fabs(rnum-goalt)) {
                            //puts("minus");
                            spi--;
                        }
                    }
                    pii findpair = *spi;
                    //printf("find (%d,%d)
    ", findpair.first, findpair.second);
                    int final = now - findpair.first;
                    if ((int)fabs(final-goalt) < minabsdis) {
                        //puts("lala");
                        minabsdis = fabs(final-goalt);
                        ans[0] = iup;
                        ans[1] = findpair.second;
                        ans[2] = idn;
                        ans[3] = j;
                    }
                    leftsum.insert(pii(getLeftVal(iup,idn,j-1), j-1));
                }
            }
        }
    }
    
    int gettype(int l, int r, bool rev) {
        if (rev) l^=r^=l^=r;
        if (l==r) return 0;
        if (l<r) return 1;
        if (l>r) return 2;
    }
    void checkAns() {
        int res = 0;
        for (int j = ans[1]+1; j <= ans[3]; j++) {
            res += t[gettype(mat[ans[0]][j-1], mat[ans[0]][j], false)];
            res += t[gettype(mat[ans[2]][j-1], mat[ans[2]][j], true)];
        }
        for (int i = ans[0]+1; i <= ans[2]; i++) {
            res += t[gettype(mat[i-1][ans[1]], mat[i][ans[1]], true)];
            res += t[gettype(mat[i-1][ans[3]], mat[i][ans[3]], false)];
        }
        if ((int)fabs(res-goalt) != minabsdis) printf("error!: check:%d, output:%d
    ", (int)fabs(res-goalt), minabsdis);
    }
    
    int main() {
        while (scanf("%d%d%d", &n, &m, &goalt) != EOF) {
            for (int i = 0; i < 3; i++) {
                scanf("%d", &t[i]);
            }
    
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    scanf("%d", &mat[i][j]);
                }
            }
    
            init();
    
            //for (int i = 0; i < n; i++) {
            //    printf("%d
    ", sum[U][i][0]);
            //}
    
    
            find();
    
            for (int i = 0; i < 4; i++) {
                printf("%d ", ans[i]+1);
            }puts("");
            //printf("minabsdis = %d
    ", minabsdis);
            checkAns();
        }
        return 0;
    }
  • 相关阅读:
    2019.10.07题解
    2019.10.06题解
    2019.10.05'题解
    2019.10.05题解
    java邮件发送
    注释类型 XmlType
    Spring 注解
    @SuppressWarnings(unchecked)作用解释
    vm文件
    Apache Shiro 使用手册(一)Shiro架构介绍
  • 原文地址:https://www.cnblogs.com/shinecheng/p/3773875.html
Copyright © 2011-2022 走看看