首页 > 代码库 > hdu 1540 Tunnel Warfare【线段树】

hdu 1540 Tunnel Warfare【线段树】

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1540

题目大意:抗日战争时期,各村庄被一条地道连接着(村庄排在一条线上),有三种操作:

第一种:某村庄被敌军摧毁;

第二种:修复上一个被摧毁的村庄;

第三种:查询与该村庄直接或间接链接的村庄有多少个(包括自己);


此题用线段树做,每个节点包含该区间从左端开始有多大连续区间ls,从右端向左有多大连续区间rs,该区间的最大连续区间mas;

代码如下:

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define N 50050
#define L(x) x<<1
#define R(x) (x<<1) | 1
using namespace std;
struct node
{
    int l,r;
    int ls,rs,mas;//ls表示该区间从左端点开始的连续区间长度,rs表示右边
}a[N<<2];
int sta[N];
void init(int l,int r,int inx)
{
    a[inx].l=l;
    a[inx].r=r;
    a[inx].ls=a[inx].rs=a[inx].mas=r-l+1;
    if(l!=r)
    {
        int mid=((l+r)>>1);
        init(l,mid,L(inx));
        init(mid+1,r,R(inx));
    }
}
void insert(int inx,int x,int flag)
{
    if(a[inx].l==a[inx].r)
    {
        if(flag)
            a[inx].ls=a[inx].rs=a[inx].mas=1;//修复
        else
            a[inx].ls=a[inx].rs=a[inx].mas=0;//破坏
        return ;
    }
    int mid=(a[inx].l+a[inx].r)>>1;
    if(x>mid)
        insert(R(inx),x,flag);
    else
        insert(L(inx),x,flag);
    a[inx].ls=a[L(inx)].ls;
    a[inx].rs=a[R(inx)].rs;
    a[inx].mas=max(max(a[L(inx)].mas,a[R(inx)].mas),a[L(inx)].rs+a[R(inx)].ls);
    if(a[L(inx)].ls==a[L(inx)].r-a[L(inx)].l+1)
        a[inx].ls+=a[R(inx)].ls;
    if(a[R(inx)].rs==a[R(inx)].r-a[R(inx)].l+1)
        a[inx].rs+=a[L(inx)].rs;
}
int query(int inx,int x)
{
    if(a[inx].l==a[inx].r||a[inx].mas==0||a[inx].mas==a[inx].r-a[inx].l+1)
        return a[inx].mas;
    int mid=(a[inx].l+a[inx].r)>>1;
    if(x<=mid)
    {
        if(x>=a[L(inx)].r-a[L(inx)].rs+1)
            //return query(L(inx),x)+query(R(inx),mid+1);
            return a[L(inx)].rs+a[R(inx)].ls;
        else
            return query(L(inx),x);
    }
    else
    {
        if(x<=a[R(inx)].l+a[R(inx)].ls-1)
            //return query(R(inx),x)+query(L(inx),mid);
            return a[L(inx)].rs+a[R(inx)].ls;
        else
            return query(R(inx),x);
    }
}
int main()
{
    int n,m;
    while(~scanf("%d%d",&n,&m))
    {
        int tail=0;
        init(1,n,1);
        for(int i=0;i<m;i++)
        {
            char temp[2];
            scanf("%s",temp);
            if(temp[0]=='D')
            {
                int x;
                scanf("%d",&x);
                insert(1,x,0);
                sta[tail++]=x;
            }
            else if(temp[0]=='R')
            {
                if(tail-1>=0)
                    insert(1,sta[--tail],1);
            }
            else if(temp[0]=='Q')
            {
                int x;
                scanf("%d",&x);
                printf("%d\n",query(1,x));
            }
        }
    }
    return 0;
}