首页 > 代码库 > bzoj1558 [JSOI2009]等差数列

bzoj1558 [JSOI2009]等差数列

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=1558

【题解】

这题恶心死人了啊。。

网络上题解很多都是看代码看代码。。真是太不负责任了。。我这里详细说一下吧。。

题解在代码下面。

技术分享
# include <stdio.h>
# include <string.h>
# include <iostream>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int M = 5e5 + 10;
const int mod = 1e9+7;

# define RG register
# define ST static

int n, a[M], b[M];

struct pa {
    int s, ll, rr, sz;
    // 连续(包括中间散的),左散,右散 
    long long l, r;
    pa() {}
    pa(int s, int ll, int rr, int sz, long long l, long long r) : s(s), ll(ll), rr(rr), sz(sz), l(l), r(r) {}
    friend pa operator + (pa a, pa b) {
        pa c; int f = (a.r == b.l);
        c.s = a.s + b.s; c.sz = a.sz + b.sz;
        c.l = a.l, c.r = b.r;
        if(a.s == 0 && b.s == 0) {
            if(!f) c.ll = c.rr = c.sz;
            else {
                c.ll = a.ll-1;
                c.rr = b.rr-1;
                ++c.s;
            }
            return c;
        }
        if(a.s == 0) {
            c.rr = b.rr;
            if(!f) c.ll = a.sz + b.ll;
            else {
                c.ll = a.ll-1;
                if(b.ll > 0) c.s += (b.ll-1)/2 + 1;
            }
            return c;
        }
        if(b.s == 0) {
            c.ll = a.ll;
            if(!f) c.rr = a.rr + b.sz;
            else {
                c.rr = b.rr-1;
                if(a.rr > 0) c.s += (a.rr-1)/2 + 1;
            }
            return c;
        }
        
        c.ll = a.ll, c.rr = b.rr;
        
        if(a.rr == 0 && b.ll == 0) {
            if(f) --c.s; 
            return c;
        }        
        if(a.rr == 0) {
            if(f) c.s += (b.ll-1)/2;
            else c.s += b.ll/2;
            return c;
        }        
        if(b.ll == 0) {
            if(f) c.s += (a.rr-1)/2;
            else c.s += a.rr/2; 
            return c;
        }
        
        int d = (a.rr + b.ll)/2;
        if(f) d = min(d, 1 + (a.rr-1)/2 + (b.ll-1)/2);
        
        c.s += d;
        
        return c;
    }
};

namespace SMT {
    pa w[M];
    ll tag[M];
    # define ls (x<<1)
    # define rs (x<<1|1)
    inline void up(int x) {
        if(!x) return;
        if(!rs) {w[x] = w[ls]; return;}
        if(!ls) {w[x] = w[rs]; return;}
        w[x] = w[ls] + w[rs];
    }
    inline void pushtag(int x, ll d) {
        w[x].l += d, w[x].r += d;
        tag[x] += d;
    }
    inline void down(int x) {
        if(!x) return;
        if(!tag[x]) return;
        pushtag(ls, tag[x]);
        pushtag(rs, tag[x]);
        tag[x] = 0;
    }
    inline void build(int x, int l, int r) {
        tag[x] = 0;
        if(l == r) {
            w[x] = pa(0, 1, 1, 1, b[l], b[l]);
            return ;
        }
        int mid = l+r>>1;
        build(ls, l, mid);
        build(rs, mid+1, r);
        up(x);
    }
    inline void edt(int x, int l, int r, int L, int R, ll d) {
        if(L <= l && r <= R) {
            pushtag(x, d);
            return ;
        }
        down(x);
        int mid = l+r>>1;
        if(L <= mid) edt(ls, l, mid, L, R, d);
        if(R > mid) edt(rs, mid+1, r, L, R, d);
        up(x);
    }
    inline pa query(int x, int l, int r, int L, int R) {
        if(L <= l && r <= R) return w[x];
        down(x);
        int mid = l+r>>1;
        if(R <= mid) return query(ls, l, mid, L, R);
        else if(L > mid) return query(rs, mid+1, r, L, R);
        else return query(ls, l, mid, L, mid) + query(rs, mid+1, r, mid+1, R);
    }
    inline void debug(int x, int l, int r) {
        printf("x=%d, l=%d, r=%d:  sum = %d, size = %d, left = %d, right = %d, lnum = %lld, rnum = %lld\n", x, l, r, w[x].s, w[x].sz, w[x].ll, w[x].rr, w[x].l, w[x].r);
        if(l==r) return;
        down(x);
        int mid = l+r>>1;
        debug(ls, l, mid);
        debug(rs, mid+1, r);
    }
}

