首页 > 代码库 > CodeForces 424D: ...(二分)

CodeForces 424D: ...(二分)

题意:给出一个n*m的矩阵,内有一些数字。当你从一个方格走到另一个方格时,按这两个方格数字的大小,有(升,平,降)三种费用。你需要在矩阵中找到边长大于2的一个矩形,使得按这个矩形顺时针行走一圈的费用,与给定费用最接近。3<=n,m<=300。

思路:O(1)计算一个矩形的费用不是什么难事,因为考虑到有前缀性质(前缀性质:[l,r] = [0,r] - [0,l-1]),只要预处理好各行各个方向行走的费用,就容易计算。

直接枚举容易得到O(n^4)的算法。难以过。这时就应当想到优化。实际上,经过优化,可以得到O(n^3 *log(n))的算法。优化的方法如下:只枚举上下两层位置和右边界位置,正常思路是再枚举左边界位置,如果我们能二分得到左边界位置,就完美了。可惜直接二分并不满足性质。[本题关键点]这时需要构造一个前缀性质。如图

细了就不说了。思考一下吧~。然后边扫边插入前面的前缀到set里面,然后用lower_bound就可以了。不过注意边界问题。

代码:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <set>
#include <vector>

using namespace std;

#define R 0
#define L 1
#define U 2
#define D 3
#define N 400
int sum[4][N][N];
int t[3];
int mat[N][N];
int n, m, goalt;

void init() {
    for (int i = 0; i < n; i++) {
        sum[R][i][0] = sum[L][i][0] = 0;
        for (int j = 1; j < m; j++) {
            sum[R][i][j] = sum[R][i][j-1];
            sum[L][i][j] = sum[L][i][j-1];
            if (mat[i][j] == mat[i][j-1]) {
                sum[R][i][j] += t[0];
                sum[L][i][j] += t[0];
                continue;
            }
            if (mat[i][j] > mat[i][j-1]) {
                sum[R][i][j] += t[1];
                sum[L][i][j] += t[2];
                continue;
            }
            if (mat[i][j] < mat[i][j-1]) {
                sum[R][i][j] += t[2];
                sum[L][i][j] += t[1];
                continue;
            }
        }
    }

    for (int j = 0; j < m; j++) {
        sum[U][0][j] = sum[D][0][j] = 0;
        for (int i = 1; i < n; i++) {
            sum[U][i][j] = sum[U][i-1][j];
            sum[D][i][j] = sum[D][i-1][j];
            if (mat[i][j] == mat[i-1][j]) {
                sum[U][i][j] += t[0];
                sum[D][i][j] += t[0];
                continue;
            }
            if (mat[i-1][j] > mat[i][j]) {
                sum[U][i][j] += t[1];
                sum[D][i][j] += t[2];
                continue;
            }
            if (mat[i-1][j] < mat[i][j]) {
                sum[U][i][j] += t[2];
                sum[D][i][j] += t[1];
                continue;
            }
        }
    }
}

int ans[4];
int minabsdis;
typedef pair<int,int> pii;

//#define FIX 10000000
inline int getLeftVal(int iup, int idn, int j) {
    //int res = -(sum[U][idn][j] - sum[U][iup][j]) + sum[R][iup][j] + sum[L][idn][j];
    //printf("(%d,%d,%d) = %d\n", iup+1, idn+1, j, res);
    return -(sum[U][idn][j] - sum[U][iup][j]) + sum[R][iup][j] + sum[L][idn][j];
}
void find() {
    minabsdis = 999999999;
    for (int iup = 0; iup < n; iup++) {
        for (int idn = iup+2; idn < n; idn++) {
            set< pair<int,int> > leftsum;
            leftsum.clear();
            set< pair<int,int> >::iterator spi;
            //printf("(%d,%d)\n", iup+1, idn+1);
            leftsum.insert(pii(getLeftVal(iup,idn,0), 0));
            //printf("first = (%d,%d)\n", (*leftsum.begin()).first,(*leftsum.begin()).second);
            for (int j = 2; j < m; j++) {
                int now = sum[R][iup][j] + sum[L][idn][j] + sum[D][idn][j]-sum[D][iup][j];
                int should = now - goalt;
                //printf("(%d) should = %d\n", j, should);
                spi = leftsum.lower_bound(pii(should, 0));
                if (spi == leftsum.end()) {
                    //puts("meet end");
                    spi--;
                }
                else if (spi != leftsum.begin()){
                    int rnum = now-(*spi).first;
                    spi--;
                    int lnum = now-(*spi).first;
                    spi++;
                    if (fabs(lnum-goalt) < fabs(rnum-goalt)) {
                        //puts("minus");
                        spi--;
                    }
                }
                pii findpair = *spi;
                //printf("find (%d,%d)\n", findpair.first, findpair.second);
                int final = now - findpair.first;
                if ((int)fabs(final-goalt) < minabsdis) {
                    //puts("lala");
                    minabsdis = fabs(final-goalt);
                    ans[0] = iup;
                    ans[1] = findpair.second;
                    ans[2] = idn;
                    ans[3] = j;
                }
                leftsum.insert(pii(getLeftVal(iup,idn,j-1), j-1));
            }
        }
    }
}

int gettype(int l, int r, bool rev) {
    if (rev) l^=r^=l^=r;
    if (l==r) return 0;
    if (l<r) return 1;
    if (l>r) return 2;
}
void checkAns() {
    int res = 0;
    for (int j = ans[1]+1; j <= ans[3]; j++) {
        res += t[gettype(mat[ans[0]][j-1], mat[ans[0]][j], false)];
        res += t[gettype(mat[ans[2]][j-1], mat[ans[2]][j], true)];
    }
    for (int i = ans[0]+1; i <= ans[2]; i++) {
        res += t[gettype(mat[i-1][ans[1]], mat[i][ans[1]], true)];
        res += t[gettype(mat[i-1][ans[3]], mat[i][ans[3]], false)];
    }
    if ((int)fabs(res-goalt) != minabsdis) printf("error!: check:%d, output:%d\n", (int)fabs(res-goalt), minabsdis);
}

int main() {
    while (scanf("%d%d%d", &n, &m, &goalt) != EOF) {
        for (int i = 0; i < 3; i++) {
            scanf("%d", &t[i]);
        }

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                scanf("%d", &mat[i][j]);
            }
        }

        init();

        //for (int i = 0; i < n; i++) {
        //    printf("%d\n", sum[U][i][0]);
        //}


        find();

        for (int i = 0; i < 4; i++) {
            printf("%d ", ans[i]+1);
        }puts("");
        //printf("minabsdis = %d\n", minabsdis);
        checkAns();
    }
    return 0;
}