例题:https://www.luogu.org/problemnew/show/P3834
主席树用于查询每个历史版本。
这个题代码如下
#include<bits/stdc++.h> #define ll long long using namespace std; const int maxn=200001; int n,m,tot; ll a[maxn],b[maxn]; struct stree{ int lc,rc; ll sum; #define lc(x) tree[x].lc #define rc(x) tree[x].rc #define s(x) tree[x].sum }tree[maxn<<5]; int root[maxn]; inline int build(int l,int r){ int p=++tot; if(l==r) return p; int mid=(l+r)>>1; lc(p)=build(l,mid); rc(p)=build(mid+1,r); return p; } inline int insert(int now,int l,int r,int x){ int p=++tot; tree[p]=tree[now]; s(p)++; if(l==r)return p; int mid=(l+r)>>1; if(x<=mid) lc(p)=insert(lc(now),l,mid,x); else rc(p)=insert(rc(now),mid+1,r,x); return p; } inline int ask(int p,int q,int l,int r,int k){//在p,q两个时间节点上,在[l,r]的范围内的数的个数 if(l==r) return l; int mid=(l+r)>>1; int lcnt=s(lc(q))-s(lc(p)); if(k<=lcnt) return ask(lc(p),lc(q),l,mid,k);//在两个节点的左儿子中查找。 else return ask(rc(p),rc(q),mid+1,r,k-lcnt); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%lld",&a[i]),b[i]=a[i]; sort(b+1,b+n+1); int cnt=unique(b+1,b+n+1)-b-1; root[0]=build(1,cnt); for(int i=1;i<=n;i++){ a[i]=lower_bound(b+1,b+cnt+1,a[i])-b; root[i]=insert(root[i-1],1,cnt,a[i]); } while(m--){ int l,r,k; scanf("%d%d%d",&l,&r,&k); printf("%lld ",b[ask(root[l-1],root[r],1,cnt,k)]); } system("pause"); return 0; }
例题2:https://www.luogu.org/problemnew/show/P3919
主席树裸题,直接就是模板。
代码如下:
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+10; int n,m,tot; int root[maxn],a[maxn]; struct stree{ int lc,rc,val; #define lc(x) tree[x].lc #define rc(x) tree[x].rc #define val(x) tree[x].val }tree[maxn<<5]; inline int build(int l,int r){ int p=++tot; if(l==r){val(p)=a[l];return p;} int mid=(l+r)>>1; lc(p)=build(l,mid); rc(p)=build(mid+1,r); return p; } inline int insert(int now,int l,int r,int x,int dat){ int p=++tot; tree[p]=tree[now]; if(l==r){val(p)=dat;return p;} int mid=(l+r)>>1; if(x<=mid) lc(p)=insert(lc(now),l,mid,x,dat); else rc(p)=insert(rc(now),mid+1,r,x,dat); return p; } inline int ask(int p,int l,int r,int x){ if(l==r) return val(p); int mid=(l+r)>>1; if(x<=mid) return ask(lc(p),l,mid,x); else return ask(rc(p),mid+1,r,x); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&a[i]); root[0]=build(1,n); for(int i=1;i<=m;i++){ int v,op,x,c; scanf("%d%d%d",&v,&op,&x); if(op==1) scanf("%d",&c),root[i]=insert(root[v],1,n,x,c); else printf("%d ",ask(root[v],1,n,x)),root[i]=root[v]; } // system("pause"); return 0; }