int main() {
    int Q, l, r, a1, d;
    char opt[23];
    pa t;
    cin >> n;
    for (int i=1; i<=n; ++i) scanf("%d", a+i);
    for (int i=1; i<n; ++i) b[i] = a[i+1] - a[i];
//    for (int i=1; i<n; ++i) printf("%d ", b[i]); puts("");
    cin >> Q;
    if(n == 1) {
        while(Q--) {
            scanf("%s", opt);
            if(opt[0] == A) scanf("%*d%*d%*d%*d");
            if(opt[0] == B) {scanf("%*d%*d"); puts("1");}
        }
        return 0;
    }
    SMT::build(1, 1, n-1);
//    SMT::debug(1, 1, n-1); 
    while(Q--) {
        scanf("%s", opt);
        if(opt[0] == A) {
            scanf("%d%d%d%d", &l, &r, &a1, &d);
            if(l != 1) SMT::edt(1, 1, n-1, l-1, l-1, a1);
            if(l <= r-1) SMT::edt(1, 1, n-1, l, r-1, d);
            if(r != n) SMT::edt(1, 1, n-1, r, r, -1ll * (r-l)*d - a1); 
        }
        if(opt[0] == B) {
            scanf("%d%d", &l, &r);
            if(l == r) puts("1");
            else {
                t = SMT::query(1, 1, n-1, l, r-1);
                int ans = (r-l+1+1)/2;
                if(t.s == 0) printf("%d\n", ans);
                else {
                    ans = min(ans, t.s + (t.ll+1)/2 + (t.rr+1)/2);
                    printf("%d\n", ans);
                }
            }
        }
//        SMT::debug(1, 1, n-1); 
    }
    return 0;
}
View Code

就是这里这里是题解啦!

首先一段加等差数列。这个可以用线段树直接维护,但是为了询问方便,我们选择线段树维护差分数组

比如:

n=6, a[]={1,2,4,7,11,16}

那么维护的就是b[]={1,2,3,4,5},b[i] = a[i+1] - a[i]

对于b建立线段树进行维护。

由于我们发现n=1没有差分。。要特判下。

那么现在我们考虑修改:在[a,b]上加首项a1公差d的等差数列。由于我们维护的是差分数组,这个就很好办了。

假装[a,b]很长,那么中间的部分差分增加的是公差d(前一个+x+d,后一个+x+d+d)。

稍微推导下(这部分可以自己举例子,自行算算,特别是边界条件),即可得到:

修改:b[l-1] += a1, b[l...r-1] += d, b[r] -= (r-l)*d+a1

这里需要判断:如果我们修改了区间[1,?],那么b[l-1]不需要改,因为第一个地方没有差分。

如果我们修改了区间[x,x],就不需要进行第二个修改。等等。

至于第三个减,是为了维护后面差分稳定,容易看出a[r] += (r-l)*d+a1,所以a[r+1]-a[r]要减少(r-1)*d+a1

我们讨论完了修改我们开始讨论询问吧。

询问求的是最少分成多少个等差数列。由于是段,连续的,就比序列好办多了。

这不是简单求[l,r]区间的差分数组有多少个连续的段的问题。

因为,比如很简单的例子:

差分数组为b[]={1,2,3,4},a[]={x,x+1,x+3,x+6,x+10}

那么答案不是4段,而是3段,为什么?因为{x,x+1}和{x+3,x+6}和{x+10}都是等差,所以我们要判断2个数构成等差这个情况。

所以我们要多维护东西。

