首页 > 代码库 > hdu1024 Max Sum Plus Plus

hdu1024 Max Sum Plus Plus

传送门:http://acm.hdu.edu.cn/showproblem.php?pid=1024

http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1053

【题解】

本题也是51nod 1053

最大m子段和

和上题很像

如果正数个数<段数,那么输出前m大。

否则

考虑线段树,跟bzoj3638一个做法。

如果中间跳出,那么选取的这些部分一定可以被分成m段

技术分享
# include <stdio.h>
# include <string.h>
# include <iostream>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int M = 5e5 + 10;
const int mod = 1e9+7;

# define RG register
# define ST static

int n, a[M];

struct pa {
    int l, r;
    ll x;
    pa() {}
    pa(int l, int r, ll x) : l(l), r(r), x(x) {}
    friend pa operator + (pa a, pa b) {
        pa c; c.x = a.x + b.x;
        c.l = a.l, c.r = b.r;
        return c;
    }
    friend bool operator < (pa a, pa b) {
        return a.x < b.x;
    }
    friend bool operator > (pa a, pa b) {
        return a.x > b.x;
    }
};

struct querys {
    pa lmx, rmx, mx, s;
    querys() {}
    querys(pa lmx, pa rmx, pa mx, pa s) : lmx(lmx), rmx(rmx), mx(mx), s(s) {}
};

namespace SMT {
    const int Ms = 1e6 + 10;
    pa lmx[Ms], rmx[Ms], lmi[Ms], rmi[Ms], mx[Ms], mi[Ms], s[Ms];
    bool tag[Ms];        // -1
    # define ls (x<<1)
    # define rs (x<<1|1)
    inline void up(int x) {
        if(!x) return ;
        lmx[x] = max(lmx[ls], s[ls] + lmx[rs]);
        lmi[x] = min(lmi[ls], s[ls] + lmi[rs]);
        rmx[x] = max(rmx[rs], rmx[ls] + s[rs]);
        rmi[x] = min(rmi[rs], rmi[ls] + s[rs]);
        mx[x] = max(mx[ls], mx[rs]);
        mx[x] = max(mx[x], rmx[ls] + lmx[rs]);
        mi[x] = min(mi[ls], mi[rs]);
        mi[x] = min(mi[x], rmi[ls] + lmi[rs]);
        s[x] = s[ls] + s[rs];
    }
    inline void pushtag(int x) {
        if(!x) return ;
        lmx[x].x = -lmx[x].x;
        rmx[x].x = -rmx[x].x;
        lmi[x].x = -lmi[x].x;
        rmi[x].x = -rmi[x].x;
        mx[x].x = -mx[x].x;
        mi[x].x = -mi[x].x;
        s[x].x = -s[x].x;
        swap(mx[x], mi[x]);
        swap(lmx[x], lmi[x]);
        swap(rmx[x], rmi[x]);
        tag[x] ^= 1;
    }
    inline void down(int x) {
        if(!x) return ;
        if(!tag[x]) return ;
        pushtag(ls); pushtag(rs);
        tag[x] = 0;
    }
    inline void change(int x, int l, int r, int ps, int d) {
        if(l == r) {
            s[x].l = s[x].r = lmx[x].l = lmx[x].r = rmx[x].l = rmx[x].r = lmi[x].l = lmi[x].r = rmi[x].l = rmi[x].r = l;
            mx[x].l = mx[x].r = mi[x].l = mi[x].r = l;
            s[x].x = mx[x].x = mi[x].x = lmx[x].x = lmi[x].x = rmx[x].x = rmi[x].x = d;
            tag[x] = 0;
            return ;
        }
        down(x);
        int mid = l+r>>1;
        if(ps <= mid) change(ls, l, mid, ps, d);
        else change(rs, mid+1, r, ps, d);
        up(x);
    }
    
    inline void change2(int x, int l, int r, int L, int R) {
        if(L <= l && r <= R) {
            pushtag(x);
            return ;
        }
        down(x);
        int mid = l+r>>1;
        if(L <= mid) change2(ls, l, mid, L, R);
        if(R > mid) change2(rs, mid+1, r, L, R);
        up(x);
    }

    inline querys merge(querys a, querys b) {
        querys c;
        c.lmx = max(a.lmx, a.s+b.lmx);
        c.rmx = max(b.rmx, a.rmx+b.s);
        c.s = a.s + b.s;
        c.mx = max(a.mx, b.mx);
        c.mx = max(c.mx, a.rmx + b.lmx);
        return c;
    }
    
    inline querys query(int x, int l, int r, int L, int R) {
        if(L <= l && r <= R) return querys(lmx[x], rmx[x], mx[x], s[x]);
        down(x);
        int mid = l+r>>1;
        if(R <= mid) return query(ls, l, mid, L, R);
        else if(L > mid) return query(rs, mid+1, r, L, R);
        else return merge(query(ls, l, mid, L, mid), query(rs, mid+1, r, mid+1, R));
    }
    
    inline void debug(int x, int l, int r) {
        printf("x = %d, l = %d, r = %d : mx = %d, lmx = %d, rmx = %d\n", x, l, r, mx[x].x, mx[x].l, mx[x].r);
        if(l == r) return ;
        int mid = l+r>>1;
        debug(ls, l, mid);
        debug(rs, mid+1, r);
    }
}

int m;

int main() {
    while(cin >> m >> n) {
        int su = 0;
        for (int i=1; i<=n; ++i) {
            scanf("%d", &a[i]);
            SMT::change(1, 1, n, i, a[i]);
            su += (a[i]>=0);
        }
        ll s = 0;
        if(su < m) {
            sort(a+1, a+n+1);
            for (int i=n, j=1; j<=m; j++, i--) s += a[i];
            cout << s << endl;
        } else {
            querys t;
            while(m--) {
                t = SMT::query(1, 1, n, 1, n);
                if(t.mx.x < 0) break;
                else s += t.mx.x;
                SMT::change2(1, 1, n, t.mx.l, t.mx.r);
            }
            cout << s << endl;
        }
    }
    return 0;
}
View Code

md O(mlogn)正解跑的比O(nm)慢?这hdu数据有多水?

upd: hdu的时候,我代码那个1e6+10要改成4e6+10(线段树的空间要开4倍),手癌晚期懒的改了

技术分享

 

hdu1024 Max Sum Plus Plus