背景:很早就想学习树链剖分,趁着最近有点自由安排的时间去学习一下,发现有个很重要的前置知识——线段树。(其实不一定是线段树,但是线段树应该是最常见的),和同学吐槽说树剖的剖和分都很死板,主要还是看线段树的维护功底。但是也要知道剖分完的结果,不然就算线段树玩得飞起,也维护不了。看了网上很多博客,都是说一个geth
,一个mark
完成树链剖分,然后映射到线段树上,进行维护,其实这只是一个大体思想,还是建议自己手动模拟一下去加深理解。
前置知识:
1、重儿子:hs[u]=v
,表示v
是u
的重儿子。意思是v
是u
的儿子中子树规模(包括自己)最大的。
轻儿子:除了重儿子的其他儿子。
2、重链:由重儿子组成的链。
轻链:除了重链的其他链。
3、顶端结点:重链的开头。
说是重轻分解,其实实质是把重链揪出来(即从轻链处砍断连接关系)连在一起拼凑成区间(同一条重链上结点编号映射到数据结构上连续),用数据结构维护,也就是说把树变成由重链组成的,只剩下重链,不考虑轻链,对于映射,同一条重链中浅结点编号小。
个人觉得这句话很通俗易懂了。~
关于树链剖分:
首先:要是一棵树。。然后有几种剖分:1、随便剖分,爱怎么编号怎么编号。2、启发式剖分(也就是常见的重轻分解)。显然!2比较科学,随便的东西肯定不稳定,就算是不随便的也不一定稳定。。(基数排序最后倒数组时的for downto...
就比for to.
稳定,而正for
显然不是随便的东西,我不会证明也一直没想明白,还望看官指点)。我们可以很简单的运用两次dfs
完成对一棵树的剖分(第一次:geth
,第二次:mark
)。第一次主要是得到深度、父亲、规模、重儿子;第二次则是将同一条重链上的结点编号在一起,对应到线段树上。(rank[]、sa[]
这两个数组和在后缀数组中一样,因为不懂后缀数组,所以用这两个提醒自己还是个弱者),以及记录重链顶端结点。
其次:其实可以说,树链剖分的题,暴力求解就是树上倍增(跑LCA
,然后沿途更新),那么如何优化?显然LCA
肯定要跑,有没有办法跑得更快?答案是肯定的,树链的剖分就是让LCA
跑得更快。显然对于V(u,v)
要么在一条重链上,要么不在一条重链上。如果在一条重链上,深度浅的就是LCA
,如果不在呢?不妨定义u
为深度更深的结点,那么倍增的思想告诉我们应该把u
跳到和v
一样浅,然后一起跳。然而轻重分解直接把u
跳到其所在重链顶端(期间维护和求解该链上的答案),判断u,v
在不在一条重链上(tp[u]==tp[v]?
),然后不断进行这个过程直到u,v
在同一重链后运用数据结构维护求解。那么我们又知道了同一条重链新编号连续,那么进行区间维护就很方便了。
最后:看各位的线段树功底了,反正笔者的线段树是很差的。。
(PS:第一次看到给20s的题,有点刺激)
Code:
#pragma comment(linkerr, "/STACK: 1024000000,1024000000")
#include <bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define eb emplace_back
#define em emplace
#define pii pair<int,int>
#define de(x) cout << #x << " = " << x << endl
#define clr(a,b) memset(a,b,sizeof(a))
#define INF (0x3f3f3f3f)
#define LINF ((long long)(0x3f3f3f3f3f3f3f3f))
#define F first
#define S second
#define lson rt<<1,l,m
#define rson rt<<1|1,m+1,r
using namespace std;
const int N = 1e5 + 15;
int n, m;
int d[N], fa[N], sz[N], hs[N];
int nw, sa[N], rk[N], tp[N];
struct Edge
{
int v, nxt;
};
Edge e[N<<1];
int h[N], ect;
void init()
{
ect = nw = 0;
clr(h,-1);
}
void _add( int u, int v )
{
e[ect].v = v;
e[ect].nxt = h[u];
h[u] = ect ++;
}
void geth( int u, int f, int de )
{
fa[u] = f;
sz[u] = 1;
hs[u] = 0;
d[u] = de;
for ( int i = h[u]; i+1; i = e[i].nxt )
{
int v = e[i].v;
if ( v == f ) continue;
geth( v, u, de+1 );
sz[u] += sz[v];
if ( sz[v] > sz[hs[u]] ) hs[u] = v;
}
}
void mark( int u, int tu )
{
tp[u] = tu;
sa[++nw] = u; rk[u] = nw;
if ( !hs[u] ) return ;
mark( hs[u], tu );
for ( int i = h[u]; i+1; i = e[i].nxt )
{
int v = e[i].v;
if ( v != fa[u] && v != hs[u] ) mark(v,v);
}
}
struct T
{
int sm, lazy, lc, rc;
};
T t[N<<2];
int A[N];
int nwlc, nwrc;
void pushup( int rt )
{
t[rt].sm = t[rt<<1].sm + t[rt<<1|1].sm;
t[rt].lc = t[rt<<1].lc;
t[rt].rc = t[rt<<1|1].rc;
if ( t[rt<<1].rc == t[rt<<1|1].lc )
t[rt].sm --;
}
void pushdown( int rt, int l, int r )
{
if ( t[rt].lazy )
{
t[rt].lazy = 0;
t[rt<<1].lazy = t[rt<<1|1].lazy = 1;
t[rt<<1].sm = t[rt<<1|1].sm = 1;
t[rt<<1].lc = t[rt<<1].rc = t[rt].lc;
t[rt<<1|1].lc = t[rt<<1|1].rc = t[rt].rc;
}
}
void build( int rt, int l, int r )
{
t[rt].lazy = 0;
if ( l == r )
{
t[rt].sm = 1;
t[rt].lc = t[rt].rc = A[sa[l]];
return ;
}
int m = (l+r) >> 1;
build(lson); build(rson); pushup(rt);
}
void update( int L, int R, int c, int rt, int l, int r )
{
if ( L <= l && r <= R )
{
t[rt].lc = t[rt].rc = c;
t[rt].sm = t[rt].lazy = 1;
return ;
}
int m = (l+r) >> 1;
pushdown(rt,l,r);
if ( L <= m ) update( L, R, c, lson );
if ( R > m ) update( L, R, c, rson );
pushup(rt);
}
int query( int L, int R, int rt, int l, int r )
{
if ( L == l ) nwlc = t[rt].lc;
if ( R == r ) nwrc = t[rt].rc;
if ( L <= l && r <= R )
return t[rt].sm;
int m = (l+r) >> 1, res = 0, lft = 0;
pushdown(rt,l,r);
if ( L <= m )
{
lft = 1;
res += query( L, R, lson );
}
if ( R > m )
{
res += query( L, R, rson );
if ( lft && t[rt<<1].rc == t[rt<<1|1].lc ) res --;
}
pushup(rt);
return res;
}
int getsum( int u, int v )
{
int lstulc, lstvlc;
lstulc = lstvlc = -1;
int res = 0;
int x = tp[u], y = tp[v];
while ( x != y )
{
if ( d[x] < d[y] ) swap(x,y), swap(u,v), swap(lstulc,lstvlc);
res += query( rk[x], rk[u], 1,1,n );
if ( nwrc == lstulc ) res --;
lstulc = nwlc;
u = fa[x]; x = tp[u];
}
if ( d[u] > d[v] ) swap(u,v), swap( lstulc, lstvlc );
res += query( rk[u], rk[v], 1,1,n );
if ( nwlc == lstulc ) res --;
if ( nwrc == lstvlc ) res --;
return res;
}
void change( int u, int v, int c )
{
int x = tp[u], y = tp[v];
while ( x != y )
{
if ( d[x] < d[y] ) swap(x,y), swap(u,v);
update( rk[x], rk[u], c, 1,1,n );
u = fa[x]; x = tp[u];
}
if ( d[u] > d[v] ) swap( u, v );
update( rk[u], rk[v], c, 1,1,n );
}
int main()
{
init();
scanf("%d%d", &n, &m);
for ( int i = 1; i <= n; i ++ )
scanf("%d", &A[i]);
for ( int i = 1, u, v; i < n; i ++ )
{
scanf("%d%d", &u, &v);
_add(u,v); _add(v,u);
}
geth(1,0,1);
mark(1,1);
build(1,1,n);
while ( m -- )
{
char s[2];
int u, v, c;
scanf("%s %d%d", s, &u, &v);
if ( s[0] == 'C' )
{
scanf("%d", &c);
change( u, v, c );
}
else
printf("%d
", getsum(u,v));
}
return 0;
}