1简介
线段树是一种比较实用的数据结构,支持区间修改,区间查询,单点修改,单点查询,维护区间最值。
2讲解
2.1基本值
(注:本博客只是设计线段树区间求和)
在结构体中我们维护一个区间值,区间长度,和该节点的标记
这个标记非常重要,是线段树的灵魂
代码:
struct node{
ll sum,len;
ll ad;
};
2.2建树
对于一个数组a来说,把它l到r的一个区间看做一个节点,根节点的区间是l到r,根节点两个儿子的区间分别为l到mid,和mid+1到r,其中mid=l+r>>1;由此不断往下递归,直至该区间内只剩一个节点。
上述说明不仅说明了线段树的构造原理,也可以得出线段树的建树代码:
inline void build(ll k,ll l,ll r)
{
if(l==r)
{
tree[k].len=1;
tree[k].sum=a[l];
return;
}
ll mid=l+r>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
pushup(k);
}
其中pushup函数是合并函数,根据代码,我们可以知道在建树中,是首先建立叶子结点,然后逐步回溯,合并,建立父节点,那么就有了合并函数。
2.3合并函数
合并函数比较好理解,一个节点对应一个区间,那么这个区间l到r我们所要维护的值一定取决于它的两个儿子,例如,如果我们所维护的是l到r的和,那么该节点所维护的l到r的区间和一定是它的两个儿子:l到mid和mid+1到r的和,长度也是如此,由此可以得出:
inline void pushup(ll k)
{
tree[k].sum=tree[k*2].sum+tree[k*2+1].sum;
tree[k].len=tree[k*2].len+tree[k*2+1].len;
}
2.4加法函数与标记下传
如果要给一个区间和为q,区间长度为len的区间进行把区间里的每个数都加x的操作,那么就是相当于整个区间的和加上x*len:
inline void A(ll k,ll x)
{
tree[k].sum+=tree[k].len*x;
tree[k].ad+=x;
}
但是我们并没有修改这个区间的每一个数,如果我们需要用到这个区间里的某些数但是又不是全部,这个时候因为我们个这个区间的标记加上了x,也就是说我们标记l到r的区间加上了x,接下来我们要做的就是让标记下传,把标记传给该节点的两个儿子,同时更新它的两个儿子所维护的区间和与标记,这样,我们需要用到什么程度就标记下传的什么程度,一边修改一边标记下传。
这就体现出了为什么线段树的时间复杂度更优:我们需要修改l到r的区间和就只看l到r的区间和,而不用给l到r的每一个值都进行修改,如果有需要用到一些更小的区间,那么标记下传
标记下传同时,不要忘了清零该节点标记:
inline void pushdown(ll k)
{
A(k*2,tree[k].ad);
A(k*2+1,tree[k].ad);
tree[k].ad=0;
}
2.5区间更改
如果让我们修改的区间已经查到了,就执行加法操作,如果没有,就看他的左右儿子,同时标记下传,最后不要忘记合并
有时我们所要修改的区间包含在多个线段树节点中,我们分治来考虑,如果mid把我们的目标区间分成两段,那就分别来处理
下面的代码中:z到y指的是目标区间,l到r指的是当前区间,k指的是当前区间所对应的线段树节点,x指的是要给区间修改的值
inline void change(ll k,ll l,ll r,ll z,ll y,ll x)
{
if(l==z&&r==y)
{
A(k,x);
return;
}
if(tree[k].ad) pushdown(k);
ll mid=l+r>>1;
if(y<=mid) change(k*2,l,mid,z,y,x);
else if(z>=mid+1) change(k*2+1,mid+1,r,z,y,x);
else change(k*2,l,mid,z,mid,x),change(k*2+1,mid+1,r,mid+1,y,x);
pushup(k);
}
2.6区间查询
区间查询的思路与2.5差不多,基本上还是一个分治的思想,同时别忘了标记下传:
inline ll ask_sum(ll k,ll l,ll r,ll z,ll y)
{
if(l==z&&r==y) return tree[k].sum;
if(tree[k].ad) pushdown(k);
ll mid=l+r>>1;
if(y<=mid) return ask_sum(k*2,l,mid,z,y);
else if(z>=mid+1) return ask_sum(k*2+1,mid+1,r,z,y);
else return ask_sum(k*2,l,mid,z,mid)+ask_sum(k*2+1,mid+1,r,mid+1,y);
}
总结
以上就是线段树大致内容,博主讲的说实话并不是很好,希望多多谅解,如果有不清楚的地方可以在评论区里提问。这里附上总代码,该代码可以过洛谷上的线段树模板1
本博客只是介绍了区间加法,实际上区间加减乘除,最值,线段树都可以维护。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ll long long
#define ull unsigned long long
#define N 1000100
#define M number
using namespace std;
ll n,m,a[N];
struct Stree{
struct node{
ll sum,len;
ll ad;
};
node tree[N<<4];
inline void A(ll k,ll x)
{
tree[k].sum+=tree[k].len*x;
tree[k].ad+=x;
}
inline void pushdown(ll k)
{
A(k*2,tree[k].ad);
A(k*2+1,tree[k].ad);
tree[k].ad=0;
}
inline void pushup(ll k)
{
tree[k].sum=tree[k*2].sum+tree[k*2+1].sum;
tree[k].len=tree[k*2].len+tree[k*2+1].len;
}
inline void build(ll k,ll l,ll r)
{
if(l==r)
{
tree[k].len=1;
tree[k].sum=a[l];
return;
}
ll mid=l+r>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
pushup(k);
}
inline void change(ll k,ll l,ll r,ll z,ll y,ll x)
{
if(l==z&&r==y)
{
A(k,x);
return;
}
if(tree[k].ad) pushdown(k);
ll mid=l+r>>1;
if(y<=mid) change(k*2,l,mid,z,y,x);
else if(z>=mid+1) change(k*2+1,mid+1,r,z,y,x);
else change(k*2,l,mid,z,mid,x),change(k*2+1,mid+1,r,mid+1,y,x);
pushup(k);
}
inline ll ask_sum(ll k,ll l,ll r,ll z,ll y)
{
if(l==z&&r==y) return tree[k].sum;
if(tree[k].ad) pushdown(k);
ll mid=l+r>>1;
if(y<=mid) return ask_sum(k*2,l,mid,z,y);
else if(z>=mid+1) return ask_sum(k*2+1,mid+1,r,z,y);
else return ask_sum(k*2,l,mid,z,mid)+ask_sum(k*2+1,mid+1,r,mid+1,y);
}
};
Stree stree;
int main()
{
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
stree.build(1,1,n);
for(int i=1;i<=m;i++)
{
int op,l,r;
scanf("%d%d%d",&op,&l,&r);
if(op==1)
{
ll x;
scanf("%lld",&x);
stree.change(1,1,n,l,r,x);
}
else printf("%lld
",stree.ask_sum(1,1,n,l,r));
}
}
线段树2代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<vector>
#define N 1000010
#define ll long long
using namespace std;
ll n,m,p;
ll sum[N<<2],len[N<<2],a[N],ad_1[N<<2],ad_2[N<<2];
void pushup(ll k)
{
sum[k]=(sum[k*2]+sum[k*2+1])%p;
len[k]=len[k*2]+len[k*2+1];
}
void A(ll k,ll x)
{
sum[k]=(x*sum[k])%p;
ad_1[k]=(ad_1[k]*x)%p;
ad_2[k]=(ad_2[k]*x)%p;
if(ad_1[k]==0) ad_1[k]=p;
}
void B(ll k,ll x)
{
sum[k]=(sum[k]+(len[k]*x)%p)%p;
ad_2[k]=(ad_2[k]+x)%p;
}
void pushdown_2(ll k)
{
B(k*2,ad_2[k]);
B(k*2+1,ad_2[k]);
ad_2[k]=0;
}
void pushdown_1(ll k)
{
A(k*2,ad_1[k]);
A(k*2+1,ad_1[k]);
ad_1[k]=1;
}
void change_1(ll k,ll l,ll r,ll z,ll y,ll x)
{
if(l==z&&r==y)
{
A(k,x);
return;
}
if(ad_1[k]>1) pushdown_1(k);
if(ad_2[k]) pushdown_2(k);
ll mid=l+r>>1;
if(y<=mid) change_1(k*2,l,mid,z,y,x);
else if(mid<z) change_1(k*2+1,mid+1,r,z,y,x);
else
{
change_1(k*2,l,mid,z,mid,x);
change_1(k*2+1,mid+1,r,mid+1,y,x);
}
pushup(k);
}
int ask_sum(ll k,ll l,ll r,ll z,ll y)
{
if(l==z&&r==y)
{
return sum[k];
}
if(ad_1[k]>1) pushdown_1(k);
if(ad_2[k]) pushdown_2(k);
ll mid=l+r>>1;
if(y<=mid) return ask_sum(k*2,l,mid,z,y);
else if(mid<z) return ask_sum(k*2+1,mid+1,r,z,y);
else return (ask_sum(k*2,l,mid,z,mid)%p+ask_sum(k*2+1,mid+1,r,mid+1,y)%p)%p;
}
void change_2(ll k,ll l,ll r,ll z,ll y,ll x)
{
if(l==z&&r==y)
{
B(k,x);
return;
}
if(ad_1[k]>1) pushdown_1(k);
if(ad_2[k]) pushdown_2(k);
ll mid=l+r>>1;
if(y<=mid) change_2(k*2,l,mid,z,y,x);
else if(mid<z) change_2(k*2+1,mid+1,r,z,y,x);
else
{
change_2(k*2,l,mid,z,mid,x);
change_2(k*2+1,mid+1,r,mid+1,y,x);
}
pushup(k);
}
void build(ll k,ll l,ll r)
{
ad_1[k]=1;
if(l==r)
{
sum[k]=a[l];
len[k]=1;
return;
}
ll mid=l+r>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
pushup(k);
}
void print()
{
cout<<endl<<endl;
int i=0;
do
{
cout<<sum[++i]<<" ";
}while(sum[i]);
cout<<endl<<endl;
}
int main()
{
// freopen("he.in","r",stdin);
// freopen("he.out","w",stdout);
ios::sync_with_stdio(false);
cin>>n>>m>>p;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
//print();
while(m--)
{
int op,a,b,x;
cin>>op>>a>>b;
if(op==1)
{
cin>>x;
change_1(1,1,n,a,b,x);
//print();
}
else if(op==2)
{
cin>>x;
change_2(1,1,n,a,b,x);
//print();
}
else
{
int ans=ask_sum(1,1,n,a,b);
cout/*<<"ans="*/<<ans<<endl;
//print();
}
}
return 0;
}