zoukankan      html  css  js  c++  java
  • splay和lct

    1. 查询前驱,后继,排名

    splay基本操作

    #include <cstdio>
    const int N = 1e6+10, INF = 0x3f3f3f3f;
    int tot, rt;
    struct {
        int cnt,sz,fa,ch[2],v;
    } tr[N];
    void pu(int x) {
        tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt;
    }
    void rot(int x) {
        int y=tr[x].fa,z=tr[y].fa;
        int f=tr[y].ch[1]==x;
        tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
        tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y;
        tr[x].ch[f^1]=y,tr[y].fa=x,pu(y),pu(x);
    }
    //s=0时将x旋转到根, 否则将x旋转到s的儿子
    void splay(int x, int s=0) {
        for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
            rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
        }
        if (!s) rt=x;
    }
    //若splay中存在权值x,那么把x旋转到根,否则把x的前驱或后继旋转到根
    //(get函数只是为了简化求前驱和后继操作, 其他地方不要使用get)
    void get(int x) {
        int cur=rt;
        while (x!=tr[cur].v&&tr[cur].ch[x>tr[cur].v]) cur=tr[cur].ch[x>tr[cur].v];
        splay(cur);
    }
    //插入权值x
    void insert(int x) {
        int cur=rt,p=0;
        while (cur&&x!=tr[cur].v) p=cur,cur=tr[cur].ch[x>tr[cur].v];
        if (cur) ++tr[cur].cnt;
        else {
            cur=++tot;
            if (p) tr[p].ch[x>tr[p].v]=cur,tr[cur].fa=p;
            tr[cur].v=x,tr[cur].sz=tr[cur].cnt=1;
        }
        splay(cur);
    }
    //返回<=x的节点编号
    int pre(int x) {
        get(x);
        if (tr[rt].v<=x) return rt;
        int cur=tr[rt].ch[0];
        while (tr[cur].ch[1]) cur=tr[cur].ch[1];
        splay(cur);
        return cur;
    }
    //返回>=x的节点编号
    int nxt(int x) {
        get(x);
        if (tr[rt].v>=x) return rt;
        int cur=tr[rt].ch[1];
        while (tr[cur].ch[0]) cur=tr[cur].ch[0];
        splay(cur);
        return cur;
    }
    //若权值x存在删除x对应节点, 否则无影响
    void erase(int x) {
        int s1=pre(x-1),s2=nxt(x+1);
        splay(s1),splay(s2,s1);
        int &cur=tr[s2].ch[0];
        if (tr[cur].cnt>1) --tr[cur].cnt,splay(cur);
        else cur=0;
    }
    //返回权值x的排名(<x的数的个数+1)
    int rk(int x) {
        int t = pre(x-1);
        return tr[tr[t].ch[0]].sz+tr[t].cnt;
    }
    //返回排名为k-1的数
    int kth(int x, int k) {
        int s=tr[tr[x].ch[0]].sz;
        if (k<=s) return kth(tr[x].ch[0],k);
        if (k>s+tr[x].cnt) return kth(tr[x].ch[1],k-s-tr[x].cnt);
        return splay(x),x;
    }
    int main() {
        int n;
        scanf("%d", &n);
        //初始化插入INF和-INF
        insert(INF),insert(-INF);
        while (n--) {
            int op, x;
            scanf("%d%d", &op, &x);
            //插入
            if (op==1) insert(x);
            //删除
            else if (op==2) erase(x);
            //查询排名(比x小的数的个数+1)
            else if (op==3) printf("%d
    ",rk(x));
            //查询排名为x的数
            else if (op==4) printf("%d
    ",tr[kth(rt,x+1)].v);
            //求x的前驱(小于x且最大的数)
            else if (op==5) printf("%d
    ",tr[pre(x-1)].v);
            //求x的后继(大于x且最小的数)
            else printf("%d
    ",tr[nxt(x+1)].v);
        }
    }
    P3369

    2. 区间翻转

    维护splay中序遍历得到的序列, 查询下标相当于查询排名

    #include <cstdio>
    #include <algorithm>
    using namespace std;
    const int N = 1e6+10;
    int rt, tot;
    struct {
        int sz,v,ch[2],fa,rev;
        void upd() {swap(ch[0],ch[1]);rev^=1;}
    } tr[N];
    void pu(int o) {
        tr[o].sz=tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1;
    }
    void pd(int o) {
        if (tr[o].rev) {
            tr[tr[o].ch[0]].upd();
            tr[tr[o].ch[1]].upd();
            tr[o].rev=0;
        }
    }
    void rot(int x) {
        int y=tr[x].fa,z=tr[y].fa;
        int f=tr[y].ch[1]==x,w=tr[x].ch[f^1];
        tr[z].ch[tr[z].ch[1]==y]=x;
        tr[x].ch[f^1]=y,tr[y].ch[f]=w;
        tr[w].fa=y;
        tr[y].fa=x,tr[x].fa=z;
        pu(y),pu(x);
    }
    void splay(int x, int s=0) {
        for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
            rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
        }
        if (!s) rt=x;
    }
    int find(int x, int k) {
        pd(x); int s=tr[tr[x].ch[0]].sz;
        if (k==s+1) return splay(x),x;
        if (k<=s) return find(tr[x].ch[0],k);
        return find(tr[x].ch[1],k-s-1);
    }
    void reverse(int x, int y) {
        int s1=find(rt,x), s2=find(rt,y+2);
        splay(s1),splay(s2,s1);
        tr[tr[s2].ch[0]].upd();
    }
    void build(int f, int &o, int l, int r) {
        if (l>r) return;
        o = ++tot; int mid = (l+r)/2;
        tr[o].v = mid-1, tr[o].fa = f;
        build(o,tr[o].ch[0],l,mid-1);
        build(o,tr[o].ch[1],mid+1,r);
        pu(o);
    }
    int n,m;
    void dfs(int x) {
        if (!x) return;
        pd(x);
        dfs(tr[x].ch[0]);
        if (1<=tr[x].v&&tr[x].v<=n) printf("%d ",tr[x].v);
        dfs(tr[x].ch[1]);
    }
    int main() {
        scanf("%d%d", &n, &m);
        build(0,rt,1,n+2);
        while (m--) {
            int x, y;
            scanf("%d%d", &x, &y);
            reverse(x,y);
        }
        dfs(rt),puts("");
    }
    P3391

    3. 区间加区间求和

    用splay提取区间, 转化为子树加和子树求和, 打标记即可

    #include <iostream>
    using namespace std;
    const int N = 1e5+10, INF = 0x3f3f3f3f;
    int tot, rt, a[N];
    struct {
        int sz,fa,ch[2];
        long long sum, tag, v;
        void add(long long x) {sum+=sz*x;tag+=x;v+=x;}
    } tr[N];
    void pu(int x) {
        tr[x].sz = tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+1;
        tr[x].sum = tr[tr[x].ch[0]].sum+tr[tr[x].ch[1]].sum+tr[x].v;
    }
    void pd(int x) {
        if (tr[x].tag) {
            tr[tr[x].ch[0]].add(tr[x].tag);
            tr[tr[x].ch[1]].add(tr[x].tag);
            tr[x].tag = 0;
        }
    }
    void rot(int x) {
        int y=tr[x].fa,z=tr[y].fa;
        int f=tr[y].ch[1]==x;
        tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
        tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y;
        tr[x].ch[f^1]=y,tr[y].fa=x,pu(y),pu(x);
    }
    void splay(int x, int s=0) {
        for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
            rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
        }
        if (!s) rt=x;
    }
    int find(int x, int k) {
        pd(x); int s = tr[tr[x].ch[0]].sz;
        if (k==s+1) return x;
        if (k<=s) return find(tr[x].ch[0],k);
        return find(tr[x].ch[1],k-s-1);
    }
    void build(int f, int &o, int l, int r) {
        if (l>r) return;
        o = ++tot; int mid = (l+r)/2;
        tr[o].v = a[mid], tr[o].fa = f;
        build(o,tr[o].ch[0],l,mid-1);
        build(o,tr[o].ch[1],mid+1,r);
        pu(o);
    }
    int split(int x, int y) {
        int s1 = find(rt,x), s2 = find(rt,y+2);
        splay(s1), splay(s2,s1);
        return tr[s2].ch[0];
    }
    int main() {
        int n, m;
        scanf("%d%d", &n, &m);
        for (int i=2; i<=n+1; ++i) scanf("%d", &a[i]);
        build(0,rt,1,n+2);
        while (m--) {
            int op, x, y, z;
            scanf("%d%d%d", &op, &x, &y);
            if (op==1) {
                scanf("%d", &z);
                tr[split(x,y)].add(z);
            }
            else {
                printf("%lld
    ", tr[split(x,y)].sum);
            }
        }
    }
    P3372

    用$lct$也可以做, 可以差分一下避免换根操作. 常数比splay略大

    #include <bits/stdc++.h>
    using namespace std;
    const int N = 1e5+10;
    int a[N];
    struct {
        int ch[2],fa,sz;
        long long sum, v, tag;
        void add(long long x) {sum+=sz*x,v+=x,tag+=x;}
    } tr[N];
    void pd(int o) {
        if (tr[o].tag) {
            tr[tr[o].ch[0]].add(tr[o].tag);
            tr[tr[o].ch[1]].add(tr[o].tag);
            tr[o].tag = 0;
        }
    }
    void pu(int o) {
        tr[o].sz = tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1;
        tr[o].sum = tr[tr[o].ch[0]].sum+tr[tr[o].ch[1]].sum+tr[o].v;
    }
    int nroot(int x) {
        return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x;
    }
    void rot(int x) {
        int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[!k];
        if (nroot(y)) tr[z].ch[tr[z].ch[1]==y]=x;
        tr[x].ch[!k]=y,tr[y].ch[k]=w;
        if (w) tr[w].fa=y;
        tr[y].fa=x,tr[x].fa=z;
        pu(x),pu(y);
    }
    void repush(int x) {
        if (nroot(x)) repush(tr[x].fa);
        pd(x);
    }
    void splay(int x) {
        repush(x);
        while (nroot(x)) {
            int y=tr[x].fa,z=tr[y].fa;
            if (nroot(y)) rot((tr[y].ch[0]==x)==(tr[z].ch[0]==y)?y:x);
            rot(x);
        }
        pu(x);
    }
    void access(int x) {
        for (int y=0,t=x; t; t=tr[y=t].fa) splay(t),tr[t].ch[1]=y,pu(t);
        splay(x);
    }
    int main() {
        int n, m;
        scanf("%d%d", &n, &m);
        for (int i=1; i<=n; ++i) { 
            scanf("%d", &a[i]);
            tr[i].v = a[i];
            if (i>1) tr[i].fa = i-1;
        }
        while (m--) {
            int op, x, y, z;
            scanf("%d%d%d", &op, &x, &y);
            if (op==1) {
                scanf("%d", &z);
                access(y);
                tr[y].add(z);
                if (x>1) access(x-1), tr[x-1].add(-z);
            }
            else {
                access(y);
                long long ans = tr[y].sum;
                if (x>1) access(x-1), ans -= tr[x-1].sum;
                printf("%lld
    ", ans);
            }
        }
    }
    View Code

    4. 动态添边维护最小生成树

    $nle 10^3$的话, 可以先$prim$求出最小生成树, 添一条边$(u,v)$, 那么就暴力求出当前最小生成树中路径$(u,v)$上的最大边, 如果添的边更小就替换掉, 复杂度是$O(n(n+q))$

    #include <bits/stdc++.h>
    using namespace std;
    const int N = 1e3+10, M = 1e5+10, INF = 0x3f3f3f3f;
    int n,m,q,ret,a[N][N],dis[N],pos[N],vis[N];
    int ans[M],u,v,w;
    struct {int op,u,v,w;} e[M];
    vector<int> g[N];
    int dfs(int x, int f, int s) {
        if (x==s) return 1;
        for (int y:g[x]) if (y!=f) { 
            if (dfs(y,x,s)) { 
                if (a[x][y]>w) {
                    w = a[x][y];
                    u = x, v = y;
                }
                ret=max(ret,a[y][x]);
                return 1;
            }
        }
        return 0;
    }
    int main() {
        scanf("%d%d%d", &n, &m, &q);
        memset(a,0x3f,sizeof a);
        for (int i=1; i<=m; ++i) {
            int u, v, w;
            scanf("%d%d%d", &u, &v, &w);
            a[u][v] = a[v][u] = w;
        }
        for (int i=1; i<=q; ++i) {
            scanf("%d%d%d", &e[i].op, &e[i].u, &e[i].v);
            if (e[i].op==2) { 
                e[i].w = a[e[i].u][e[i].v];
                a[e[i].u][e[i].v] = a[e[i].v][e[i].u] = INF;
            }
        }
        memset(dis,0x3f,sizeof dis);
        dis[1] = 0;
        for (int i=1; i<=n; ++i) {
            int mi = INF, x, y;
            for (int j=1; j<=n; ++j) {
                if (!vis[j]&&dis[j]<mi) {
                    mi = dis[j];
                    x = j, y = pos[j];
                }
            }
            vis[x] = 1;
            if (x!=1) { 
                g[x].push_back(y);
                g[y].push_back(x);
            }
            for (int j=1; j<=n; ++j) {
                if (dis[j]>a[x][j]) dis[j] = a[x][j], pos[j] = x;
            }
        }
        for (int i=q; i; --i) {
            w = 0;
            if (e[i].op==1) { 
                dfs(e[i].u,0,e[i].v);
                ans[i] = w;
            }
            else {
                dfs(e[i].u,0,e[i].v);
                a[e[i].u][e[i].v] = a[e[i].v][e[i].u] = e[i].w;
                auto del = [&](int u, int v) {
                    for (int i=0; i<g[u].size(); ++i) {
                        if (g[u][i]==v) {
                            swap(g[u][i],g[u].back());
                            g[u].pop_back();
                            return;
                        }
                    }
                };
                if (e[i].w<w) {
                    del(u,v);
                    del(v,u);
                    g[e[i].u].push_back(e[i].v);
                    g[e[i].v].push_back(e[i].u);
                }
            }
        }
        for (int i=1; i<=q; ++i) if (e[i].op==1) printf("%d
    ", ans[i]);
    }
    View Code

    用lct的话等价于要找边权最大值以及两端点, $lct$维护边权可以把每条边新建一个点, 转化为维护点权

    复杂度是$O((m+q)log{m})$ 

    #include <bits/stdc++.h>
    using namespace std;
    const int N = 2e5+10;
    int n,m,q,fa[N],ans[N],vis[N],val[N];
    struct edge {int u,v,w;} e[N];
    struct Q {int op,u,v,id;} f[N];
    map<pair<int,int>,int> mp;
    int Find(int x) {return fa[x]?fa[x]=Find(fa[x]):x;}
    struct {
        int ch[2],fa,tag,v;
        void rev() {tag^=1;swap(ch[0],ch[1]);}
    } tr[N];
    void pd(int o) {
        if (tr[o].tag) {
            tr[tr[o].ch[0]].rev();
            tr[tr[o].ch[1]].rev();
            tr[o].tag = 0;
        }
    }
    void pu(int o) {
        tr[o].v = max({val[o], tr[tr[o].ch[0]].v, tr[tr[o].ch[1]].v});
    }
    int nroot(int x) {
        return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x;
    }
    void rot(int x) {
        int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[!k];
        if (nroot(y)) tr[z].ch[tr[z].ch[1]==y]=x;
        tr[x].ch[!k]=y,tr[y].ch[k]=w;
        if (w) tr[w].fa=y;
        tr[y].fa=x,tr[x].fa=z;
        pu(x),pu(y);
    }
    void repush(int x) {
        if (nroot(x)) repush(tr[x].fa);
        pd(x);
    }
    void splay(int x) {
        repush(x);
        while (nroot(x)) {
            int y=tr[x].fa,z=tr[y].fa;
            if (nroot(y)) rot((tr[y].ch[0]==x)==(tr[z].ch[0]==y)?y:x);
            rot(x);
        }
        pu(x);
    }
    void access(int x) {
        for (int y=0; x; x=tr[y=x].fa) splay(x),tr[x].ch[1]=y,pu(x);
    }
    void makeroot(int x) {
        access(x),splay(x);
        tr[x].rev();
    }
    int findroot(int x) {
        access(x),splay(x);
        while (tr[x].ch[0]) pd(x),x=tr[x].ch[0];
        return splay(x), x;
    }
    void split(int x, int y) {
        makeroot(x),access(y),splay(y);
    }
    void link(int x, int y) {
        makeroot(x);
        tr[x].fa=y;
    }
    void cut(int x, int y) {
        split(x,y);
        tr[y].ch[0]=tr[x].fa=0;
        pu(y);
    }
    int main() {
        scanf("%d%d%d", &n, &m, &q);
        for (int i=1; i<=m; ++i) { 
            scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);
            if (e[i].u>e[i].v) swap(e[i].u, e[i].v);
        }
        for (int i=1; i<=q; ++i) {
            scanf("%d%d%d", &f[i].op, &f[i].u, &f[i].v);
            if (f[i].u>f[i].v) swap(f[i].u,f[i].v);
        }
        sort(e+1,e+1+m,[](edge a,edge b){return a.w<b.w;});
        for (int i=1; i<=m; ++i) { 
            val[i+n] = mp[{e[i].u,e[i].v}] = i;
        }
        for (int i=1; i<=q; ++i) {
            if (f[i].op==2) { 
                f[i].id = mp[{f[i].u,f[i].v}];
                vis[f[i].id] = 1;
            }
        }
        for (int i=1; i<=m; ++i) if (!vis[i]) {
            int u = Find(e[i].u), v = Find(e[i].v);
            if (u!=v) {
                link(e[i].u,i+n);
                link(e[i].v,i+n);
                fa[u] = v;
            }
        }
        for (int i=q; i; --i) {
            if (f[i].op==1) {
                split(f[i].u,f[i].v);
                ans[i] = e[tr[f[i].v].v].w;
            }
            else { 
                split(f[i].u,f[i].v);
                int x = tr[f[i].v].v, y = mp[{f[i].u,f[i].v}];
                if (x>y) {
                    cut(e[x].u,x+n);
                    cut(e[x].v,x+n);
                    link(e[y].u,y+n);
                    link(e[y].v,y+n);
                }
            }
        }
        for (int i=1; i<=q; ++i) if (f[i].op==1) printf("%d
    ", ans[i]);
    }
    View Code
  • 相关阅读:
    如何写文件上传下载
    填充表格的模板代码
    ArcGIS Server 分布式注意事项
    在android上导入第三方jar包 报错:Could not find class
    @Override annotation 出错
    签到时间
    分页三条件查询
    上传图片
    分页
    二级联动
  • 原文地址:https://www.cnblogs.com/fs-es/p/13785011.html
Copyright © 2011-2022 走看看