首页 > 代码库 > 线段树模板(结构体)

线段树模板(结构体)

线段树研究了两天了,总算有了点眉目,今天也把落下的题,补了一下。 贴一份线段树模板


线段树的特点:
1. 每一层都是区间[a, b]的一个划分,记 L = b - a

2. 一共有log2L层
3. 给定一个点p,从根到叶子p上的所有区间都包含点p,且其他区间都不包含点p。
4. 给定一个区间[l; r],可以把它分解为不超过2log2 L条不相交线段的并。


总结来说:线段树最近本的应用是4点:

1.单点更新:单点替换、单点增减

2.单点询问

3.区间询问:区间之和、区间最值

4.区间更新:区间替换、区间增减


下面是 这4个基本操作的模板:(有点儿挫)

单点替换  区间求最大

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
#define N 200004
#define MAX INT_MAX
#define MIN INT_MIN
struct node
{
	int left,right;
	int num; //
}T[4*N];
int ans=0;
void Creat(int left,int right,int id)//建树
{
	T[id].left =left;
	T[id].right =right;
	T[id].num =0;
	if(T[id].left ==T[id].right )
		return ;
	Creat(left,(left+right)/2,2*id);
	Creat((left+right)/2+1,right,2*id+1);
}
void UPdata(int id,int i,int j)

{
	if(T[id].left<=i&&T[id].right >=i)
		T[id].num = j;
	if(T[id].left ==T[id].right )
		return;
	if(i>T[id].right )
		return;
	if(i<T[id].left )
		return;
	int mid=(T[id].left +T[id].right )/2;
	if(i<=mid)
		UPdata(id*2,i,j);
	else
		UPdata(id*2+1,i,j);
	T[id].num = max(T[id*2].num,T[id*2+1].num);
}
void query(int id,int l,int r)//区间&&单点查询,l-r 区间内的所有人
{
	int mid=(T[id].left +T[id].right)/2;
	if(T[id].left ==l&&T[id].right ==r)
	{
		if(T[id].num >ans)
            ans = T[id].num;
		return;
	}
	if(r<=mid)
		query(2*id,l,r);
	else if(l>mid)
		query(2*id+1,l,r);
	else
	{
		query(2*id,l,mid);
		query(2*id+1,mid+1,r);
	}

}

int main()
{
    int n,m,x,l,r;
    char str[5];
    while(scanf("%d%d",&n,&m)!=EOF)
    {
        Creat(1,n,1);
        for(int i = 1;i<=n;i++)
        {
            scanf("%d",&x);
            UPdata(1,i,x);
        }
        for(int i = 1;i<=m;i++)
        {
            scanf("%s",str);
            if(str[0]=='Q')
            {
                scanf("%d%d",&l,&r);
                ans = -9999999;
               query(1,l,r);
                printf("%d\n",ans);
                ans = -9999999;
            }
            else if(str[0]=='U')
            {
                scanf("%d%d",&l,&r);
                UPdata(1,l,r);
            }
        }
    }
    return 0;
}


单点增减  区间求和

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
#define N 50004
#define MAX INT_MAX
#define MIN INT_MIN
struct node
{
	int left,right;
	int num; //
}T[4*N];
int ans=0;
void Creat(int left,int right,int id)//建树
{
	T[id].left =left;
	T[id].right =right;
	T[id].num =0;
	if(T[id].left ==T[id].right )
		return ;
	Creat(left,(left+right)/2,2*id);
	Creat((left+right)/2+1,right,2*id+1);
}
void UPdata(int id,int i,int j)//单点更新
{
	if(T[id].left<=i&&T[id].right >=i)
		T[id].num +=j;
	if(T[id].left ==T[id].right )
		return;
	if(i>T[id].right )
		return;
	if(i<T[id].left )
		return;
	int mid=(T[id].left +T[id].right )/2;
	if(i<=mid)
		UPdata(id*2,i,j);
	else
		UPdata(id*2+1,i,j);
}
void query(int id,int l,int r)//区间&&单点查询
{
	int mid=(T[id].left +T[id].right)/2;
	if(T[id].left ==l&&T[id].right ==r)
	{
		ans+=T[id].num ;
		return;
	}
	if(r<=mid)
		query(2*id,l,r);
	else if(l>mid)
		query(2*id+1,l,r);
	else
	{
		query(2*id,l,mid);
		query(2*id+1,mid+1,r);
	}

}
int main()
{
	int t,n,num,l,r,C=1;
	char str[20];
	scanf("%d",&t);
	while(t--)
	{
		printf("Case %d:\n",C++);
		scanf("%d",&n);
		Creat(1,n,1);
		for(int i=1;i<=n;i++)
		{
			scanf("%d",&num);
			UPdata(1,i,num);
		}

		while(scanf("%s",str))
		{
			if(str[0]=='E')
				break;
		else if(str[0]=='Q')
			{
				scanf("%d%d",&l,&r);
				query(1,l,r);
				printf("%d\n",ans);
				ans=0;
			}
		else if(str[0]=='A')
			{
				scanf("%d%d",&l,&r);
				UPdata(1,l,r);
			}
		else if(str[0]=='S')
			{
				scanf("%d%d",&l,&r);
				UPdata(1,l,-r);
			}
       /* else if(str[0]=='D')//单点查询
            {
                scanf("%d",&l);
                query(1,l,l);
                printf("%d\n",ans);
                ans = 0;
            }*/

		}

	}
	return 0;
}


