首页 > 代码库 > 【NOI2005】维护数列

【NOI2005】维护数列

描述

请写一个程序,要求维护一个数列,支持以下6种操作:(请注意,格式栏中的下划线‘ _ ’表示实际输入文件中的空格)
技术分享

格式

输入格式

输入的第1 行包含两个数N 和M(M ≤20 000),N 表示初始时数列中数的个数,M表示要进行的操作数目。
第2行包含N个数字,描述初始时的数列。
以下M行,每行一条命令,格式参见问题描述中的表格。
任何时刻数列中最多含有500 000个数,数列中任何一个数字均在[-1 000, 1 000]内。
插入的数字总数不超过4 000 000个,输入文件大小不超过20MBytes。

输出格式

对于输入数据中的GET-SUM 和MAX-SUM 操作,依次输出结果,每个答案(数字)占一行。

样例1

样例输入1

9 8
2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM

样例输出1

-1
10
1
10

限制

每个测试点3s。

提示

样例说明:
技术分享
技术分享

来源

NOI 2005 Day 1

 

题解

伸展树维护数列的终极题目,共要维护两个标记和两个数列信息,为了维护MAX-SUM还要维护从左端开始的数列的最大和及到右端结束的数列的最大和。

按照伸展树的套路,给数列左右两边加上不存在的边界节点,给每个子树的空儿子指向哨兵节点。

维护最大子数列和

题目说的子数列其实要求至少包含一个元素,这要很恶心的维护。

(其实让max_sum可以不含元素也能过90%)

每个节点定义max_sum:该节点的最大数列和(至少包含一个元素)

max_lsum:该节点的从左端开始的最大数列和(可以不包含元素)

max_rsum:该节点的到右端结束的最大数列和(可以不包含元素)

按照分冶法,max_sum=max{左儿子max_sum,右儿子max_sum,左儿子max_rsum+该节点的值+右儿子max_lsum}。

如果它和它的左右儿子都是普通节点,这个转移保证至少有一个元素。

如果它是普通节点或边界节点,它的左或右儿子是哨兵节点,则左儿子max_sum或右儿子max_sum是不可取的。故令哨兵节点的max_sum=-inf。

如果它是边界节点,它必定至多有一个儿子,令它的max_sum等于它的唯一儿子的max_sum,max_lsum与max_rsum同理。

覆盖子数列和翻转子数列

每个节点定义两个标记replaced和reversed。

replaced:这个节点及它的所有后代都应该修改为一个特定的值,但实际上只有这个节点的值已经修改。

reversed:这个节点及它的所有后代都应该交换左右子树(max_lsum和max_rsum也应该跟着交换),但实际上只有这个节点的左右子树已经交换。

可见这两个标记是互斥的,且replaced标记的优先级显然大于reversed标记。

打标记的时候注意维护每个结点的标记至多有一个就可以了。

 

309行的不压行代码,7.85KB,调了近8小时才AC:

 

#include <algorithm>
#include <iostream>
#include <string>
using namespace std;
namespace splay
{
const int inf = 0x7fffffff;
enum direction
{
    l = 0,
    r
};
struct node;
node *nil = 0, *l_edge, *r_edge;
struct node
{
    int val, size;
    node *ch[2];
    int sum;
    int max_sum, max_lsum, max_rsum;
    // max_sum 定义为最少包含一个元素的最大子数列和
    // max_lsum 定义为从左端开始的可以不包含元素的最大子数列和
    // max_lsum 定义为到右端结束的可以不包含元素的最大子数列和

    bool replaced, reversed;
    // 当replaced为true,表示它的所有后代的val应该与这个节点的val相同,但实际上后代节点并没有更新
    // 当reversed为true,表示它已经交换了左右节点和左右最大值,且它的所有后代都应该交换左右子树和左右最大值,但实际上后代节点并没有更新

    node(int v) : val(v), size(1), sum(v), replaced(false), reversed(false)
    {
        ch[l] = ch[r] = nil;
        if (v >= 0)
            max_sum = max_lsum = max_rsum = sum;
        else
        {
            max_sum = v;
            max_lsum = max_rsum = 0;
        }
    }
    int cmp(int k)
    {
        if (k == ch[l]->size + 1 || this == nil)
            return -1;
        else
            return k <= ch[l]->size ? l : r;
    }

    void reverse()
    {
        if (!replaced)
        {
            reversed ^= 1;
            swap(ch[l], ch[r]);
            swap(max_lsum, max_rsum);
        }
    }
    void replace(int v)
    {
        reversed = false;
        replaced = true;
        val = v;
        sum = v * size;
        if (v > 0)
            max_sum = max_lsum = max_rsum = sum;
        else
        {
            max_sum = v; // 由于子数列要求至少有一个元素,故当 val < 0 ,只有一个元素时和最大
            max_lsum = max_rsum = 0;
        }
    }

