树链剖分
树链剖分是搞啥的?
一种类似于原子崩坏的超能力。
先把树炸成好多好多条链。然后把这些链组装起来,以进行路径修改,路径查询。
几个概念
- 重儿子,轻儿子
- 轻边
- 重边,重链
请尽情发挥艺术天赋,动手画棵树。
两个看起来没什么用的事实
-
v
是u
的一个轻儿子:size[v]*2 < size[u]
-
在DFS的时候,如果我们先采访重儿子,那么重链的DFS序连续。
一个看起来还是没什么用的事实
x
到根的路径中,重链,轻边的条数都是log(n)
级别的
假的证明:
我们从x
向根节点走。
根据上面那条性质: 树的总结点个数 <= pow(2, x到根,轻边的个数)
所以x
到根路径上,轻边的条树,是log
级别的。
而,两条重链之间,至少有一条轻边。所以重链的条数,也是log
级别的。
似乎还是有那么一点点用的
推广一下上面的结论。
( ⊙ o ⊙ )!任意两点之间的路径,重链,轻边的条数都是log(n)
级别的哎!
只要我们要维护的信息满足区间可合并【区间最值,区间和,区间GCD】
我们就可以把一条链,分成log
条。
又每一条链DFS序连续。因此我们可以用线段树维护每一条链。
容易搞错的地方
dfn[x]
与x
区分开、- 在找两个点的LCA的时候,
top
深度高的先往上爬。
几个栗子
POJ3237
题意:区间更新,区间取反,区间最值
我发明了一种神妙的战法。在push_up
的时候把懒惰标记也传上去了。
调代码的时候,
WA! 懒惰标记上天啦!
我的内心是崩溃的.....
所以,我是来搞笑的吗?
code
#include <iostream>
#include <vector>
#include <cstdio>
using namespace std;
typedef pair<int,int> pii;
const int N = 200000+10;
const int INF = 1e9+7;
vector<pii> g[N];
int edge_id[N];
int size[N];
int dep[N];
int son[N], par[N], dis[N];
int dfn[N], moment, who[N];
int top[N];
int T, n;
int u[N], v[N], w[N];
void init() {
moment = 0;
for(int i=0;i<N;i++)
g[i].clear();
}
void presolve(int u,int p) {
size[u] = 1;
par[u] = p;
int maxSize=0, id=-1;
for(int i=0;i<g[u].size();i++) {
int v=g[u][i].first;
if(v==p) continue;
dis[v] = g[u][i].second;
dep[v] = dep[u]+1;
presolve(v, u);
size[u] += size[v];
if(size[v] > maxSize)
maxSize = size[v], id = v;
}
son[u] = id;
}
void dfs(int u,int p) {
dfn[u] = ++moment;
who[moment] = u;
top[u] = p;
//printf("u = %d, p = %d
", u, p);
if (son[u] != -1) {
//printf("%d -> %d
",u, son[u]);
dfs(son[u], p);
}
for(int i=0;i<g[u].size();i++) {
int v=g[u][i].first;
if(v==son[u] || v==par[u]) continue;
dfs(v, v);
}
}
struct Data {
int l, r;
int mx, mn;
int ne;
Data operator + (const Data & o) const {
Data ans = o; ans.ne = 0;
ans.mx = max(o.mx, mx);
ans.mn = min(o.mn, mn);
return ans;
}
} nod[N<<2];
void build(int l,int r,int rt) {
nod[rt].l = l, nod[rt].r = r; nod[rt].ne = 0;
if (l==r) {
nod[rt].mn = nod[rt].mx = dis[who[l]];
return;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
void push_down(int rt) {
if (nod[rt].ne) {
//printf("##### rt = %d
", rt);
nod[rt<<1].mx = -nod[rt<<1].mx;
nod[rt<<1].mn = -nod[rt<<1].mn;
swap(nod[rt<<1].mx, nod[rt<<1].mn);
nod[rt<<1|1].mx = -nod[rt<<1|1].mx;
nod[rt<<1|1].mn = -nod[rt<<1|1].mn;
swap(nod[rt<<1|1].mx, nod[rt<<1|1].mn);
nod[rt<<1].ne ^= 1;
nod[rt<<1|1].ne ^= 1;
nod[rt].ne = 0;
}
}
int query(int l,int r,int rt,int L,int R) {
//printf("gg
");
//printf("! %d %d %d %d
", l,r,L,R);
if (L<=l&&r<=R) {
//printf("rt=%d, [%d, %d] mx=%d
", rt,l,r,nod[rt].mx);
return nod[rt].mx;
}
//if (rt!=8)
push_down(rt);
int mid = (l+r)>>1;
int ans = - INF;
if (L <= mid) ans = query(l,mid,rt<<1,L,R);
if (R > mid) ans = max(ans, query(mid+1,r,rt<<1|1,L,R));
//printf("%d %d %d %d %d
", l,r,L,R,ans);
return ans;
}
void update(int l,int r,int rt,int pos,int x) {
if (l == r) {
nod[rt].mx = nod[rt].mn = x;
nod[rt].ne = 0;
return;
}
//if (rt!=8)
push_down(rt);
int mid = (l+r)>>1;
if (pos <= mid)
update(l,mid,rt<<1,pos,x);
else
update(mid+1,r,rt<<1|1,pos,x);
nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
void neg_update(int l,int r,int rt,int L,int R) {
//printf("%d %d %d %d
", l,r, L,R);
if (L<=l&&r<=R) {
nod[rt].ne ^= 1;
nod[rt].mn *= -1;
nod[rt].mx *= -1;
//printf("rt=%d [%d, %d] %d %d, ne=%d
", rt,nod[rt].l, nod[rt].r, nod[rt].mn, nod[rt].mx, nod[rt].ne);
swap(nod[rt].mn, nod[rt].mx);
return;
}
//if(rt!=8)
push_down(rt);
int mid = (l+r)>>1;
if (L <= mid)
neg_update(l,mid,rt<<1,L,R);
if (R > mid)
neg_update(mid+1,r,rt<<1|1,L,R);
nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
int max_on_path(int u, int v) {
int ans = - INF;
int f1 = top[u], f2 = top[v];
while (f1 != f2) {
//printf("%d %d
", f1, f2);
if (dep[f1] < dep[f2]) {
swap(f1, f2);
swap(u, v);
}
//printf("[%d, %d]
", dfn[f1], dfn[u]);
ans = max(ans, query(1,n,1,dfn[f1],dfn[u]));
//printf("%d
", ans);
u = par[f1]; f1 = top[u];
//printf("%d %d %d
", u, f1, f2);
}
if (u == v) return ans;
if (dep[u] > dep[v]) swap(u, v);
ans = max(ans, query(1,n,1,dfn[son[u]],dfn[v]));
return ans;
}
void update_on_tree(int u, int v) {
int f1 = top[u], f2 = top[v];
while (f1 != f2) {
if (dep[f1] < dep[f2]) {
swap(f1, f2);
swap(u, v);
}
neg_update(1,n,1,dfn[f1],dfn[u]);
u = par[f1]; f1 = top[u];
}
if (u == v) return;
if (dep[u] > dep[v]) swap(u, v);
neg_update(1,n,1,dfn[son[u]], dfn[v]);
}
int main() {
scanf("%d", &T);
while (T --) {
scanf("%d", &n);
init();
for (int i=1;i<n;i++) {
scanf("%d %d %d", &u[i], &v[i], &w[i]);
g[u[i]].push_back(make_pair(v[i], w[i]));
g[v[i]].push_back(make_pair(u[i], w[i]));
}
presolve(1, 1);
dfs(1, 1);
for (int i=1;i<n;i++) {
if (par[u[i]] == v[i])
edge_id[i] = dfn[u[i]];
else
edge_id[i] = dfn[v[i]];
}
/*
for(int i=1;i<=n;i++) {
printf("i = %d
", i);
printf("dfn = %d, son = %d, par = %d, dis = %d, top = %d
", dfn[i],son[i],par[i],dis[i],top[i]);
}
*/
build(1, n, 1);
char op[5]; int a, b;
while (scanf("%s", op)) {
if (op[0] == 'D') break;
scanf("%d %d", &a, &b);
if (op[0] == 'C') {
update(1,n,1,edge_id[a],b);
}
if (op[0] == 'Q') {
int ans = max_on_path(a, b);
if (a == b) ans = 0;
printf("%d
", ans);
}
if (op[0] == 'N') {
update_on_tree(a, b);
}
}
}
}
Gym 101741 C
题意
给出一棵树,与m条路径,选出最小的点集,使得每条路径都包含点集中的点。
做法
-
在一条链上,就是个按右端点排序的贪心。
-
把路径按照LCA的深度从高到低排序。
-
然后遍历所有路径,如果路径上没有选择的点,那就选择LCA。否则,什么都不做。这个可以通过单点更新,区间查询来实现。
code
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 400000+10;
int n;
int son[N],size[N],par[N],top[N],dep[N];
int dfn[N],who[N],moment;
vector<int> g[N];
void preSovle(int u,int p) {
size[u]=1;
par[u] =p;
dep[u] =dep[p]+1;
int mx=0, bst=-1;
for(int i=0;i<g[u].size();i++) {
int v=g[u][i];
if (v==p) continue;
preSovle(v, u);
size[u] += size[v];
if (size[v] > mx) {
mx = size[v];
bst = v;
}
}
son[u] = bst;
}
void dfs(int u,int p) {
top[u]=p;
dfn[u]=++moment;
who[moment]=u;
if(son[u]!=-1){
dfs(son[u],p);
}
for(int i=0;i<g[u].size();i++) {
int v=g[u][i];
if(v!=par[u]&&v!=son[u])
dfs(v,v);
}
}
struct Query {
int u,v;
int lca;
bool operator < (const Query & o) const {
return dep[lca] > dep[o.lca];
}
} q[N];
int sum[N<<2];
void update(int l,int r,int rt,int pos){
if(l==r) {
sum[rt] ++;
return;
}
int mid=(l+r)>>1;
if (pos<=mid) update(l,mid,rt<<1,pos);
else update(mid+1,r,rt<<1|1,pos);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
int query(int l,int r,int rt,int L,int R){
if(L<=l&&r<=R) {
return sum[rt];
}
int mid=(l+r)>>1;
int ans=0;
if(L<=mid) ans+=query(l,mid,rt<<1,L,R);
if(R >mid) ans+=query(mid+1,r,rt<<1|1,L,R);
return ans;
}
pair<int,int> sum_path(int u,int v) { // (lca, sum)
int f1=top[u], f2=top[v];
int ans=0;
while(f1!=f2) {
//printf("%d %d
", f1, f2);
if (dep[f1] < dep[f2]) {
swap(f1,f2);
swap(u,v);
}
ans+=query(1,n,1,dfn[f1],dfn[u]);
u=par[f1], f1=top[u];
//printf("%d %d
", f1, f2);
}
if(dep[u]>dep[v]) swap(u,v);
ans+=query(1,n,1,dfn[u],dfn[v]);
return make_pair(u, ans);
}
int main() {
scanf("%d",&n);
for(int i=1;i<n;i++) {
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
preSovle(1,1);
dfs(1,1);
int m; scanf("%d", &m);
for(int i=1;i<=m;i++) {
scanf("%d %d", &q[i].u, &q[i].v);
q[i].lca = sum_path(q[i].u,q[i].v).first;
//printf("ok
");
}
sort(q+1,q+1+m);
vector<int> ret;
for(int i=1;i<=m;i++) {
//printf("%d
", q[i].lca);
if(sum_path(q[i].u, q[i].v).second == 0) {
ret.push_back(q[i].lca);
update(1,n,1,dfn[q[i].lca]);
}
}
printf("%d
", ret.size());
for(auto x: ret) {
printf("%d ", x);
}
}
BZOJ2243
题意
一棵树,路径赋值,路径颜色段数查询。
1121
:有三段
1121233
:有五段
题解
线段树维护:
- 区间左端的颜色
- 区间右端的颜色
- 区间内总共有多少段颜色
信息可以合并哎!
code
写残了...
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int N = 400000+10;
vector<int> g[N];
int a[N],son[N],par[N],dep[N],sz[N],top[N],dfn[N],who[N],moment;
int n, m;
void init() {
for(int i=0;i<N;i++)
g[i].clear();
moment=0;
}
void pre(int u,int p){
par[u]=p,dep[u]=dep[p]+1,sz[u]=1;
int hson=-1,mx=0;
for(int i=0;i<g[u].size();i++) {
int v=g[u][i]; if(v==p) continue;
pre(v,u);
if(sz[v]>mx) mx=sz[v], hson=v;
sz[u]+=sz[v];
}
son[u]=hson;
}
void dfs(int u,int p){
top[u]=p;
dfn[u]=++moment, who[moment]=u;
if(son[u]!=-1) dfs(son[u],p);
for(int i=0;i<g[u].size();i++) {
int v=g[u][i]; if(v==son[u]||v==par[u]) continue;
dfs(v,v);
}
}
struct Data {
int lef;
int rig;
int tot;
int lazy;
Data operator + (const Data & o) const {
Data ret;
ret.lef = lef; ret.rig = o.rig;
ret.tot = (rig == o.lef) ? (tot+o.tot-1) : (tot+o.tot);
ret.lazy = 0;
return ret;
}
} nod[N<<2];
Data rev(Data nod) {
swap(nod.lef, nod.rig);
return nod;
}
Data set_value(int rt, int x) {
Data ans;
ans.lef = ans.rig = x;
ans.tot = 1; ans.lazy = 0;
return ans;
}
void push_down(int rt) {
if (nod[rt].lazy) {
nod[rt<<1] = set_value(rt<<1, nod[rt].lazy);
nod[rt<<1].lazy = nod[rt].lazy;
nod[rt<<1|1]=set_value(rt<<1|1,nod[rt].lazy);
nod[rt<<1|1].lazy = nod[rt].lazy;
nod[rt].lazy = 0;
}
}
void build(int l,int r,int rt) {
nod[rt].lazy = 0;
if (l==r) {
nod[rt].lef = nod[rt].rig = a[who[l]];
nod[rt].tot = 1;
return;
}
int mid = (l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
void update(int l,int r,int rt,int L,int R,int x) {
if (L<=l&&r<=R) {
nod[rt]=set_value(rt, x);
nod[rt].lazy = x;
return;
}
push_down(rt);
int mid = (l+r)>>1;
if (L<=mid) update(l,mid,rt<<1,L,R,x);
if (R >mid) update(mid+1,r,rt<<1|1,L,R,x);
nod[rt] = nod[rt<<1] + nod[rt<<1|1];
}
Data query(int l,int r,int rt,int L,int R) {
if (L<=l&&r<=R) {
return nod[rt];
}
push_down(rt);
int mid=(l+r)>>1;
if (L<=mid && R<=mid) return query(l,mid,rt<<1,L,R);
if (L>mid && R>mid) return query(mid+1,r,rt<<1|1,L,R);
return query(l,mid,rt<<1,L,R) + query(mid+1,r,rt<<1|1,L,R);
}
int nex[N];
int get_path(int u, int v) {
int cu = u, cv = v;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u,v), swap(cu, cv);
Data tmp = query(1,n,1,dfn[top[u]], dfn[u]);
nex[u] = top[u];
u = par[top[u]];
}
if (dep[u] > dep[v]) swap(u, v), swap(cu, cv);
nex[v] = u;
Data ans, ans_;
bool find = 0, find_ = 0;
while(cu != u) {
if (find == 1)
ans = query(1,n,1,dfn[nex[cu]], dfn[cu]) + ans;
else
find = 1, ans = query(1,n,1,dfn[nex[cu]],dfn[cu]);
cu = par[top[cu]];
}
ans = rev(ans);
if (find == 1)
ans = ans + query(1,n,1,dfn[u],dfn[v]);
else
find = 1, ans = query(1,n,1,dfn[u],dfn[v]);
while(cv != v) {
if (find_ == 1)
ans_ = query(1,n,1,dfn[nex[cv]], dfn[cv]) + ans_;
else
find_ = 1, ans_ = query(1,n,1,dfn[nex[cv]],dfn[cv]);
cv = par[top[cv]];
}
if (find && find_)
ans = ans + ans_;
else if (find_)
ans = ans_;
else if (find)
ans = ans;
return ans.tot;
}
void modidy(int u,int v,int x) {
while(top[u]!=top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u,v);
update(1,n,1,dfn[top[u]],dfn[u],x);
u=par[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
update(1,n,1,dfn[u],dfn[v],x);
}
int main() {
while (~ scanf("%d %d", &n, &m)) {
init();
for (int i=1;i<=n;i++)
scanf("%d", &a[i]);
for (int i=1;i<n;i++) {
int u, v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
pre(1,1);
dfs(1,1);
build(1,n,1);
for(int i=1;i<=m;i++) {
char op[2]; int u, v, x;
scanf("%s%d%d",op,&u,&v);
if (op[0]=='Q') {
printf("%d
", get_path(u,v));
} else {
scanf("%d", &x); modidy(u,v,x);
}
}
}
}