前置芝士:链式前向星,线段树,dfs序
这里写的是重链剖分。
参考博客指盗图来源: x正义小学生x
树链剖分可以把一棵树“投影“到一个序列上,然后用线段树维护一些东西。
通过重儿子的性质来保证时间复杂度。
我们首先使用两次dfs进行预处理,将树投影到序列上。
对于一个有儿子的节点,我们定义它最大的儿子为重儿子。
图中,3,6,10,5,8就是重儿子。
我们称像1-3-6-10,2-5,4-8这样的为一条重链。
显然会形成很多条重链,每个点属于且只属于一条重链。
我们定义一个数组 (top_x) 表示 (x) 所在的重链的最浅节点。
在第一次dfs中,我们求出每个点的深度dep,父亲节点fa,子树大小sz,重儿子。
在第二次dfs中,我们求出每个点的dfs序(时间戳就是在新序列里的位置),并且保存新序列,建线段树。注意要保存每个树上的点 对应在序列里的位置 。称为 (id_x)
预处理之后,就要对付询问。
询问和修改子树很显然,是询问和修改序列 ([id[x],id[x]+sz[x]-1])
链怎么办呢?树链剖分,意思是将链剖开成多个(一条重链 或者 一条重链的一部分)。
设链的两头为 (x) 和 (y) ,
- 当 (top_x eq top_y) 选其中链头深度较大的重链,询问和修改它,不妨设链头深度大的是 (x) , 这条重链剖出来以后,(x=fa_{top_x}) , 继续循环,直到 (top_x=top_y) 。
- (top_x=top_y) 询问和修改他们所在的重链。
务必注意update的时候两点的位置!
void updrange(int x,int y,int z){
z=(z%p+p)%p;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,id[top[x]],id[x],z); x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,id[x],id[y],z);
return;
}
怎么保证时间复杂度是
- 如果你学过dsu on tree 的话就会知道,每次从u到一个轻儿子,点个数都会减半,
也就是说,根到一个点最多 (log n) 条轻边。
- 从根节点到一个点,最多有 (log n+1) 条重链。 因为最多 (log n) 条轻边来分割这些重链。
线段树有一只log,剖链会剖成的条数也有一只log,一共两只log。
时间复杂度 (O(nlog^2n))
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
typedef long long LL;
int n,m,r,p;
int e,to[N<<1],nxt[N<<1],hd[N];
int tim,sz[N],dep[N],fa[N],top[N],son[N],id[N];
LL stval[N],fnval[N];
struct pos{
int l,r;
LL sum,lazy;
}t[N<<2];
void add(int a,int b){
to[++e]=b; nxt[e]=hd[a]; hd[a]=e;
}
void pushup(int rt){
t[rt].sum=(t[rt<<1].sum+t[rt<<1|1].sum)%p;
}
void pushdown(int rt){
if(t[rt].lazy){
t[rt<<1].lazy=(t[rt<<1].lazy+t[rt].lazy)%p;
t[rt<<1].sum=(t[rt<<1].sum+1ll*(t[rt<<1].r-t[rt<<1].l+1)*t[rt].lazy%p)%p;
t[rt<<1|1].lazy=(t[rt<<1|1].lazy+t[rt].lazy)%p;
t[rt<<1|1].sum=(t[rt<<1|1].sum+1ll*(t[rt<<1|1].r-t[rt<<1|1].l+1)*t[rt].lazy%p)%p;
t[rt].lazy=0;
}
}
void build(int rt,int l,int r){
t[rt].l=l; t[rt].r=r;
if(l==r){
t[rt].sum=fnval[l];
t[rt].lazy=0;
return;
}
int mid=l+r>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int L,int R,LL val){
if(L<=l&&r<=R){
t[rt].lazy=(t[rt].lazy+val)%p;
t[rt].sum=(t[rt].sum+1ll*val*(t[rt].r-t[rt].l+1)%p)%p;
return;
}
pushdown(rt);
int mid=l+r>>1;
if(L<=mid) update(rt<<1,l,mid,L,R,val);
if(R>mid) update(rt<<1|1,mid+1,r,L,R,val);
pushup(rt);
}
LL query(int rt,int l,int r,int L,int R){
LL ret=0;
if(L<=l&&r<=R) return t[rt].sum;
pushdown(rt);
int mid=l+r>>1;
if(L<=mid) ret=(ret+query(rt<<1,l,mid,L,R))%p;
if(R>mid) ret=(ret+query(rt<<1|1,mid+1,r,L,R))%p;
pushup(rt);
return ret;
}
void dfs1(int u,int fat){
sz[u]=1; dep[u]=dep[fat]+1; fa[u]=fat;
for(int i=hd[u];i;i=nxt[i]){
int v=to[i]; if(v==fat) continue;
dfs1(v,u);sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
return;
}//dep,fa,子树大小(含它自己),重儿子编号son
void dfs2(int u,int topf){
id[u]=++tim; fnval[tim]=stval[u]; top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=hd[u];i;i=nxt[i]){
int v=to[i]; if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
return;
}//新编号,赋值到新编号上,所在链的顶端,处理每条链
void updrange(int x,int y,int z){
z=(z%p+p)%p;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,id[top[x]],id[x],z); x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,id[x],id[y],z);
return;
}
LL qrange(int x,int y){
LL ret=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ret=(ret+query(1,1,n,id[top[x]],id[x]))%p; x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ret=(ret+query(1,1,n,id[x],id[y]))%p;
return (ret+p)%p;
}
void updson(int x,int z){
update(1,1,n,id[x],id[x]+sz[x]-1,(z%p+p)%p);
return;
}
LL qson(int x){
return (query(1,1,n,id[x],id[x]+sz[x]-1)+p)%p;
}
int main(){
scanf("%d%d%d%d",&n,&m,&r,&p);
for(int i=1;i<=n;i++)
scanf("%lld",&stval[i]),stval[i]=(stval[i]%p+p)%p;
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
dfs1(r,0); dfs2(r,r);
build(1,1,n);
for(int i=1,tp,x,y;i<=m;i++){
LL z;
scanf("%d",&tp);
if(tp==1){
scanf("%d%d%lld",&x,&y,&z); updrange(x,y,z);
} else if(tp==2){
scanf("%d%d",&x,&y); printf("%lld
",qrange(x,y));
} else if(tp==3){
scanf("%d%lld",&x,&z); updson(x,z);
} else{
scanf("%d",&x); printf("%lld
",qson(x));
}
}
return 0;
}