没什么解释大概只有一些板子$qwq$
线段树是一种基于分治思想的二叉树结构,用于在区间上进行信息统计.比树状数组更加通用.
线段树在数组中的存储
线段树除了最后一层一定是一棵完全二叉树,树的深度为$O(log N)$.因此我们可以按照"父子二倍"给结点编号的方法:
1.根结点为$1$
2.对于编号为$p$的结点,它的左孩子为$2*p$,右孩子为$2*p+1$
$N$个叶结点的满二叉树的结点数:$N+N/2+N/4+...+2+1=2N-1$.因为在上述存储方式下,最后还有一层产生了空余,所以保存线段树的数组长度要不小于$4N$.
线段树的建树
下面的结点建立了一棵线段树并在每个结点上保存了对应区间的最大值
struct node{int l,r,dat;}t[N*4]; inline void build(int l,int r,int p) { t[p].l=l,t[p].r=r; if(l==r)return t[p].dat=a[l]; int mid=(l+r)>>1; build(l,mid,p*2); build(mid+1,r,p*2+1); t[p].dat=max(t[p*2].dat,t[p*2+1].dat); } build(1,n,1) //调用入口
线段树的单点修改
根结点是执行各种指令的入口.我们需要从根结点出发递归找到需要修改的叶结点,然后从下到上更新它的所有祖先结点上保留的信息.时间复杂度为$O(log N)$.$
将$a[x]$修改成$y$
inline void update(int p,int x,int y) { int l=t[p].l,r=t[p].r; if(l==r){t[p].dat=y;return;} int mid=(l+r)>>1; if(x<=mid)update(p*2,x,y); else update(p*2+1,x,y); t[p].dat=max(t[p*2].dat,t[p*2+1].dat); } change(1,x,y);
线段树的区间查询
查询[l,r]区间上的最大值.
inline int query(int p,int l,int r) { if(l<=t[p].l && r>=t[p].r)return t[p].dat; int mid=(t[p].l+t[p].r)>>1,ret=-inf; if(l<=mid)ret=max(ret,query(p*2,l,r)); if(r>mid)ret=max(ret,query(p*2+1,l,r)); return ret; } printf("%d ",query(1,l,r));
延迟标记
对于"区间修改"而言的.为了避免浪费时间去修改后面压根用不到的结点,我们给被需修改区间完全覆盖的结点打上延迟标记,标识"该结点曾经被修改过,但其子结点尚未被更新".如果后面的查询用到该结点,再修改,也叫"标记下放".
以$Poj3468$ / Luogu3372为例(两题的输入略有不同,以下$Code$是$Poj$版)
View Code#include<cstdio> #define il inline #define Rg register #define go(i,a,b) for(Rg int i=a;i<=b;i++) #define ll long long const int N=100010; struct node{int l,r;ll dat,add;}t[N*4]; int n,m,a[N]; il void build(int p,int l,int r) { t[p].l=l,t[p].r=r; if(l==r){t[p].dat=a[l];return;} int mid=(l+r)>>1; build(p*2,l,mid); build(p*2+1,mid+1,r); t[p].add=0,t[p].dat=t[p*2].dat+t[p*2+1].dat; } il void pushdown(int p) { if(t[p].add) { t[p*2].dat+=(t[p*2].r-t[p*2].l+1)*t[p].add; t[p*2+1].dat+=(t[p*2+1].r-t[p*2+1].l+1)*t[p].add; t[p*2].add+=t[p].add; t[p*2+1].add+=t[p].add; t[p].add=0; } } il void update(int p,int l,int r,int d) { if(t[p].l>=l && t[p].r<=r){t[p].dat+=1LL*(t[p].r-t[p].l+1)*d,t[p].add+=d;return;} pushdown(p); int mid=(t[p].l+t[p].r)>>1; if(l<=mid)update(p*2,l,r,d); if(r>mid)update(p*2+1,l,r,d); t[p].dat=t[p*2].dat+t[p*2+1].dat; } il ll query(int p,int l,int r) { if(t[p].l>=l && t[p].r<=r)return t[p].dat; pushdown(p);ll ret=0; int mid=(t[p].l+t[p].r)>>1; if(l<=mid)ret+=query(p*2,l,r); if(r>mid)ret+=query(p*2+1,l,r); return ret; } int main() { scanf("%d%d",&n,&m); go(i,1,n)scanf("%d",&a[i]); build(1,1,n); while(m--) { char tp=getchar();int l,r; while(tp<'A'||tp>'Z')tp=getchar(); scanf("%d%d",&l,&r); if(tp=='Q')printf("%lld ",query(1,l,r)); else{int d;scanf("%d",&d);update(1,l,r,d);} } return 0; }
扫描线
咕咕咕
板子题!
最简单的板子没延迟标记也能过(甚至更快???)
View Code#include<cstdio> #define il inline #define Rg register #define go(i,a,b) for(Rg int i=a;i<=b;i++) #define ll long long const int N=100010; struct node{int l,r;ll dat,add;}t[N*4]; int n,m,a[N]; il void build(int p,int l,int r) { t[p].l=l,t[p].r=r; if(l==r){t[p].dat=a[l];return;} int mid=(l+r)>>1; build(p*2,l,mid); build(p*2+1,mid+1,r); t[p].add=0,t[p].dat=t[p*2].dat+t[p*2+1].dat; } il void pushdown(int p) { if(t[p].add) { t[p*2].dat+=(t[p*2].r-t[p*2].l+1)*t[p].add; t[p*2+1].dat+=(t[p*2+1].r-t[p*2+1].l+1)*t[p].add; t[p*2].add+=t[p].add; t[p*2+1].add+=t[p].add; t[p].add=0; } } il void update(int p,int l,int r,int d) { if(t[p].l>=l && t[p].r<=r){t[p].dat+=1LL*(t[p].r-t[p].l+1)*d,t[p].add+=d;return;} pushdown(p); int mid=(t[p].l+t[p].r)>>1; if(l<=mid)update(p*2,l,r,d); if(r>mid)update(p*2+1,l,r,d); t[p].dat=t[p*2].dat+t[p*2+1].dat; } il ll query(int p,int l,int r) { if(t[p].l>=l && t[p].r<=r)return t[p].dat; pushdown(p);ll ret=0; int mid=(t[p].l+t[p].r)>>1; if(l<=mid)ret+=query(p*2,l,r); if(r>mid)ret+=query(p*2+1,l,r); return ret; } int main() { scanf("%d%d",&n,&m); go(i,1,n)scanf("%d",&a[i]); build(1,1,n); while(m--) { int tp,l,r; scanf("%d%d%d",&tp,&l,&r); if(tp==2)printf("%lld ",query(1,l,r)); else{int d;scanf("%d",&d);update(1,l,r,d);} } return 0; }难一点的板子.但是似乎只要知道"乘法优先于加法"$+$"乘法分配律"这题就$over$了
具体说一下叭.线段树有两个延迟标记t[p].add,t[p].mul,分别表示加法和乘法.
$1.p$结点代表的区间内的数都加上$d: t[p].dat=(t[p].r-t[p].l+1)*d,t[p].add+=d;$
$2.p$结点代表的区间内的数都乘上$d:t[p].dat*=d,t[p].add*=d,t[p]=mul*=d;$
关于$2$操作的理解:
$t[p*2].dat=(t[p*2].dat+t[p].add)*(t[p].mul*d)$
$=t[p*2].dat*(t[p].mul)+t[p].add*(t[p].mul*d)$
很久以前的 Code#include<iostream> #include<cstdio> #define ll long long using namespace std; int read() { int x=0,y=1;char c; c=getchar(); while(c<'0'||c>'9') {if(c=='-') y=-1;c=getchar();} while(c>='0'&&c<='9') {x=(x<<1)+(x<<3)+c-'0';c=getchar();} return x*y; } int n,m,mod; ll s[100010]; struct SegmentTree { int l,r; ll w,mul,add; }t[400010]; void build(int p,int l,int r) { t[p].mul=1;t[p].add=0; t[p].l=l;t[p].r=r; if(l==r) {t[p].w=s[l]%mod;return ;} int mid=(l+r)/2; build(p*2,l,mid); build(p*2+1,mid+1,r); t[p].w=(t[p*2].w+t[p*2+1].w)%mod; } void pushdown(int p) { t[p*2].w=(t[p*2].w*t[p].mul+(t[p*2].r-t[p*2].l+1)*t[p].add)%mod; t[p*2+1].w=(t[p*2+1].w*t[p].mul+(t[p*2+1].r-t[p*2+1].l+1)*t[p].add)%mod; t[p*2].mul=(t[p*2].mul*t[p].mul)%mod; t[p*2+1].mul=(t[p*2+1].mul*t[p].mul)%mod; t[p*2].add=(t[p*2].add*t[p].mul+t[p].add)%mod; t[p*2+1].add=(t[p*2+1].add*t[p].mul+t[p].add)%mod; t[p].mul=1;t[p].add=0; } void update_add(int p,int l,int r,int d) { if(t[p].l>=l&&t[p].r<=r){t[p].w+=d*(t[p].r-t[p].l+1);t[p].add+=d;t[p].w%=mod;t[p].add%=mod;return ;} pushdown(p); int mid=(t[p].l+t[p].r)/2; if(l<=mid) update_add(p*2,l,r,d); if(r>mid) update_add(p*2+1,l,r,d); t[p].w=t[p*2].w+t[p*2+1].w;t[p].w%=mod; } void update_mul(int p,int l,int r,int d) { if(t[p].l>=l&&t[p].r<=r){t[p].w*=d;t[p].mul*=d;t[p].add*=d;t[p].w%=mod;t[p].add%=mod;t[p].mul%=mod;return ;} pushdown(p); int mid=(t[p].l+t[p].r)/2; if(l<=mid) update_mul(p*2,l,r,d); if(r>mid) update_mul(p*2+1,l,r,d); t[p].w=t[p*2].w+t[p*2+1].w;t[p].w%=mod; } ll query(int p,int l,int r) { if(t[p].l>=l&&t[p].r<=r) return t[p].w%mod; pushdown(p); int mid=(t[p].l+t[p].r)/2; if(l>mid) return query(p*2+1,l,r)%mod; if(r<=mid) return query(p*2,l,r)%mod; return (query(p*2,l,mid)+query(p*2+1,mid+1,r))%mod; } int main() { n=read();m=read();mod=read(); for(int i=1;i<=n;i++) s[i]=read(); build(1,1,n); while(m--){ int type=read(); if(type==2){ int x=read(),y=read(),k=read(); update_add(1,x,y,k);} if(type==1){ int x=read(),y=read(),k=read(); update_mul(1,x,y,k);} if(type==3){ int x=read(),y=read(); printf("%lld ",query(1,x,y));} } return 0; }