首页 > 代码库 > HDU 1540 && POJ 2892 Tunnel Warfare (线段树,区间合并).

HDU 1540 && POJ 2892 Tunnel Warfare (线段树,区间合并).

~~~~

第一次遇到线段树合并的题,又被律爷教做人。TAT.

~~~~

线段树的题意都很好理解吧。。

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=1540

http://poj.org/problem?id=2892

~~~~

我的代码:200ms

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 55555
#define lson rt<<1,s,m
#define rson rt<<1|1,m+1,e
using namespace std;

int st[N];
struct node
{
    int l,r;
    int lm,rm;  //维护一个区间从左右端点开始的最长区间
}tre[N<<2];
void build(int rt,int s,int e)
{
    tre[rt].l=s;
    tre[rt].r=e;
    tre[rt].lm=tre[rt].rm=e-s+1;
    if(s==e)
        return ;
    int m=(s+e)>>1;
    build(lson);
    build(rson);
}
void update(int p,int v,int rt,int s,int e)
{
    if(s==e)
    {
        if(v) tre[rt].lm=tre[rt].rm=1; //rebuild
        else tre[rt].lm=tre[rt].rm=0;  //destroy
        return ;
    }
    int m=(s+e)>>1;
    if(p<=m) update(p,v,lson);
    else update(p,v,rson);
    tre[rt].lm=tre[rt<<1].lm;
    tre[rt].rm=tre[rt<<1|1].rm;
    
    //若lm==区间长,还要加上右孩子的lm。
    if(tre[rt<<1].lm==tre[rt<<1].r-tre[rt<<1].l+1)  
        tre[rt].lm=tre[rt<<1].lm+tre[rt<<1|1].lm;
    if(tre[rt<<1|1].rm==tre[rt<<1|1].r-tre[rt<<1|1].l+1)   //同理~
        tre[rt].rm=tre[rt<<1|1].rm+tre[rt<<1].rm;
}
int query(int q,int rt,int s,int e)
{
    //总是把查询操作归结到一个父亲节点下的两个孩子节点的中间区域的最长连续区间。
    if(s==e) return tre[rt].lm;
    int m=(s+e)>>1;
    int l=m-tre[rt<<1].rm;
    int r=m+1+tre[rt<<1|1].lm;
    if(q>l && q<r) //要查找的端点就在该区间的连续区间中,则返回。
        return tre[rt<<1].rm+tre[rt<<1|1].lm;
    else if(q<=l) query(q,lson);
    else query(q,rson);
}
int main()
{
    int n,m;
    while(~scanf("%d%d",&n,&m))
    {
        int top=-1;
        build(1,1,n);
        for(int i=0;i<m;i++)
        {
            int x;
            char str[5];
            scanf("%s",str);
            if(str[0]=='D')
            {
                scanf("%d",&x);
                update(x,0,1,1,n);
                st[++top]=x;
            }
            else if(str[0]=='R')
            {
                int y=st[top--];
                if(top>=-1)
                    update(y,1,1,1,n);
            }
            else if(str[0]=='Q')
            {
                scanf("%d",&x);
                int k=query(x,1,1,n);
                printf("%d\n",k);
            }
        }
    }
    return 0;
}


~~~~

之前参考网上代码所写:300ms;


#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 55555
#define lson rt<<1,s,m
#define rson rt<<1|1,m+1,e
using namespace std;

int st[N];
struct node
{
    int l,r;
    int lm,rm,sm; //还要维护一个区间的最长连续区间
}tre[N<<2];
void build(int rt,int s,int e)
{
    tre[rt].l=s;
    tre[rt].r=e;
    tre[rt].lm=tre[rt].rm=tre[rt].sm=e-s+1;
    if(s==e)
        return ;
    int m=(s+e)>>1;
    build(lson);
    build(rson);
}
void update(int p,int v,int rt,int s,int e)
{
    if(s==e)
    {
        if(v) tre[rt].lm=tre[rt].rm=tre[rt].sm=1;
        else tre[rt].lm=tre[rt].rm=tre[rt].sm=0;
        return ;
    }
    int m=(s+e)>>1;
    if(p<=m) update(p,v,lson);
    else update(p,v,rson);
    tre[rt].lm=tre[rt<<1].lm;
    tre[rt].rm=tre[rt<<1|1].rm;
    tre[rt].sm=max(max(tre[rt<<1].sm,tre[rt<<1|1].sm),tre[rt<<1].rm+tre[rt<<1|1].lm);
    if(tre[rt<<1].lm==tre[rt<<1].r-tre[rt<<1].l+1)
        tre[rt].lm=tre[rt<<1].lm+tre[rt<<1|1].lm;
    if(tre[rt<<1|1].rm==tre[rt<<1|1].r-tre[rt<<1|1].l+1)
        tre[rt].rm=tre[rt<<1|1].rm+tre[rt<<1].rm;
}
int query(int q,int rt,int s,int e)
{
    //到达叶子节点或是当前节点为空或为满的情况,直接返回。
    if(s==e || tre[rt].sm==0 || tre[rt].sm==tre[rt].r-tre[rt].l+1)
        return tre[rt].sm;
    int m=(s+e)>>1;
    if(q<=m)
    {
        //若是在做孩子的右连续区间,那么还要看右孩子的左连续区间
        if(q>=tre[rt<<1].r-tre[rt<<1].rm+1)
            return query(q,lson)+query(m+1,rson);
        else
            return query(q,lson);
    }
    else
    {
        //同理~
        if(q<=tre[rt<<1|1].lm+tre[rt<<1|1].l-1)
            return query(m,lson)+query(q,rson);
        else
            return query(q,rson);
    }
}
int main()
{
    int n,m;
    while(~scanf("%d%d",&n,&m))
    {
        int top=-1;
        build(1,1,n);
        for(int i=0;i<m;i++)
        {
            int x;
            char str[5];
            scanf("%s",str);
            if(str[0]=='D')
            {
                scanf("%d",&x);
                update(x,0,1,1,n);
                st[++top]=x;
            }
            else if(str[0]=='R')
            {
                int y=st[top--];
                if(top>=-1)
                    update(y,1,1,1,n);
            }
            else if(str[0]=='Q')
            {
                scanf("%d",&x);
                int k=query(x,1,1,n);
                printf("%d\n",k);
            }
        }
    }
    return 0;
}