zoukankan      html  css  js  c++  java
  • ACM之路(19)—— 主席树初探

      长春赛的 I 题是主席树,现在稍微的学了一点主席树,也就算入了个门吧= =

      简单的来说主席树就是每个节点上面都是一棵线段树,但是这么多线段树会MLE吧?其实我们解决的办法就是有重复的节点给他利用起来,具体见幻神博客

      不妨以1~n上的求任意区间第k小的问题,就是上面博客中所写,我们从1访问到n的预处理中,每一个时间都新建一个线段树,这棵树上记录着已经出现的各个数字,这样我们求[L,R]上的第k小,我们拿R时刻的线段树减去(L-1)时刻的线段树,就是这个区间内需要的线段树,这个线段树上存在的数字其实就是[L,R]上存在的数字,我们在这里寻找我们需要的第k小就可以了。具体实现方法见上面的博客。

      我自己的模板如下:

     1 #include <stdio.h>
     2 #include <algorithm>
     3 #include <string.h>
     4 #define t_mid (l+r>>1)
     5 using namespace std;
     6 const int N = 100000 + 5;
     7 
     8 int n,q,tot,sz;
     9 int a[N],b[N];
    10 int rt[N*20],sum[N*20],ls[N*20],rs[N*20];
    11 void build(int &o,int l,int r)
    12 {
    13     o = ++tot;
    14     sum[o] = 0;
    15     if(l==r) return;
    16     build(ls[o],l,t_mid);
    17     build(rs[o],t_mid+1,r);
    18 }
    19 
    20 void update(int &o,int l,int r,int last,int p)
    21 {
    22     o = ++tot;
    23     ls[o] = ls[last];
    24     rs[o] = rs[last];
    25     sum[o] = sum[last] + 1;
    26     if(l==r) return;
    27     if(p <= t_mid) update(ls[o],l,t_mid,ls[last],p);
    28     else update(rs[o],t_mid+1,r,rs[last],p);
    29 }
    30 
    31 int query(int ql,int qr,int l,int r,int k)
    32 {
    33     if(l==r) return l;
    34     int cnt = sum[ls[qr]] - sum[ls[ql]];
    35     if(cnt >= k) return query(ls[ql],ls[qr],l,t_mid,k);
    36     else return query(rs[ql],rs[qr],t_mid+1,r,k-cnt);
    37 }
    38 
    39 void work()
    40 {
    41     int ql,qr,k;
    42     scanf("%d%d%d",&ql,&qr,&k);
    43     int ans = query(rt[ql-1],rt[qr],1,sz,k);
    44     printf("%d
    ",b[ans]);
    45 }
    46 
    47 int main()
    48 {
    49     while(scanf("%d%d",&n,&q)==2)
    50     {
    51         tot = 0;
    52         for(int i=1;i<=n;i++) scanf("%d",a+i),b[i]=a[i];
    53         sort(b+1,b+1+n);
    54         sz = unique(b+1,b+1+n) - (b+1);
    55         build(rt[0],1,sz);
    56         
    57         for(int i=1;i<=n;i++)
    58         {
    59             int t = lower_bound(b+1,b+1+sz,a[i]) - b;
    60             update(rt[i],1,sz,rt[i-1],t);
    61         }
    62         while(q--) work();
    63     }
    64 }
    求区间第K小

      然后如果是在一棵树上,求其一条链上的区间第k小呢?其实也差不多,我们就想着怎么把这棵需要的线段树抽取出来就行。这棵树实际上就是 u - lca(u,v) + v - father(lca(u,v))。具体的画画图就可以懂了。这里还涉及到求LCA的方法,具体方法见《挑战程序设计》中的倍增法求LCA即可。

      我自己的模板如下:

      1 #include <stdio.h>
      2 #include <algorithm>
      3 #include <string.h>
      4 #include <vector>
      5 #include <math.h>
      6 #define t_mid (l+r>>1)
      7 using namespace std;
      8 const int N = 100000 + 5;
      9 const int MAX_LOG_N = 16 + 5;
     10 
     11 int n,q,tot,sz;
     12 int a[N],b[N];
     13 int rt[N*20],sum[N*20],ls[N*20],rs[N*20];
     14 int parent[MAX_LOG_N][N],depth[N];
     15 vector<int> G[N];
     16 
     17 void getDepth(int v,int p,int d)
     18 {
     19     parent[0][v] = p;
     20     depth[v] = d;
     21     for(int i=0;i<G[v].size();i++)
     22     {
     23         if(G[v][i] != p) getDepth(G[v][i],v,d+1);
     24     }
     25 }
     26 
     27 void init()
     28 {
     29     getDepth(1,-1,0);
     30     for(int k=0;k+1<MAX_LOG_N;k++)
     31     {
     32         for(int v=1;v<=n;v++)
     33         {
     34             if(parent[k][v] < 0) parent[k+1][v] = -1;
     35             else parent[k+1][v] = parent[k][parent[k][v]];
     36         }
     37     }
     38 }
     39 
     40 int lca(int u,int v)
     41 {
     42     if(depth[u]>depth[v]) swap(u,v);
     43     for(int k=0;k<MAX_LOG_N;k++)
     44     {
     45         if((depth[v]-depth[u]) >> k & 1)
     46         {
     47             v = parent[k][v];
     48         }
     49     }
     50     if(u==v) return u;
     51     for(int k=MAX_LOG_N-1;k>=0;k--)
     52     {
     53         if(parent[k][u] != parent[k][v])
     54         {
     55             u = parent[k][u];
     56             v = parent[k][v];
     57         }
     58     }
     59     return parent[0][u];
     60 }
     61 
     62 void build(int &o,int l,int r)
     63 {
     64     o = ++tot;
     65     sum[o] = 0;
     66     if(l==r) return;
     67     build(ls[o],l,t_mid);
     68     build(rs[o],t_mid+1,r);
     69 }
     70 
     71 void update(int &o,int l,int r,int last,int p)
     72 {
     73     o = ++tot;
     74     ls[o] = ls[last];
     75     rs[o] = rs[last];
     76     sum[o] = sum[last] + 1;
     77     if(l==r) return;
     78     if(p <= t_mid) update(ls[o],l,t_mid,ls[last],p);
     79     else update(rs[o],t_mid+1,r,rs[last],p);
     80 }
     81 
     82 int query(int u,int v,int x,int y,int l,int r,int k)
     83 {
     84     if(l==r) return l;
     85     int cnt = sum[ls[u]] + sum[ls[v]] - sum[ls[x]] - sum[ls[y]];
     86     if(cnt >= k) return query(ls[u],ls[v],ls[x],ls[y],l,t_mid,k);
     87     else return query(rs[u],rs[v],rs[x],rs[y],t_mid+1,r,k-cnt);
     88 }
     89 
     90 void work()
     91 {
     92     int u,v,k;
     93     scanf("%d%d%d",&u,&v,&k);
     94     int _lca = lca(u,v);
     95     int _lca_fa = parent[0][_lca];
     96     int ans = query(rt[u],rt[v],rt[_lca],rt[_lca_fa],1,sz,k);
     97     printf("%d
    ",b[ans]);
     98 }
     99 
    100 void dfs(int u,int fa)
    101 {
    102     for(int i=0;i<G[u].size();i++)
    103     {
    104         int v = G[u][i];
    105         if(v==fa) continue;
    106         int t = lower_bound(b+1,b+1+sz,a[v]) - b;
    107         update(rt[v],1,sz,rt[u],t);
    108         dfs(v,u);
    109     }
    110 }
    111 
    112 int main()
    113 {
    114     while(scanf("%d%d",&n,&q)==2)
    115     {
    116         tot = 0;
    117         for(int i=1;i<=n;i++) G[i].clear();
    118         for(int i=1;i<=n;i++) scanf("%d",a+i),b[i]=a[i];
    119         sort(b+1,b+1+n);
    120         sz = unique(b+1,b+1+n) - (b+1);
    121         for(int i=1;i<n;i++)
    122         {
    123             int u,v;scanf("%d%d",&u,&v);
    124             G[u].push_back(v);
    125             G[v].push_back(u);
    126         }
    127         build(rt[0],1,sz);
    128         init();
    129         
    130         int t = lower_bound(b+1,b+1+sz,a[1]) - b;
    131         update(rt[1],1,sz,rt[0],t);
    132         dfs(1,-1);
    133         
    134         while(q--) work();
    135     }
    136 }
    求树上的一条链的第K小

      好,接下来就是解决那个烦人的 I 题了。

      我们首先需要用主席树来解决区间内不同的数的个数,这东西比较奥义- -直接上模板好了。。反正随便百度一下"主席树求区间内不同数的个数"都会出来spoj的D-query那题,随便看下原理就行= =。。。然后用二分解决 I 题(固定左端点,二分右端点,具体见代码。。)。

      看我直接丢 I 题的代码~:

      1 #include <stdio.h>
      2 #include <algorithm>
      3 #include <string.h>
      4 #include <map>
      5 #define t_mid (l+r>>1)
      6 using namespace std;
      7 const int N = 2*100000 + 50;
      8 
      9 int rt[N*20*2],sum[N*20*2],ls[N*20*2],rs[N*20*2];
     10 int a[N],n,m,tot;
     11 void build(int &o,int l,int r)
     12 {
     13     o = ++tot;
     14     sum[o] = 0;
     15     if(l == r) return;
     16     build(ls[o],l,t_mid);
     17     build(rs[o],t_mid+1,r);
     18 }
     19 
     20 void update(int &o,int l,int r,int last,int pos,int dt)
     21 {
     22     o = ++tot;
     23     sum[o] = sum[last];
     24     ls[o] = ls[last];
     25     rs[o] = rs[last];
     26     if(l==r) {sum[o]+=dt;return;}
     27     if(pos <= t_mid) update(ls[o],l,t_mid,ls[last],pos,dt);
     28     else update(rs[o],t_mid+1,r,rs[last],pos,dt);
     29     sum[o] = sum[ls[o]] + sum[rs[o]];
     30 }
     31 
     32 int query(int l,int r,int o,int pos)
     33 {
     34     if(l == r) return sum[o];
     35     if(pos <= t_mid) return sum[rs[o]] + query(l,t_mid,ls[o],pos);
     36     else return query(t_mid+1,r,rs[o],pos);
     37 }
     38 
     39 /*
     40 int query(int l,int r,int L,int R,int x){
     41     if(L <= l && r <= R) return sum[x];
     42     int mid = (l+r) >> 1 , ret = 0;
     43     if(L <= mid) ret += query(l,mid,L,R,ls[x]);
     44     if(R > mid) ret += query(mid+1,r,L,R,rs[x]);
     45     return ret;
     46 }
     47 */
     48 
     49 int main()
     50 {
     51     int T;scanf("%d",&T);
     52     for(int kase=1;kase<=T;kase++)
     53     {
     54         scanf("%d%d",&n,&m);
     55         int pre = 0;
     56         map<int,int> mp;
     57         tot = 0;
     58         for(int i=1;i<=n;i++) scanf("%d",a+i);
     59         build(rt[0],1,n);
     60 
     61         for(int i=1;i<=n;i++)
     62         {
     63             if(mp.find(a[i]) == mp.end())
     64             {
     65                 mp[a[i]] = i;
     66                 update(rt[i],1,n,rt[i-1],i,1);
     67             }
     68             else
     69             {
     70                 int temp = 0;
     71                 update(temp,1,n,rt[i-1],mp[a[i]],-1);
     72                 update(rt[i],1,n,temp,i,1);
     73             }
     74             mp[a[i]] = i;
     75         }
     76         //scanf("%d",&m);
     77         printf("Case #%d:",kase);
     78         while(m--)
     79         {
     80             int ql,qr;scanf("%d%d",&ql,&qr);
     81             int L = min((ql+pre)%n+1,(qr+pre)%n+1);
     82             int R = max((ql+pre)%n+1,(qr+pre)%n+1);
     83             //L = ql, R = qr;
     84             int k = (query(1,n,rt[R],L)+1)>>1;
     85             int l = L, r = R;
     86             //printf("!! %d %d 
    ",L,R);
     87             int ans = -1;
     88             while(l<=r)
     89             {
     90                 int mid = l + r >> 1;
     91                 int t = query(1,n,rt[mid],L);
     92                 //printf("mid is %d %d
    ",mid,t);
     93                 if(t < k) l = mid + 1;
     94                 else
     95                 {
     96                     r = mid - 1;
     97                     ans = mid;
     98                 }
     99             }
    100             /*while(l < r)
    101             {
    102                 int mid = l + r >> 1;
    103                 int t = query(1,n,rt[mid],L);
    104                 if(t < k) l = mid + 1;
    105                 else r = mid;
    106             }*/
    107             
    108             printf(" %d",ans);
    109             pre = ans;
    110         }
    111         puts("");
    112     }
    113 }
    114 
    115 /*
    116 100
    117 20 100
    118 1 2 3 4 3 2 1 2 4 2 2 3 1 2 3 1 4 4 2 1
    119 1 20
    120 1 10
    121 2 5
    122 4 6
    123 3 2
    124 4 7
    125 
    126 100
    127 5 100
    128 0 1 0 2 3
    129 1 5
    130 */
    131 /*
    132 #include<iostream>
    133 //#include<bits/stdc++.h>
    134 #include<cstdio>
    135 #include<string>
    136 #include<cstring>
    137 #include<map>
    138 #include<queue>
    139 #include<set>
    140 #include<stack>
    141 #include<ctime>
    142 #include<algorithm>
    143 #include<cmath>
    144 #include<vector>
    145 #define showtime fprintf(stderr,"time = %.15f
    ",clock() / (double)CLOCKS_PER_SEC)
    146 //#pragma comment(linker, "/STACK:1024000000,1024000000")
    147 using namespace std;
    148 typedef long long ll;
    149 typedef long long LL;
    150 #define MP make_pair
    151 #define PII pair<int,int>
    152 #define PLI pair<long long ,int>
    153 #define PFI pair<double,int>
    154 #define PLL pair<ll,ll>
    155 #define PB push_back
    156 #define F first
    157 #define S second
    158 #define lson l,mid,rt<<1
    159 #define rson mid+1,r,rt<<1|1
    160 #define debug cout<<"?????"<<endl;
    161 //freopen("1005.in","r",stdin);
    162 //freopen("data.out","w",stdout);
    163 const int INF = 0x3f3f3f3f;
    164 const double eps = 1e-2;
    165 const int N = 4e5 + 50 ;
    166 const double PI = acos(-1.);
    167 const double E = 2.71828182845904523536;
    168 const int MOD = 1e9+7;
    169 typedef vector<ll> Vec;
    170 typedef vector<Vec> Mat;
    171 int n,m;
    172 struct node{int l,r,sum;}T[N*40];
    173 int a[N],root[N],pre[N],tot;
    174 int q,x,y;
    175 int ans[N];
    176 vector<int> v;
    177 int getid(int x){ return lower_bound(v.begin(),v.end(),x) - v.begin() + 1;}
    178 void init(){
    179     tot = 0;
    180     memset(root,0,sizeof(root));
    181     memset(pre,-1,sizeof(pre));
    182     v.clear();
    183 }
    184 void update(int l,int r,int val,int &x,int y,int pos){
    185     T[++tot] = T[y] , T[tot].sum += val , x = tot;
    186     if(l == r) return ;
    187     int mid = (l + r) >> 1;
    188     if(pos <= mid) update(l,mid,val,T[x].l,T[y].l,pos);
    189     else update(mid+1,r,val,T[x].r,T[y].r,pos);
    190 }
    191 **
    192  *        【x=L,y=R】 不同数字的有多少个
    193  *        query(1,n,x,y,root[y]);  第y颗树。
    194  *
    195 int query(int l,int r,int L,int R,int x){
    196     if(L <= l && r <= R) return T[x].sum;
    197     int mid = (l+r) >> 1 , ret = 0;
    198     if(L <= mid) ret += query(l,mid,L,R,T[x].l);
    199     if(R > mid) ret += query(mid+1,r,L,R,T[x].r);
    200     return ret;
    201 }
    202 int main(){
    203     int kase = 1,T;
    204     cin >> T;
    205     while(T --){
    206         cin >> n >> m;
    207         init();
    208         for(int i = 1 ; i <= n ; i ++) scanf("%d",&a[i]) , v.push_back(a[i]);
    209         sort(v.begin(),v.end());
    210         v.erase(unique(v.begin(),v.end()),v.end());
    211         for(int i = 1 ; i <= n ; i ++){
    212             int id = getid(a[i]);
    213             if(pre[id] == -1){
    214                 update(1,n,1,root[i],root[i-1],i);
    215                 pre[id] = i;
    216             }else{
    217                 int tmp;
    218                 update(1,n,-1,tmp,root[i-1],pre[id]);
    219                 update(1,n,1,root[i],tmp,i);
    220                 pre[id] = i;
    221             }
    222         }
    223         ans[0] = 0;
    224         printf("Case #%d:",kase ++);
    225         for(int i = 1 ; i <= m ; i ++){
    226             scanf("%d%d",&x,&y);
    227             int l,r;
    228             l = min((x+ans[i-1])%n+1,(y+ans[i-1])%n+1);
    229             r = max((x+ans[i-1])%n+1,(y+ans[i-1])%n+1);
    230             //l = x ; r = y;
    231             //printf("%d %d !!
    ",l,r);
    232             int k = (query(1,n,l,r,root[r])+1) / 2;
    233             int ll = l , rr = r;
    234             while(ll < rr){
    235                 int mid = (ll + rr) / 2;
    236                 int t = query(1,n,l,mid,root[mid]);
    237                 if(t < k) ll = mid+1;
    238                 else rr = mid;
    239             }
    240             printf(" %d",rr);
    241             ans[i] = rr;
    242         }
    243         puts("");
    244     }
    245     return 0;
    246 }
    247 */
    长春 I 题

      有几点想说明的:1.下面注释的是大力的代码,但是超时了,因为他的query方法和我的有点小差别,虽然都能实现需要的功能,但是似乎我的query方法复杂度更小一点(??)。。不过我的也是卡过的,但是我觉得在长春现场赛的话应该能过,感觉HDU的评测机这次有点坑。。2.我的代码本来是WA的,因为数组开小了,我上面的两个代码都是*20的,都没问题,这里必须要开*40的才行,被坑了这一次以后我下次都开大一点的好了,反正*40内存也够用= =。。那么主席树就写到这里好了,以后刷了题目有什么要补充的再补充好了~(话说我的数据结构真的好烂啊,,以后搞splay怎么办啊233。。)

  • 相关阅读:
    Postman提取接口返回值设置变量
    Python-浅拷贝与深拷贝
    Python列表
    typeorm查询两个没有关联关系的实体
    springboot去掉数据源自动加载
    docker搭建redis集群
    实习工作记录(一)大文件上传vue+WebUploader
    js重点之promise
    css重点
    git简单命令整理
  • 原文地址:https://www.cnblogs.com/zzyDS/p/5931453.html
Copyright © 2011-2022 走看看