区间增减 


#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>

#define max(a,b) (a>b)?a:b
#define min(a,b) (a>b)?b:a
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
#define LL __int64
const int maxn = 500100;
using namespace std;
#define MAX INT_MAX
#define MIN INT_MIN
struct node
{
    int l,r;
    LL add,sum;  //add作为一个数的累加和,同时起标记的作用,即lazy数组的作用
}T[300010];
int a[100005];   //add必须是__int64;
void putup(int id)
{
    T[id].sum=T[2*id].sum+T[2*id+1].sum;
}
void putdown(int id)
{
    if(T[id].add) //更新左右孩子
    {
        T[2*id].add+=T[id].add;
        T[2*id].sum += (T[2*id].r-T[2*id].l+1)*T[id].add;
        T[2*id+1].add+=T[id].add;
        T[2*id+1].sum += (T[2*id+1].r-T[2*id+1].l+1)*T[id].add;
        T[id].add=0;  //取消标记
    }
}
void creat(int l,int r,int id)
{
    T[id].l=l;
    T[id].r=r;
    T[id].add=0;
    if(l==r)
    {
        T[id].sum=a[r];
        return;
    }
    int mid=(l+r)>>1;
    creat(l,mid,2*id);
    creat(mid+1,r,2*id+1);
    putup(id);
}
void update(int from,int to,LL add,int id)
{
    if(from<=T[id].l&&to>=T[id].r)
    {
        T[id].add +=add;
        T[id].sum += (T[id].r-T[id].l+1)*add;
        return;
    }

    putdown(id);
    if(from<=T[2*id].r)
        update(from,to,add,2*id);
    if(to>=T[2*id+1].l)
        update(from,to,add,2*id+1);
    putup(id);
}
LL query(int from,int to,int id)
{
    if(from==T[id].l&&to==T[id].r)
        return T[id].sum;

    putdown(id);
    if(from>=T[2*id+1].l)
        return query(from,to,2*id+1);
    else if(to<=T[2*id].r)
        return query(from,to,2*id);
    else
    return query(from,T[2*id].r,2*id) + query(T[2*id+1].l,to,2*id+1);
}
int main()
{
    int n,m,A,B;
    LL add;
    char str[5];
    while(scanf("%d%d",&n,&m)!=EOF)
    {
        for(int i=1; i<=n; i++)
            scanf("%d",&a[i]);
        creat(1,n,1);
        while(m--)
        {
            LL ans = 0;
            scanf("%s",str);
            if(str[0]=='C')
            {
                scanf("%d%d%I64d",&A,&B,&add);
                update(A,B,add,1);
            }
            else
            {
                scanf("%d%d",&A,&B);
                ans=query(A,B,1);
                printf("%I64d\n",ans);
            }
        }
    }
}

区间替换

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>

#define max(a,b) (a>b)?a:b
#define min(a,b) (a>b)?b:a
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
#define LL __int64
const int maxn = 500100;
using namespace std;
#define MAX INT_MAX
#define MIN INT_MIN
struct node
{
    int l,r;
    LL add,sum;  
}T[400010];
int a[100005];   
void putup(int id)
{
    T[id].sum=T[2*id].sum+T[2*id+1].sum;
}
void putdown(int id)
{
    if(T[id].add)
    {
        T[2*id].add= T[2*id+1].add= T[id].add;
        T[2*id].sum = (T[2*id].r-T[2*id].l+1)*T[id].add;
        T[2*id+1].sum = (T[2*id+1].r-T[2*id+1].l+1)*T[id].add;
        T[id].add=0;
    }
}
void creat(int l,int r,int id)
{
    T[id].l=l;
    T[id].r=r;
    T[id].add=0;
  //  T[id].sum = 1;
    if(l==r)
    {
        T[id].sum=a[r];
        return;
    }
    int mid=(l+r)>>1;
    creat(l,mid,2*id);
    creat(mid+1,r,2*id+1);
    putup(id);
}
void update(int from,int to,LL add,int id)
{
    if(from<=T[id].l&&to>=T[id].r)
    {
        T[id].add = add;
        T[id].sum = (T[id].r-T[id].l+1)*add;
        return;
    }

    putdown(id);
    if(from<=T[2*id].r)
        update(from,to,add,2*id);
    if(to>=T[2*id+1].l)
        update(from,to,add,2*id+1);
    putup(id);
}
LL query(int from,int to,int id)
{
    if(from==T[id].l&&to==T[id].r)
        return T[id].sum;

    putdown(id);
    if(from>=T[2*id+1].l)
        return query(from,to,2*id+1);
    else if(to<=T[2*id].r)
        return query(from,to,2*id);
    else
    return query(from,T[2*id].r,2*id) + query(T[2*id+1].l,to,2*id+1);
}
int main()
{
    int n,m,A,B;
    LL add;
    char str[5];
    while(~scanf("%d%d",&n,&m))
      {
       // C++;
        for(int i=1; i<=n; i++)
            scanf("%d",&a[i]);
        creat(1,n,1);
        while(m--)
        {
            scanf("%s",str);
            if(str[0]=='T')
            {
                 scanf("%d%d%I64d",&A,&B,&add);
            update(A,B,add,1);
            }
           else if(str[0]=='Q')
           {
               scanf("%d%d",&A,&B);
               cout<<query(A,B,1)<<endl;
           }
        }

    }

}