看了一下题解里的zkw线段树,感觉讲的不是很清楚啊(可能有清楚的但是我没翻到,望大佬勿怪)。
决定自己写一篇。。。希望大家能看明白。。。
zkw线段树是一种优秀的非递归线段树,速度比普通线段树快两道三倍,同时代码量不大。
(当然,存在很多线段树可做zkw不可做的题)
zkw线段树的核心思路就是先修改叶子,然后从底向上沿着路径修改。
如果画一张图出来整个过程有点像逐渐两条交回在根节点的链。
注意:对于需要维护的区间$[1,n]$,zkw线段树维护的实际上是$[0,n+1]$。
建树
1 inline void build(ll n){ 2 bit=1; 3 while(bit<n+2)bit<<=1; 4 for(ll i=1;i<=n;++i)tree[bit+i]=a[i]; 5 for(ll i=bit-1;i>=1;--i)tree[i]=tree[i<<1]+tree[i<<1|1],tag[i]=0; 6 }
bit表示的底层的大小,我们需要先预处理出这个全局变量。
然后我们就可以先把叶子的值全部读入。
读入之后就顺着叶子向上走,更新上面的节点。
这一段代码没有什么复杂的地方。
更新
1 inline void update(ll l,ll r,ll val){ 2 ll s,t,ln=0,rn=0,x=1; 3 for(s=bit+l-1,t=bit+r+1;s^t^1;s>>=1,t>>=1,x<<=1){ 4 tree[s]+=val*ln,tree[t]+=val*rn; 5 if(~s&1)tag[s^1]+=val,tree[s^1]+=val*x,ln+=x; 6 if(t&1)tag[t^1]+=val,tree[t^1]+=val*x,rn+=x; 7 } 8 for(;s;s>>=1,t>>=1)tree[s]+=val*ln,tree[t]+=val*rn; 9 }
更新操作稍微比建树复杂一点。
s和t就是先前提到的两条链,当然准确地说,它们的轨迹才是那两条链。
ln,rn表示的是当前节点的长度(也就是s,t的长度)。
x表示的是s和t中间这一坨的长度。
然后也是一样的自底向上,每一次先更新两边,然后再判断该更新左儿子还是右儿子。
查询
1 inline ll query(ll l,ll r){ 2 ll s,t,ln=0,rn=0,x=1,ans=0; 3 for(s=bit+l-1,t=bit+r+1;s^t^1;s>>=1,t>>=1,x<<=1){ 4 if(tag[s])ans+=tag[s]*ln; 5 if(tag[t])ans+=tag[t]*rn; 6 if(~s&1)ans+=tree[s^1],ln+=x; 7 if(t&1)ans+=tree[t^1],rn+=x; 8 } 9 for(;s;s>>=1,t>>=1)ans+=tag[s]*ln,ans+=tag[t]*rn; 10 return ans; 11 }
查询操作和更新一样,没什么好讲的。
不开O2跑了511ms,比普通线段树的760+ms快很多(可能是我写丑了)
完整代码如下:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef unsigned long long ll; 4 const ll N=100100; 5 ll n,m; 6 ll op,x,y,z; 7 ll a[N]; 8 ll bit; 9 ll tree[N<<2],tag[N<<2]; 10 inline void build(ll n){ 11 bit=1; 12 while(bit<n+2)bit<<=1; 13 for(ll i=1;i<=n;++i)tree[bit+i]=a[i]; 14 for(ll i=bit-1;i>=1;--i)tree[i]=tree[i<<1]+tree[i<<1|1],tag[i]=0; 15 } 16 inline void update(ll l,ll r,ll val){ 17 ll s,t,ln=0,rn=0,x=1; 18 for(s=bit+l-1,t=bit+r+1;s^t^1;s>>=1,t>>=1,x<<=1){ 19 tree[s]+=val*ln,tree[t]+=val*rn; 20 if(~s&1)tag[s^1]+=val,tree[s^1]+=val*x,ln+=x; 21 if(t&1)tag[t^1]+=val,tree[t^1]+=val*x,rn+=x; 22 } 23 for(;s;s>>=1,t>>=1)tree[s]+=val*ln,tree[t]+=val*rn; 24 } 25 inline ll query(ll l,ll r){ 26 ll s,t,ln=0,rn=0,x=1,ans=0; 27 for(s=bit+l-1,t=bit+r+1;s^t^1;s>>=1,t>>=1,x<<=1){ 28 if(tag[s])ans+=tag[s]*ln; 29 if(tag[t])ans+=tag[t]*rn; 30 if(~s&1)ans+=tree[s^1],ln+=x; 31 if(t&1)ans+=tree[t^1],rn+=x; 32 } 33 for(;s;s>>=1,t>>=1)ans+=tag[s]*ln,ans+=tag[t]*rn; 34 return ans; 35 } 36 int main(){ 37 scanf("%lld%lld",&n,&m); 38 for(ll i=1;i<=n;++i)scanf("%lld",&a[i]); 39 build(n); 40 while(m--){ 41 scanf("%lld%lld%lld",&op,&x,&y); 42 if(op==1)scanf("%lld",&z),update(x,y,z); 43 else cout<<query(x,y)<<endl; 44 } 45 }