线段树合并学习笔记
正文
线段树合并简介
把两棵权值线段树的信息合并。
把若干棵共有有m个元素的权值线段树合并,权值范围是1~n,时间复杂度大概是O((mlogn))
实现
int merge(int k1,int k2,int l,int r)//合并k1,k2两棵线段树的信息
{
if ((k1==0)||(k2==0)) return k1+k2;//若一个节点为空则返回另一个
int k=newnode();//新建节点存放合并后的信息
if (l==r)
{
tree[k].sum=tree[k1].sum+tree[k2].sum;、、合并
return k;
}
else
{
int mid=(l+r)/2;
tree[k].ch[0]=merge(tree[k1].ch[0],tree[k2].ch[0],l,mid);//合并左儿子
tree[k].ch[1]=merge(tree[k1].ch[1],tree[k2].ch[1],mid+1,r);//合并右儿子
tree[k].sum=tree[tree[k].ch[0]].sum+tree[tree[k].ch[1]].sum;//更新当前节点信息
recycle(k1);recycle(k2);//回收废节点
return k;
}
}
例题
BZOJ2212/洛谷P3521/LOJ2163
给一棵n(1≤n≤200000)个叶子的二叉树,可以交换每个点的左右子树,要求前序遍历叶子的逆序对最少。
由于交换x节点的左右儿子不会影响x祖先的答案,所以对于每个节点,分别求出交换左右儿子或不交换的最小代价,累加就是答案。
对于每个节点开一棵线段树,每次把左右儿子合并起来就是这个节点的线段树。
为了统计答案,每次合并k1与k2时,k1的右儿子的sum*k2的左儿子的sum就是不交换的答案,因为k1的右儿子(都在l~mid范围)都大于k2的左儿子(都在mid+1~r范围)。而k1的右儿子在k2的左儿子左边。所以相乘即为这个范围的逆序对个数。反之亦然。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
struct qy
{
int ch[2],sum;
};
int n,i,j,k,cnt;
long long ans,sum1,sum2;
int a[400005],fa[400005],root[400005];
int l[400005],next[400005],last[400005],tot;
qy tree[4000005];
int pool[4000005];
int read()
{
int sum=0;char ch=getchar();
while ((ch<'0')||(ch>'9')) ch=getchar();
while ((ch>='0')&&(ch<='9'))
{
sum=sum*10+ch-'0';
ch=getchar();
}
return sum;
}
void insertt(int x,int y)
{
l[++tot]=y;
next[tot]=last[x];
last[x]=tot;
}
void init(int x)
{
int t;
t=read();
cnt++;
insertt(x,cnt);
if (t==0)
init(cnt);
else
a[cnt]=t;
t=read();
cnt++;
insertt(x,cnt);
if (t==0)
init(cnt);
else
a[cnt]=t;
}
int newnode()
{
return pool[pool[0]--];
}
void recycle(int x)
{
tree[x].sum=tree[x].ch[0]=tree[x].ch[1]=0;
pool[++pool[0]]=x;
}
void insert(int k,int l,int r,int x)
{
if (l==r)
{
tree[k].sum++;
}
else
{
int mid=(l+r)/2;
if (x<=mid)
{
tree[k].ch[0]=newnode();
insert(tree[k].ch[0],l,mid,x);
}
else
{
tree[k].ch[1]=newnode();
insert(tree[k].ch[1],mid+1,r,x);
}
tree[k].sum=tree[tree[k].ch[0]].sum+tree[tree[k].ch[1]].sum;
}
}
int merge(int k1,int k2,int l,int r)
{
if ((k1==0)||(k2==0)) return k1+k2;
int k=newnode();
if (l==r)
{
tree[k].sum=tree[k1].sum+tree[k2].sum;
return k;
}
else
{
int mid=(l+r)/2;
sum1+=(long long)tree[tree[k1].ch[1]].sum*tree[tree[k2].ch[0]].sum;
sum2+=(long long)tree[tree[k1].ch[0]].sum*tree[tree[k2].ch[1]].sum;
tree[k].ch[0]=merge(tree[k1].ch[0],tree[k2].ch[0],l,mid);
tree[k].ch[1]=merge(tree[k1].ch[1],tree[k2].ch[1],mid+1,r);
tree[k].sum=tree[tree[k].ch[0]].sum+tree[tree[k].ch[1]].sum;
recycle(k1);recycle(k2);
return k;
}
}
void dg(int x)
{
if (last[x]==0)
{
root[x]=newnode();
insert(root[x],1,n,a[x]);
}
else
{
int son[3];
son[0]=0;
for (int i=last[x];i>=1;i=next[i])
{
son[++son[0]]=l[i];
dg(l[i]);
}
sum1=sum2=0;
root[x]=merge(root[son[1]],root[son[2]],1,n);
ans=ans+min(sum1,sum2);
}
}
int main()
{
freopen("read.in","r",stdin);
for (i=1;i<=4000000;i++)
{
pool[i]=4000000-i+1;
}
pool[0]=4000000;
n=read();
cnt=1;
a[1]=read();
init(1);
dg(1);
printf("%lld",ans);
}