zoukankan      html  css  js  c++  java
  • [做题笔记] 浅谈势能线段树在特殊区间问题上的应用

    区间最值操作

    题目描述

    点此看题

    维护一个数据结构支持区间取最小值,查询区间最大值,查询区间和。

    解法

    线段树上每个节点维护 (mx) 表示区间最大值,(cx) 表示区间严格次大值,对于修改我们这样做:

    • 如果 (mxleq t),那么忽略这次取最小值的操作。
    • 如果 (mx>t>cx),设区间中 (mx)(num) 个,那么打上标记,把区间和减去 (numcdot(mx-t))
    • 如果 (tleq cx),暴力往下递归。

    可以设计势能函数 (h(x)) 表示线段树上节点 (x) 的代表区间中互不相同的元素个数,考虑无论是打标记还是往下递归都是花费 (O(1)) 的时间将势能减少 (1),初始势能是 (nlog n),所以时间复杂度 (O(nlog n))

    总结

    对于一些奇怪的区间操作,可以考虑势能线段树。

    我们可以先尽量多想一些剪枝,然后用势能函数证明时间复杂度。

    关于势能函数的定义,可以考虑关键操作会让什么量减少,尝试把它定义成势能函数。

    #include <cstdio>
    #include <iostream>
    using namespace std;
    const int M = 1000005;
    #define ll long long
    int read()
    {
        int x=0,f=1;char c;
        while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
        while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
        return x*f;
    }
    int T,n,m,mx[4*M],cx[4*M],num[4*M];ll s[4*M];
    void up(int i)
    {
        num[i]=0;
        mx[i]=max(mx[i<<1],mx[i<<1|1]);
        cx[i]=max(cx[i<<1],cx[i<<1|1]);
        if(mx[i<<1]!=mx[i<<1|1])
            cx[i]=max(cx[i],min(mx[i<<1],mx[i<<1|1]));
        if(mx[i]==mx[i<<1]) num[i]+=num[i<<1];
        if(mx[i]==mx[i<<1|1]) num[i]+=num[i<<1|1];
        s[i]=s[i<<1]+s[i<<1|1];
    }
    void fuck(int i,int c)
    {
        if(mx[i]<=c) return ;
        s[i]-=1ll*(mx[i]-c)*num[i];
        mx[i]=c;
    }
    void down(int i)
    {
        fuck(i<<1,mx[i]);
        fuck(i<<1|1,mx[i]);
    }
    void build(int i,int l,int r)
    {
        if(l==r)
        {
            s[i]=mx[i]=read();
            num[i]=1;cx[i]=-1;
            return ;
        }
        int mid=(l+r)>>1;
        build(i<<1,l,mid);
        build(i<<1|1,mid+1,r);
        up(i);
    }
    void zxy(int i,int l,int r,int c)
    {
        if(mx[i]<=c) return ;
        if(mx[i]>c && c>cx[i])
        {
            fuck(i,c);
            return ;
        }
        if(l==r)
        {
            mx[i]=s[i]=min(c,mx[i]);
            return ;
        }
        int mid=(l+r)>>1;down(i);
        zxy(i<<1,l,mid,c);
        zxy(i<<1|1,mid+1,r,c);
        up(i);
    }
    void upd(int i,int l,int r,int L,int R,int c)
    {
        if(L>r || l>R) return ;
        if(L<=l && r<=R)
        {
            zxy(i,l,r,c);
            return ;
        }
        int mid=(l+r)>>1;down(i);
        upd(i<<1,l,mid,L,R,c);
        upd(i<<1|1,mid+1,r,L,R,c);
        up(i);
    }
    int askmax(int i,int l,int r,int L,int R)
    {
        if(L>r || l>R) return 0;
        if(L<=l && r<=R) return mx[i];
        int mid=(l+r)>>1;down(i);
        return max(askmax(i<<1,l,mid,L,R), 
        askmax(i<<1|1,mid+1,r,L,R));
    }
    ll asksum(int i,int l,int r,int L,int R)
    {
        if(L>r || l>R) return 0;
        if(L<=l && r<=R) return s[i];
        int mid=(l+r)>>1;down(i);
        return asksum(i<<1,l,mid,L,R)+
        asksum(i<<1|1,mid+1,r,L,R);
    }
    void write(ll x)
    {
        if (x < 0) x = ~x + 1, putchar('-');
        if (x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
    signed main()
    {
        T=read();
        while(T--)
        {
            n=read();m=read();
            build(1,1,n);
            while(m--)
            {
                int op=read(),l=read(),r=read();
                if(op==0) upd(1,1,n,l,r,read());
                if(op==1) write(askmax(1,1,n,l,r)),puts("");
                if(op==2) write(asksum(1,1,n,l,r)),puts("");
            }
        }
    }
    

    带区间加法的区间除法

    题目描述

    点此看题

    解法

    考虑除法减少得是很快的,而加法只是把区间权值整体抬升,所以我们考虑定义势能函数 (h(x)=lg (mx-mi)),也就是达到状态 (mx-mileq 1) 需要被除的次数。

    显然初始时势能总和是 (nlog nlog c),对于整个被除的区间,如果我们向下递归,那么势能一定减少 (1),这说明我们花费了 (O(1)) 的时间让势能减少 (1)

    再考虑操作中带来的势能增加,考虑一次加法操作部分影响的区间有 (log n) 个,单个区间增加的势能不超过 (log c),所以总势能增加不超过 (qlog nlog c);考虑除法只除到了一个区间的部分,这部分的增量也是类似的 (qlog nlog c)

    所以总时间复杂度 (O((n+q)log nlog c)),具体实现中我们判断 (mx-frac{mx}{d}=mi-frac{mi}{d}) 就打减法标记。

    #include <cstdio>
    #include <iostream>
    #include <cmath>
    using namespace std;
    const int M = 100005;
    #define int long long
    const int inf = 1e18;
    int read()
    {
    	int x=0,f=1;char c;
    	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
    	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
    	return x*f;
    }
    int n,q,fl[4*M],s[4*M],mi[4*M],mx[4*M];
    void work(int x,int y,int len)
    {
    	fl[x]+=y;s[x]+=y*len;
    	mi[x]+=y;mx[x]+=y;
    }
    void up(int i)
    {
    	s[i]=s[i<<1]+s[i<<1|1];
    	mi[i]=min(mi[i<<1],mi[i<<1|1]);
    	mx[i]=max(mx[i<<1],mx[i<<1|1]);
    }
    void down(int i,int l,int r)
    {
    	int mid=(l+r)>>1;
    	if(!fl[i]) return ;
    	work(i<<1,fl[i],mid-l+1);
    	work(i<<1|1,fl[i],r-mid);
    	fl[i]=0;
    }
    void add(int i,int l,int r,int L,int R,int x)
    {
    	if(L>r || l>R) return ;
    	if(L<=l && r<=R)
    	{
    		work(i,x,r-l+1);
    		return ;
    	}
    	int mid=(l+r)>>1;down(i,l,r);
    	add(i<<1,l,mid,L,R,x);
    	add(i<<1|1,mid+1,r,L,R,x);
    	up(i);
    }
    int asksum(int i,int l,int r,int L,int R)
    {
    	if(L>r || l>R) return 0;
    	if(L<=l && r<=R) return s[i];
    	int mid=(l+r)>>1;down(i,l,r);
    	return asksum(i<<1,l,mid,L,R)
    	+asksum(i<<1|1,mid+1,r,L,R);
    }
    int askmin(int i,int l,int r,int L,int R)
    {
    	if(L>r || l>R) return inf;
    	if(L<=l && r<=R) return mi[i];
    	int mid=(l+r)>>1;down(i,l,r);
    	return min(askmin(i<<1,l,mid,L,R),
    	askmin(i<<1|1,mid+1,r,L,R));
    }
    int wxk(int x,int y)
    {
    	return (int)floor(1.0*x/y);
    }
    void zxy(int i,int l,int r,int c)
    {
    	if(l==r)
    	{
    		mx[i]=mi[i]=s[i]=wxk(mx[i],c);
    		return ;
    	}
    	if(mx[i]-wxk(mx[i],c)==mi[i]-wxk(mi[i],c))
    	{
    		work(i,wxk(mx[i],c)-mx[i],r-l+1);
    		return ;
    	}
    	int mid=(l+r)>>1;down(i,l,r);
    	zxy(i<<1,l,mid,c);
    	zxy(i<<1|1,mid+1,r,c);
    	up(i);
    }
    void div(int i,int l,int r,int L,int R,int c)
    {
    	if(L>r || l>R) return ;
    	if(L<=l && r<=R) {zxy(i,l,r,c);return ;}
    	int mid=(l+r)>>1;down(i,l,r);
    	div(i<<1,l,mid,L,R,c);
    	div(i<<1|1,mid+1,r,L,R,c);
    	up(i);
    }
    signed main()
    {
    	n=read();q=read();
    	for(int i=1;i<=n;i++)
    		add(1,1,n,i,i,read());
    	while(q--)
    	{
    		int op=read(),l=read()+1,r=read()+1;
    		if(op==1) add(1,1,n,l,r,read());
    		if(op==2) div(1,1,n,l,r,read());
    		if(op==3) printf("%lld
    ",askmin(1,1,n,l,r));
    		if(op==4) printf("%lld
    ",asksum(1,1,n,l,r)); 
    	}
    }
    
  • 相关阅读:
    CentOS7搭建SFTP服务
    MySQL主从异常恢复
    MySQL主从复制配置
    Docker安装MySQL8.0
    CentOS7安装JDK1.8
    RabbitMQ死信队列
    RabbitMQ重试机制
    RabbitMQ消息可靠性传输
    TCP/IP的Socket编程
    c#网络编程使用tcpListener和tcpClient
  • 原文地址:https://www.cnblogs.com/C202044zxy/p/15182641.html
Copyright © 2011-2022 走看看