主席树复习笔记
首先简单复习一下之前学过的主席树和线段树合并的题目
因为主席树打的比较熟,所以就稍微简单一些吧。
LuoguP3834 【模板】可持久化线段树 1(主席树)
非常形象的一个图
本质是一颗权值线段树?
我们每次加入一个点,发现最多只有(log)个点会受到影响,所以我们就把这(log)个点单独建出来,其余的直接套用之前的就好了
这大概就是主席树的核心思想了?
#include<cstdio>
#include<cctype>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 2e5 + 3;
struct node{
int sum;
int lc,rc;
}a[N * 20];
int t,n,m;
int A[N],B[N];
int rt[N];
inline int read(){
int v = 0,c = 1;char ch = getchar();
while(!isdigit(ch)){
if(ch == '-') c = -1;
ch = getchar();
}
while(isdigit(ch)){
v = v * 10 + ch - 48;
ch = getchar();
}
return v * c;
}
inline void ins(int &u,int l,int r,int x){
/// printf("%d %d %d %d
",u,l,r,x);
a[++t] = a[u];
u = t;
if(l == r){
a[u].sum++;
return ;
}
int mid = (l + r) >> 1;
if(x <= mid) ins(a[u].lc,l,mid,x);
else ins(a[u].rc,mid + 1,r,x);
a[u].sum = a[a[u].lc].sum + a[a[u].rc].sum;
}
inline int query(int u1,int u2,int l,int r,int p){
//printf("%d %d")
if(l == r) return l;
int w = a[a[u2].lc].sum - a[a[u1].lc].sum;
int mid = (l + r) >> 1;
if(w >= p) return query(a[u1].lc,a[u2].lc,l,mid,p);
else return query(a[u1].rc,a[u2].rc,mid + 1,r,p - w);
}
int main(){
n = read(),m = read();
for(int i = 1;i <= n;++i) B[i] = A[i] = read();
sort(B + 1,B + n + 1);
B[0] = unique(B + 1,B + n + 1) - B - 1;
// for(int i = 1;i <= B[0];++i) printf("%d ",B[i]);puts("");
for(int i = 1;i <= n;++i){
rt[i] = rt[i - 1];
A[i] = lower_bound(B + 1,B + B[0] + 1,A[i]) - B;
// cout << A[i] << endl;
ins(rt[i],1,B[0],A[i]);
}
while(m--){
int l = read(),r = read(),p = read();
printf("%d
",B[query(rt[l - 1],rt[r],1,B[0],p)]);
}
return 0;
}
线段树合并
线段树合并可以看做是树上启发式合并的一种优化。时间复杂度我不会证明
但应该不会比启发式合并慢
我们在树上的每个点都维护一颗线段树,在dfs中将其与儿子的信息合并为一颗
合并完成后,该线段树维护的就是子树信息了。
这也是线段树合并大多数都是离线的原因。因为线段树继续向上合并时,维护的内容会改变。
这就要求在合并完成时就求出相应答案
例题
LuoguP3605
简单的线段树合并裸题(当然启发式合并也可以过去)
大概思路就和上面所说的
将儿子的线段树全部合并之后
查一下有多少比他权值大的,跟新答案就好了
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cctype>
#include<iostream>
#include<vector>
using namespace std;
const int N = 1e5 + 3;
int n,m,t;
int v[N];
int b[N];
int fa[N];
int rt[N],ans[N];
vector <int> G[N];
struct node{
int sum;
int lc,rc;
}a[N << 5];
inline int read(){
int v = 0,c = 1;char ch = getchar();
while(!isdigit(ch)){
if(ch == '-') c = -1;
ch = getchar();
}
while(isdigit(ch)){
v = v * 10 + ch - 48;
ch = getchar();
}
return v * c;
}
inline void pushup(int u){a[u].sum = a[a[u].lc].sum + a[a[u].rc].sum;}
inline void ins(int &u,int l,int r,int x){
if(!u) u = ++t;
// printf("%d %d %d %d %d
",u,l,r,x,a[u].sum);
if(l == r) {
a[u].sum++;
return ;
}
int mid = (l + r) >> 1;
if(x <= mid) ins(a[u].lc,l,mid,x);
else ins(a[u].rc,mid + 1,r,x);
pushup(u);
}
inline void merge(int &u1,int u2,int l,int r){
if(!u1) {u1 = u2;return;}
if(!u2) return;
if(l == r){
a[u1].sum += a[u2].sum;
return ;
}
int mid = (l + r) >> 1;
merge(a[u1].lc,a[u2].lc,l,mid);
merge(a[u1].rc,a[u2].rc,mid + 1,r);
pushup(u1);
}
inline int query(int u,int l,int r,int x){
if(!u) return 0;
// printf("%d %d %d %d %d
",u,l,r,a[u].sum,x);
if(l > x) return a[u].sum;
if(r <= x) return 0;
int mid = (l + r) >> 1;
if(mid <= x) return query(a[u].rc,mid + 1,r,x);
else return query(a[u].lc,l,mid,x) + query(a[u].rc,mid + 1,r,x);
}
inline void dfs(int x,int f){
fa[x] = f;
for(int i = 0;i < (int)G[x].size();++i){
int y = G[x][i];
if(y == f) continue;
dfs(y,x);
merge(rt[x],rt[y],1,b[0]);
}
// printf("xxxxxx:%d
",x);
ans[x] = query(rt[x],1,b[0],v[x]);
ins(rt[x],1,b[0],v[x]);
}
int main(){
n = read();
for(int i = 1;i <= n;++i) b[i] = v[i] = read();
sort(b + 1,b + n + 1);
b[0] = unique(b + 1,b + n + 1) - b - 1;
for(int i = 1;i <= n;++i)
v[i] = lower_bound(b + 1,b + b[0] + 1,v[i]) - b;
//for(int i = 1;i <= n;++i) printf("))%d
",v[i]);
for(int i = 2;i <= n;++i){
int x = read();
G[i].push_back(x);
G[x].push_back(i);
}
dfs(1,0);
for(int i = 1;i <= n;++i) printf("%d
",ans[i]);
return 0;
}
luoguP4556 [Vani有约会]雨天的尾
还是一道较简单的题目,思路和上一道题目一样,转化成树上差分,然后在线段树上动动手脚
#include<cstdio>
#include<cstring>
#include<cctype>
#include<vector>
#include<iostream>
#include<algorithm>
#define mk make_pair
using namespace std;
const int N = 1e5 + 3;
vector < pair<int,int> > G[N];
struct node{
int to;
int nxt;
}e[N << 1];
int head[N],fa[N][20],deep[N];
int rt[N];
int ans[N];
struct tree{
int sum;
int pos;
int lc,rc;
}a[N * 25];
int n,m,t,tot,maxx;
inline void add(int x,int y){
e[++tot].to = y;
e[tot].nxt = head[x];
head[x] = tot;
}
inline int read(){
int v = 0,c = 1;char ch = getchar();
while(!isdigit(ch)){
if(ch == '-') c = -1;
ch = getchar();
}
while(isdigit(ch)){
v = v * 10 + ch - 48;
ch = getchar();
}
return v * c;
}
inline void dfs1(int x,int f,int dep){
deep[x] = dep;
fa[x][0] = f;
for(int i = head[x];i;i = e[i].nxt){
int y = e[i].to;
if(y == f) continue;
dfs1(y,x,dep + 1);
}
}
inline int lca(int x,int y){
if(deep[x] < deep[y]) swap(x,y);
for(int i = 19;i >= 0;--i)
if(deep[fa[x][i]] >= deep[y]) x = fa[x][i];
if(x == y) return x;
for(int i = 19;i >= 0;--i)
if(fa[x][i] != fa[y][i]) x = fa[x][i],y = fa[y][i];
return fa[x][0];
}
inline void pushup(int u){
if(a[a[u].lc].sum >= a[a[u].rc].sum) a[u].sum = a[a[u].lc].sum,a[u].pos = a[a[u].lc].pos;
else a[u].sum = a[a[u].rc].sum,a[u].pos = a[a[u].rc].pos;
}
inline void ins(int &u,int l,int r,int x,int v){
if(!u) u = ++t;
if(l == r){
a[u].sum += v;
if(a[u].sum > 0) a[u].pos = x;
else a[u].pos = 0;
// printf("%d %d %d %d %d
",u,l,r,a[u].sum,a[u].pos);
return ;
}
int mid = (l + r) >> 1;
if(x <= mid) ins(a[u].lc,l,mid,x,v);
else ins(a[u].rc,mid + 1,r,x,v);
pushup(u);
//printf("%d %d %d %d %d
",u,l,r,a[u].sum,a[u].pos);
}
inline void merge(int &u1,int u2,int l,int r){
//printf("%d %d %d %d %d %d
",u1,u2,l,r,a[u1].sum,a[u2].sum);
if(!u1) {u1 = u2;return;}
if(!u2) return ;
if(l == r){
//printf("%d %d %d %d
",u1,u2,a[u1].sum,a[u2].sum);
a[u1].sum += a[u2].sum;
if(a[u1].sum > 0) a[u1].pos = l;
else a[u1].pos = 0;
return ;
}
int mid = (l + r) >> 1;
merge(a[u1].lc,a[u2].lc,l,mid);
merge(a[u1].rc,a[u2].rc,mid + 1,r);
pushup(u1);
}
inline void dfs2(int x,int f){
for(int i = head[x];i;i = e[i].nxt){
int y = e[i].to;
if(y == f) continue;
dfs2(y,x);
// printf("x:%d y:%d
",x,y);
merge(rt[x],rt[y],1,maxx);
}
// printf("now:%d
",x);
for(int i = 0;i < (int)G[x].size();++i){
int z = G[x][i].first,s = G[x][i].second;
// printf("%d %d
",z,s);
ins(rt[x],1,maxx,z,s);
}
ans[x] = a[rt[x]].pos;
}
int main(){
// freopen("data.in","r",stdin);
//freopen("data2.out","w",stdout);
n = read(),m = read();
for(int i = 1;i < n;++i){
int x = read(),y = read();
add(x,y);
add(y,x);
}
dfs1(1,0,1);
// for(int i = 1;i <= n;++i) printf("%d
",deep[i]);
for(int j = 1;j < 20;++j)
for(int i = 1;i <= n;++i)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
while(m--){
int x = read(),y = read(),z = read();
maxx = max(maxx,z);
G[x].push_back(mk(z,1));
G[y].push_back(mk(z,1));
int LCA = lca(x,y);
G[LCA].push_back(mk(z,-1));
// printf("%d %d
",LCA,fa[LCA][0]);
G[fa[LCA][0]].push_back(mk(z,-1));
}
dfs2(1,0);
//printf("%d %d
",a[rt[3]].sum,a[rt[3]].pos);
for(int i = 1;i <= n;++i) printf("%d
",ans[i]);
return 0;
}
写的挺简略的,数据结构的题目主要还是要以刷题为主吧