    void push_down()
    {
        if (replaced)
        {
            ch[l]->replace(val);
            ch[r]->replace(val);
            replaced = false;
        }
        else if (reversed)
        {
            ch[l]->reverse();
            ch[r]->reverse();
            reversed = false;
        }
    }
    void pull_up()
    {
        if (this != nil)
        {
            size = ch[l]->size + ch[r]->size + 1;

            if (!replaced)
                sum = ch[l]->sum + ch[r]->sum + val;
            else
                sum = val * size;

            if (this != l_edge && this != r_edge)
            {
                max_sum = max(ch[l]->max_rsum + val + ch[r]->max_lsum, max(ch[l]->max_sum, ch[r]->max_sum)); // 更新后 max_sum 至少包含一个元素
                max_lsum = max(ch[l]->max_lsum, ch[l]->sum + val + ch[r]->max_lsum);                         // 更新后 max_lsum / max_rsum 可以不包含元素
                max_rsum = max(ch[r]->max_rsum, ch[l]->max_rsum + val + ch[r]->sum);
            }
            else if (this == l_edge) // 注意特判左右边界节点
            {
                // 若不特判,当左边界节点为根且整个数列的从左开始的最大值为0时
                // 就会出现 max_sum = ch[l]->max_rsum + val + ch[r]->max_lsum
                // 即 max_sum = 0,这显然不合法
                max_sum = ch[r]->max_sum;
                max_lsum = ch[r]->max_lsum;
                max_rsum = ch[r]->max_rsum;
            }
            else
            {
                // 右边界同理
                max_sum = ch[l]->max_sum;
                max_lsum = ch[l]->max_lsum;
                max_rsum = ch[l]->max_rsum;
            }
        }
    }

    void remove()
    {
        if (this != nil)
        {
            ch[l]->remove();
            ch[r]->remove();
            delete this;
        }
    }
} * root;
void init()
{
    if (!nil)
        nil = new node(0);
    l_edge = new node(0), r_edge = new node(0);
    nil->size = 0;
    nil->ch[l] = nil->ch[r] = nil;
    // 注意维持哨兵节点,边界节点的 max_sum 为负无穷,保证普通节点的 max_sum 合法
    nil->max_sum = -inf;
    l_edge->max_sum = -inf;
    r_edge->max_sum = -inf;
    root = nil;
}
void rotate(node *&t, int d)
{
    t->push_down();
    t->ch[l]->push_down();
    t->ch[r]->push_down();
    node *k = t->ch[d ^ 1];
    t->ch[d ^ 1] = k->ch[d];
    k->ch[d] = t;
    t->pull_up();
    k->pull_up();
    t = k;
}
void splay(node *&t, int k)
{
    t->push_down();
    int d = t->cmp(k);
    if (d == r)
        k = k - t->ch[l]->size - 1;
    if (d != -1)
    {
        t->ch[d]->push_down();
        int d2 = t->ch[d]->cmp(k);
        int k2 = (d2 == r) ? k - t->ch[d]->ch[l]->size - 1 : k;
        if (d2 != -1)
        {
            splay(t->ch[d]->ch[d2], k2);
            if (d == d2)
            {
                rotate(t, d ^ 1);
                rotate(t, d ^ 1);
            }
            else
            {
                rotate(t->ch[d], d2 ^ 1);
                rotate(t, d ^ 1);
            }
        }
        else
            rotate(t, d ^ 1);
    }
}
void join(node *&t1, node *&t2)
{
    if (t1 == nil)
        swap(t1, t2);
    splay(t1, t1->size);
    t1->ch[r] = t2;
    t2 = nil;
    t1->pull_up();
}
node *split(node *&t, int k)
{
    if (k == 0)
    {
        node *subtree = t;
        t = nil;
        return subtree;
    }
    splay(t, k);
    node *subtree = t->ch[r];
    t->ch[r] = nil;
    t->pull_up();
    return subtree;
}
node *build_tree(int *p, int n)
{
    if (n == 0)
        return nil;
    node *fa;
    node *ch = new node(p[1]);
    for (int i = 2; i <= n; i++)
    {
        fa = new node(p[i]);
        fa->ch[l] = ch;
        fa->pull_up();
        ch = fa;
    }
    return fa;
}
node *select(int p, int tot)
{
    int ln = p, rn = ln + tot - 1;
    splay(root, rn + 1);
    splay(root->ch[l], ln - 1);
    return root->ch[l]->ch[r];
}
}
int n, m;
int num[500005];
int main()
{
    using namespace splay;
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> num[i];
    init();

    node *t1, *t2; //tmp
    root = l_edge;
    t1 = build_tree(num, n);
    join(root, t1);
    t1 = r_edge;
    join(root, t1);

    string opt;
    int posi, tot, c;
    while (m--)
    {
        cin >> opt;
        switch (opt[0])
        {
        case I: //INSERT
            cin >> posi >> tot;
            posi++;
            for (int i = 1; i <= tot; i++)
                cin >> num[i];
            t1 = build_tree(num, tot);
            t2 = split(root, posi);
            join(root, t1);
            join(root, t2);
            break;
        case D: //DELETE
            cin >> posi >> tot;
            posi++;
            t1 = split(root, posi - 1);
            t2 = split(t1, tot);
            join(root, t2);
            t1->remove();
            break;
        case R: //REVERSE
            cin >> posi >> tot;
            posi++;
            t1 = select(posi, tot);
            t1->reverse();
            root->ch[l]->pull_up();
            root->pull_up();
            break;
        case G: //GET-SUM
            cin >> posi >> tot;
            posi++;
            t1 = select(posi, tot);
            cout << t1->sum << endl;
            break;
        case M:
            if (opt[2] == K) //MAKE_SAME
            {
                cin >> posi >> tot >> c;
                posi++;
                t1 = select(posi, tot);
                t1->replace(c);
                root->ch[l]->pull_up();
                root->pull_up();
            }
            else //MAX_SUM
                cout << root->max_sum << endl;
            break;
        case S:
            cin >> posi;
            splay::splay(root, posi);
            break;
        }
    }
    return 0;
}

 

【NOI2005】维护数列