http://uoj.ac/problem/228 (题目链接)
题意
给出一个序列,维护区间加法,区间开根,区间求和
Solution
线段树。考虑区间开根怎么做。当区间的最大值与最小值相等时,我们直接对整个区间开根。最坏情况下,一次开根的复杂度最坏是${O(n)}$的,然而每次开根可以迅速拉近两个数之间的大小差距,最坏复杂度的开根不会超过${5}$次。
但是考虑这样一种情况:${sqrt{x+1}=sqrt{x}+1}$,如果序列长成这样:${65535,65536,65535,65536······}$,那么对它开根${3}$次,每次都是最坏情况下的复杂度,最后变成了${3,4,3,4······}$,如果此时我们对它进行区间加法,又加回${65535,65536,65535,65536······}$,不断循环,复杂度就炸裂了。所以当出现这种情况时,我们也对它进行区间开根。
细节
LL
代码
// uoj228 #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<cstdio> #include<vector> #include<cmath> #include<queue> #include<map> #define LL long long #define inf 1ll<<30 #define Pi acos(-1.0) #define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout); using namespace std; const int maxn=100010; int n,m,a[maxn]; struct segtree {int l,r;LL mn,mx,tag,s;}tr[maxn<<2]; void update(int k) { tr[k].mn=min(tr[k<<1].mn,tr[k<<1|1].mn)+tr[k].tag; tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx)+tr[k].tag; tr[k].s=tr[k<<1].s+tr[k<<1|1].s+tr[k].tag*(tr[k].r-tr[k].l+1); } void build(int k,int s,int t) { tr[k].l=s;tr[k].r=t; int mid=(s+t)>>1; if (s==t) {tr[k].mn=tr[k].mx=tr[k].s=a[s];return;} build(k<<1,s,mid); build(k<<1|1,mid+1,t); update(k); } void add(int k,int s,int t,int val) { int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if (l==s && r==t) {tr[k].s+=(LL)val*(tr[k].r-tr[k].l+1);tr[k].mn+=val,tr[k].mx+=val;tr[k].tag+=val;return;} if (t<=mid) add(k<<1,s,t,val); else if (s>mid) add(k<<1|1,s,t,val); else add(k<<1,s,mid,val),add(k<<1|1,mid+1,t,val); update(k); } void Sqrt(int k,int s,int t,LL tag) { int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if (l==s && r==t) { if ((tr[k].mx==tr[k].mn) || (tr[k].mn+1==tr[k].mx && floor(sqrt(tr[k].mn+tag))+1==floor(sqrt(tr[k].mx+tag)))) { LL tmp=floor(sqrt(tr[k].mn+tag))-tr[k].mn-tag; tr[k].tag+=tmp;tr[k].mn+=tmp;tr[k].mx+=tmp; tr[k].s+=(tr[k].r-tr[k].l+1)*tmp; return; } } if (t<=mid) Sqrt(k<<1,s,t,tag+tr[k].tag); else if (s>mid) Sqrt(k<<1|1,s,t,tag+tr[k].tag); else Sqrt(k<<1,s,mid,tag+tr[k].tag),Sqrt(k<<1|1,mid+1,t,tag+tr[k].tag); update(k); } LL query(int k,int s,int t) { int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if (l==s && r==t) return tr[k].s; if (t<=mid) return query(k<<1,s,t)+tr[k].tag*(t-s+1); else if (s>mid) return query(k<<1|1,s,t)+tr[k].tag*(t-s+1); else return query(k<<1,s,mid)+query(k<<1|1,mid+1,t)+tr[k].tag*(t-s+1); } int main() { scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&a[i]); build(1,1,n); for (int op,l,r,val,i=1;i<=m;i++) { scanf("%d%d%d",&op,&l,&r); if (op==1) scanf("%d",&val),add(1,l,r,val); if (op==2) Sqrt(1,l,r,0); if (op==3) printf("%lld ",query(1,l,r)); } return 0; }