首页 > 代码库 > 线段树复习
线段树复习
2017.3.24
T1 最大子段和 http://codevs.cn/problem/3981/
初做:2017.2.1 time:2576ms memory:22MB
http://www.cnblogs.com/TheRoadToTheGold/p/6360224.html
现在:2017.3.24 time:2991ms memory:29MB
#include<cstdio>#include<iostream>#include<algorithm>#define N 200001using namespace std;int n,m,tot;struct node{ int l,r,siz; long long lmax,rmax,maxx,sum;}tr[N*2];void up(int k){ int l=k+1,r=k+tr[k+1].siz*2; tr[k].lmax=max(tr[l].lmax,tr[l].sum+tr[r].lmax); tr[k].rmax=max(tr[r].rmax,tr[l].rmax+tr[r].sum); tr[k].sum=tr[l].sum+tr[r].sum; long long tmp1=max(tr[l].maxx,tr[r].maxx); long long tmp2=max(tr[k].lmax,tr[k].rmax); long long tmp3=max(tmp1,tmp2); tr[k].maxx=max(tmp3,tr[l].rmax+tr[r].lmax);}void build(int l,int r){ tr[++tot].l=l;tr[tot].r=r; tr[tot].siz=r-l+1; int k=tot; if(l==r) { cin>>tr[tot].maxx; tr[tot].lmax=tr[tot].rmax=tr[tot].sum=tr[tot].maxx; return; } int mid=l+r>>1; build(l,mid);build(mid+1,r); up(k);}void query(int k,int opl,int opr,long long & ans_l,long long &ans_r,long long &ans_sum,long long &ans){ if(tr[k].l==opl&&tr[k].r==opr) { ans_l=tr[k].lmax; ans_r=tr[k].rmax; ans_sum=tr[k].sum; ans=tr[k].maxx; return; } int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].siz*2; if(opr<=mid) query(l,opl,opr,ans_l,ans_r,ans_sum,ans); else if(opl>mid) query(r,opl,opr,ans_l,ans_r,ans_sum,ans); else { long long lch_lmax,lch_rmax,lch_sum,lch_maxx; long long rch_lmax,rch_rmax,rch_sum,rch_maxx; query(l,opl,mid,lch_lmax,lch_rmax,lch_sum,lch_maxx); query(r,mid+1,opr,rch_lmax,rch_rmax,rch_sum,rch_maxx); ans_l=max(lch_lmax,lch_sum+rch_lmax); ans_r=max(rch_rmax,rch_sum+lch_rmax); ans_sum=lch_sum+rch_sum; long long tmp1=max(lch_maxx,rch_maxx); long long tmp2=max(ans_l,ans_r); long long tmp3=max(tmp1,tmp2); ans=max(tmp3,lch_rmax+rch_lmax); }}int main(){ freopen("data","r",stdin); freopen("2.out","w",stdout); scanf("%d",&n); build(1,n); scanf("%d",&m); int x,y; long long ans_l,ans_r,ans_sum,ans; while(m--) { scanf("%d%d",&x,&y); query(1,x,y,ans_l,ans_r,ans_sum,ans); printf("%lld\n",ans); }}
画蛇添足:
tmp1=max(lch_max,rch_max)
tmp2=max(l_max,r_max)
tmp3=max(tmp1,tmp2)
ans=max(tmp3,l_rmax+r_lmax)
优化:ans=max(tmp1,l_rmax+r_lmax)
原因:如果全部的左子区间+右子区间左半部分最优,那么左子区间的右半部分=全部的左子区间
加深理解:
ans_l=max(lch_lmax,lch_sum+rch_lmax);
ans_r=max(rch_rmax,rch_sum+lch_rmax);
ans_l != tr[l].lmax
ans_r != tr[r].rmax
ans_l是自下一层开始递归,直至找到符合要求的tr[].lmax
这个符合要求的tr[].lmax不一定就是当前左子区间的lmax
疑问:
既然如此,那lch_sum也应该是符合要求的tr[].sum,而不一定当前左子区间的sum
但如若直接用tr[l].sum代替lch_sum也能AC,且对拍无误
why?
T2 https://www.luogu.org/problem/show?pid=2894
只有0和1,找到连续0的个数超过x的位置,输出最左端,支持区间修改操作
初做:2017.2.1 time:354ms memory:68.18MB
多开了10倍的结构体,重测:21.24MB
http://www.cnblogs.com/TheRoadToTheGold/p/6360248.html
现在:2017.3.24 time:1030ms memory:18.63MB
#include<cstdio>#include<algorithm>#define N 50001using namespace std;int n,m,tot,ans;struct node{ int l,r,lmax,rmax,maxx,siz; int f;}tr[N*2];void build(int l,int r){ int k=++tot; tr[k].l=l;tr[k].r=r; tr[k].lmax=tr[k].rmax=tr[k].maxx=tr[k].siz=r-l+1; if(l==r) return; int mid=l+r>>1; build(l,mid);build(mid+1,r);}void down(int k){ int l=k+1,r=k+tr[k+1].siz*2; if(tr[k].f==1) { tr[l].lmax=tr[l].rmax=tr[l].maxx=0; tr[r].lmax=tr[r].rmax=tr[r].maxx=0; } else { tr[l].lmax=tr[l].rmax=tr[l].maxx=tr[l].siz; tr[r].lmax=tr[r].rmax=tr[r].maxx=tr[r].siz; } tr[l].f=tr[r].f=tr[k].f; tr[k].f=0;}int query(int k,int l,int x){ if(tr[k].lmax>=x) return l; if(tr[k].f) down(k); int mid=tr[k].l+tr[k].r>>1,lc=k+1,rc=k+tr[k+1].siz*2; if(tr[lc].maxx>=x) return query(lc,l,x); if(tr[lc].rmax+tr[rc].lmax>=x) return mid-tr[lc].rmax+1; return query(rc,mid+1,x);}void up(int k){ int l=k+1,r=k+tr[k+1].siz*2; if(tr[l].lmax==tr[l].siz) tr[k].lmax=tr[l].siz+tr[r].lmax; else tr[k].lmax=tr[l].lmax; if(tr[r].rmax==tr[r].siz) tr[k].rmax=tr[l].rmax+tr[r].siz; else tr[k].rmax=tr[r].rmax; tr[k].maxx=max(max(tr[l].maxx,tr[r].maxx),tr[l].rmax+tr[r].lmax);}void change(int k,int opl,int opr,int w){ if(tr[k].l>=opl&&tr[k].r<=opr) { if(w==1) tr[k].lmax=tr[k].rmax=tr[k].maxx=0; else tr[k].lmax=tr[k].rmax=tr[k].maxx=tr[k].siz; tr[k].f=w; return; } if(tr[k].f) down(k); int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].siz*2; if(opl<=mid) change(l,opl,opr,w); if(opr>mid) change(r,opl,opr,w); up(k);}int main(){ scanf("%d%d",&n,&m); build(1,n); int x,y,z; while(m--) { scanf("%d",&z); if(z==1) { scanf("%d",&x); if(tr[1].maxx<x) { printf("0\n"); continue; } ans=query(1,1,x); printf("%d\n",ans); change(1,ans,ans+x-1,1); } else { scanf("%d%d",&x,&y); change(1,x,x+y-1,2); } }}
3个错误:
① 父节点编号k,左子节点为k+1,右子节点为k+tr[k+1].siz*2
没有+1,因为左子树节点数为2*tr[k+1].siz-1
② 区间修改为1时,实际操作区间opl,opr与当前递归区间l,r混淆
既然记录了siz信息,为啥不直接用呢
③ query时忘了下传标记
思路卡壳处:
要求输出最左端,而线段树里没有记录这一信息。
可以在调用函数时设一实参表示,lmax满足要求时直接返回这一参数,这一点后来想到了
因为输出最左端,所以跨左右区间的优于在右子区间的
这一点在没过样例之后想到了,但如何处理
一开始的想法是递归左子区间,找>=x-tr[r].lmax的位置
真是脑抽了,明显不对
没忍住看了以前做的,是直接返回mid-tr[l].rmax+1
因为如果跨区间,那么最左端位置就固定了
以前写的跟这次的不大一样,仔细想想,他主要集中在了这一句
if(e[(k<<1)+1].max_l+e[k<<1].max_r>=x) return mid-e[k<<1].max_r+1;
(以前的代码)
线段树复习