problem
已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
solution
区间修改+区间查询。
维护两个LazyTag
codes
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn = 100010;
typedef long long LL;
int n, m;
LL a[maxn],mod;
struct node{
int l, r;
LL val, addmark, mulmark;
}sgt[maxn<<2];
void build(int p, int l, int r){
sgt[p].l = l, sgt[p].r = r;
sgt[p].mulmark=1, sgt[p].addmark=0;
if(l == r){
sgt[p].val = a[l];
}else{
int m = (l+r)/2;
build(p*2,l,m);
build(p*2+1,m+1,r);
sgt[p].val = sgt[p*2].val+sgt[p*2+1].val;
}
sgt[p].val %= mod;
}
void pushdown(int p){
if(sgt[p].addmark==0&&sgt[p].mulmark==1)return ;
//初始化父节点
LL t1 = sgt[p].addmark, t2 = sgt[p].mulmark;
sgt[p].addmark = 0, sgt[p].mulmark = 1;
//维护标记
sgt[p*2].mulmark = (sgt[p*2].mulmark*t2)%mod;
sgt[p*2+1].mulmark = (sgt[p*2+1].mulmark*t2)%mod;
sgt[p*2].addmark = (sgt[p*2].addmark*t2+t1)%mod;
sgt[p*2+1].addmark = (sgt[p*2+1].addmark*t2+t1)%mod;
//更新当前值,我们规定乘法优先更新(加法优先会损失精度)
int l = sgt[p].l, r = sgt[p].r, m = (l+r)/2;
sgt[p*2].val=(sgt[p*2].val*t2+t1*(m-l+1))%mod;//先乘以乘法标记再加上已用乘法标记更新过的加法标记。
sgt[p*2+1].val=(sgt[p*2+1].val*t2+t1*(r-m))%mod;
}
void add(int p, int l, int r, LL v){
if(l <= sgt[p].l && sgt[p].r <= r){
sgt[p].val = (sgt[p].val+(sgt[p].r-sgt[p].l+1)*v)%mod;
sgt[p].addmark = (sgt[p].addmark+v)%mod;
return ;
}
pushdown(p);
int m = (sgt[p].l+sgt[p].r)/2;
if(l <= m)add(p*2,l,r,v);
if(r > m)add(p*2+1,l,r,v);
sgt[p].val = (sgt[p*2].val+sgt[p*2+1].val)%mod;
}
void times(int p, int l, int r, LL v){
if(l <= sgt[p].l && sgt[p].r <= r){
sgt[p].val = (sgt[p].val*v)%mod;
sgt[p].mulmark = (sgt[p].mulmark*v)%mod;
sgt[p].addmark = (sgt[p].addmark*v)%mod;//原先的加法标记也要乘
return ;
}
pushdown(p);
int m = (sgt[p].l+sgt[p].r)/2;
if(l <= m)times(p*2,l,r,v);
if(r > m)times(p*2+1,l,r,v);
sgt[p].val = (sgt[p*2].val+sgt[p*2+1].val)%mod;
}
LL query(int p, int l, int r){
if(l <= sgt[p].l && sgt[p].r <= r)return sgt[p].val;
pushdown(p); //pushdown
LL m = (sgt[p].l+sgt[p].r)/2, ans = 0;
if(l <= m)ans += query(p*2,l,r);
if(r > m)ans += query(p*2+1,l,r);
return ans%mod;
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>mod;
for(int i = 1; i <= n; i++)cin>>a[i];
build(1,1,n);
cin>>m;
for(int i = 1; i <= m; i++){
int op; cin>>op;
if(op == 1){
LL x, y, z; cin>>x>>y>>z;
times(1,x,y,z);
}else if(op == 2){
LL x, y, z; cin>>x>>y>>z;
add(1,x,y,z);
}else{
LL x, y; cin>>x>>y;
cout<<query(1,x,y)%mod<<"
";
}
}
return 0;
}