题目描述
样例
样例输入
1
5 5
1 2 3 4 5
1 1 5
2 1 5
0 3 5 3
1 1 5
2 1 5
样例输出
5
15
3
12
分析
对于线段树的每一个节点,我们记录它的最大值、严格次大值、区间和以及最大值的个数
对于第一种操作,如果当前区间的最大值小于等于 (t),那么直接把这个区间剪掉
如果最大值大于 (t) ,但是次大值小于 (t),那么我们直接把最大值赋值成 (t)
其它的情况暴力递归
时间复杂度:挺对的
代码
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<vector>
#include<cstring>
#include<map>
#define rg register
inline int read(){
rg int x=0,fh=1;
rg char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') fh=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*fh;
}
const int maxn=1e6+5;
int t,n,m,a[maxn];
struct trr{
int l,r,mmax,sec,cnt;
long long sum;
}tr[maxn<<2];
void push_up(rg int da){
tr[da].mmax=std::max(tr[da<<1].mmax,tr[da<<1|1].mmax);
if(tr[da<<1].mmax>tr[da<<1|1].mmax){
tr[da].sec=std::max(tr[da<<1|1].mmax,tr[da<<1].sec);
tr[da].cnt=tr[da<<1].cnt;
} else if(tr[da<<1].mmax<tr[da<<1|1].mmax){
tr[da].sec=std::max(tr[da<<1].mmax,tr[da<<1|1].sec);
tr[da].cnt=tr[da<<1|1].cnt;
} else {
tr[da].sec=std::max(tr[da<<1].sec,tr[da<<1|1].sec);
tr[da].cnt=tr[da<<1].cnt+tr[da<<1|1].cnt;
}
tr[da].sum=tr[da<<1].sum+tr[da<<1|1].sum;
}
void updat(rg int da,rg int val){
if(tr[da].mmax>val){
tr[da].sum-=1LL*tr[da].cnt*(tr[da].mmax-val);
tr[da].mmax=val;
}
}
void push_down(rg int da){
updat(da<<1,tr[da].mmax);
updat(da<<1|1,tr[da].mmax);
}
void build(rg int da,rg int l,rg int r){
tr[da].l=l,tr[da].r=r;
if(tr[da].l==tr[da].r){
tr[da].sum=tr[da].mmax=a[l];
tr[da].sec=-1;
tr[da].cnt=1;
return;
}
rg int mids=(tr[da].l+tr[da].r)>>1;
if(l<=mids) build(da<<1,l,mids);
if(r>mids) build(da<<1|1,mids+1,r);
push_up(da);
}
void xg(rg int da,rg int l,rg int r,rg int val){
if(tr[da].mmax<=val) return;
if(tr[da].l>=l && tr[da].r<=r && tr[da].sec<val){
updat(da,val);
return;
}
push_down(da);
rg int mids=(tr[da].l+tr[da].r)>>1;
if(l<=mids) xg(da<<1,l,r,val);
if(r>mids) xg(da<<1|1,l,r,val);
push_up(da);
}
long long cxsum(rg int da,rg int l,rg int r){
if(tr[da].l>=l && tr[da].r<=r)return tr[da].sum;
push_down(da);
rg int mids=(tr[da].l+tr[da].r)>>1;
rg long long nans=0;
if(l<=mids) nans+=cxsum(da<<1,l,r);
if(r>mids) nans+=cxsum(da<<1|1,l,r);
return nans;
}
int cxmax(rg int da,rg int l,rg int r){
if(tr[da].l>=l && tr[da].r<=r) return tr[da].mmax;
push_down(da);
rg int mids=(tr[da].l+tr[da].r)>>1,nans=-1;
if(l<=mids) nans=std::max(nans,cxmax(da<<1,l,r));
if(r>mids) nans=std::max(nans,cxmax(da<<1|1,l,r));
return nans;
}
int main(){
t=read();
while(t--){
n=read(),m=read();
for(rg int i=1;i<=n;i++) a[i]=read();
build(1,1,n);
rg int aa,bb,cc,dd;
for(rg int i=1;i<=m;i++){
aa=read(),bb=read(),cc=read();
if(aa==0){
dd=read();
xg(1,bb,cc,dd);
} else if(aa==1){
printf("%d
",cxmax(1,bb,cc));
} else {
printf("%lld
",cxsum(1,bb,cc));
}
}
}
return 0;
}