首先可以确定的是,一个区间的中间部分如果确定是哪些等差数列,以后是不会改变的。

所以我们维护:

1. 区间中的等差数列段数。

2. 区间左边剩下的零散数的个数,以及右边剩下的零散数的个数

3. 区间左边的数是什么,区间右边的数是什么

4. 区间大小。

假设区间为[l,r],那么我们最后的答案就是线段树查询[l,r-1](因为差分过了所以要-1)出来后:

1. 要么是(r-l+1+1)/2(r-l+1为区间长度,这种分法是两个数两个数分)

2. 要么是询问出来的区间中的等差数列段数加上区间左边剩下的零散数的个数,以及右边剩下的零散数的个数按照1中方法分成等差数列的个数。

至于两边的个数怎么对应到段上,可以通过画几个例子来解决。

好了现在只要考虑线段树内和查询时,数据合并了。

 1 struct pa {
 2     int s, ll, rr, sz;
 3     // 连续(包括中间散的),左散,右散 
 4     long long l, r;
 5     pa() {}
 6     pa(int s, int ll, int rr, int sz, long long l, long long r) : s(s), ll(ll), rr(rr), sz(sz), l(l), r(r) {}
 7     friend pa operator + (pa a, pa b) {
 8         pa c; int f = (a.r == b.l);
 9         c.s = a.s + b.s; c.sz = a.sz + b.sz;
10         c.l = a.l, c.r = b.r;
11         if(a.s == 0 && b.s == 0) {
12             if(!f) c.ll = c.rr = c.sz;
13             else {
14                 c.ll = a.ll-1;
15                 c.rr = b.rr-1;
16                 ++c.s;
17             }
18             return c;
19         }
20         if(a.s == 0) {
21             c.rr = b.rr;
22             if(!f) c.ll = a.sz + b.ll;
23             else {
24                 c.ll = a.ll-1;
25                 if(b.ll > 0) c.s += (b.ll-1)/2 + 1;
26             }
27             return c;
28         }
29         if(b.s == 0) {
30             c.ll = a.ll;
31             if(!f) c.rr = a.rr + b.sz;
32             else {
33                 c.rr = b.rr-1;
34                 if(a.rr > 0) c.s += (a.rr-1)/2 + 1;
35             }
36             return c;
37         }
38         
39         c.ll = a.ll, c.rr = b.rr;
40         
41         if(a.rr == 0 && b.ll == 0) {
42             if(f) --c.s; 
43             return c;
44         }        
45         if(a.rr == 0) {
46             if(f) c.s += (b.ll-1)/2;
47             else c.s += b.ll/2;
48             return c;
49         }        
50         if(b.ll == 0) {
51             if(f) c.s += (a.rr-1)/2;
52             else c.s += a.rr/2; 
53             return c;
54         }
55         
56         int d = (a.rr + b.ll)/2;
57         if(f) d = min(d, 1 + (a.rr-1)/2 + (b.ll-1)/2);
58         
59         c.s += d;
60         
61         return c;
62     }
63 };

我们看上述代码就是线段树内结构体以及怎么合并。

首先sz, l,r都直接合并。

然后要考虑的就是ll和rr还有s要怎么合并的问题了。

先处理左边/右边没有连续段的问题(都是散的)

分三类:都没有,左没有,右没有

这部分直接看,是11~37行

现在左边/右边都有连续段了,要处理的是:

左边段紧挨着边界这种情况,也就是ll=0或rr=0这种情况。

分三类:a.rr=0&&b.ll=0,a.rr=0,b.ll=0

第一类显然两个段如果a.r=b.l的话就能直接合并成一个段了,否则就正常合并,在第41~44行。

后两类,如果a.r=b.l的时候,剩下的一个(大于0的b.ll或a.rr)就可以少一个数进行划分了。在第45~54行。

否则,就按照正常的划分方法:要么全部划分成2个2个的,要么把中间相等的提出来,左右划分。

懒得讨论提出来是不是一定优了。。就取min呗

然后……呼终于说完了

 

这题挺好的。。细节贼多

 

bzoj1558 [JSOI2009]等差数列