模版 动态 dp
终于来写这个东西了。。
LG 模版:给定 n 个点的数,点有点权, $ m $ 次修改点权,求修改完后这个树的最大独立集大小。
我们先来考虑朴素的最大独立集的 dp
现在我们就拥有了一个 $ O(nm) $ 的做法,但是明显它不优秀。
既然支持修改,我们可以考虑树剖(或者LCT),现在树被我们剖成了重链和轻链。此时发现我们其实可以先对轻子树算出 dp 值并且累计到当前根,再去算重链的影响,这样有什么好处呢?好处在于重链我们可以单独算,这样的话 dp 转移就是连续的。同时,当你修改一个点,它所影响的也仅仅是它到根的很多重链,不会影响到这路径上虚儿子的贡献。
但是如果按照原来的 dp 转移,修改点权仍然很困难。为此定义了一种新的矩阵乘法,让两个矩阵相乘时的乘法改成加法,加法改成最大值,也就是:
不难发现这样定义矩阵乘法后矩阵乘法仍然满足结合律。
然后我们考虑对于一个点 $ u $ ,假设它已经算完了轻儿子的贡献,计作 $ g[u][0/1] $ ,考虑从重儿子转移,那么这个点的转移矩阵就是
为什么呢?考虑 $ dp[u][0] $ 可以从什么转移过来,是由 $ dp[u][0] $ 和 $ dp[u][1] $ 中选择较大的和 $ g[u][0] $ 加起来得到的,而 $ dp[u][1] $ 由 $ dp[u][0] + g[u][1] $ 得到,故最后一个位置填 $ -infin $。
写到这里感觉 LCT 会比树剖好写的多,所以后面默认使用 LCT 了。
然后我们考虑对每个链维护它的转移矩阵的乘积。而 $ g $ 的维护就是对一个链的顶端的父亲更新 $ g $ 即可,体现在 LCT 中就是一个点维护它 splay 里面的矩阵的积。注意这里的矩阵乘法的顺序,对一个链做的时候应该从上乘到下,所以 pushup 的时候应该先左后中最后右。
如果我们需要得到一个点的 dp 值,必须把它 Access 并且 旋转到根,类似 Qtree V 的做法,不然它内部的值是假的(是子树或者splay子树内的乘积).
修改权值,比较简单的方法是 LCT 直接 旋转到根 然后直接修改 val 和矩阵就可以了。
感觉比 Red Blue Tree 好写,不用拿 BST 维护虚儿子信息。。(然后估计码一天)
代码还是很好看的(虽然调起来很烦)
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
using namespace std;
#define MAXN 100006
#define max( a , b ) ( (a) > (b) ? (a) : (b) )
#define inf 0x3f3f3f3f
int n , m;
int w[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
void ade( int u , int v ) {
to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
struct mtrx {
#define N 2
int A[2][2];
inline void in( int a , int b ) { A[0][0] = A[0][1] = a , A[1][0] = b , A[1][1] = -inf; }
inline int re( ) { return max( A[0][0] , A[1][0] ); }
inline mtrx operator * ( const mtrx& a ) const {
mtrx ret;
for( int i = 0 ; i < 2 ; ++ i ) for( int j = 0 ; j < 2 ; ++ j )
ret.A[i][j] = max( A[i][0] + a.A[0][j] , A[i][1] + a.A[1][j] );
return ret;
}
};
int ch[MAXN][2] , fa[MAXN] , dp[MAXN][2];
mtrx G[MAXN];
inline bool notroot( int u ) {
return ch[fa[u]][0] == u || ch[fa[u]][1] == u;
}
inline void pu( int u ) {
G[u].in( u[dp][0] , u[dp][1] );
if( ch[u][0] ) G[u] = G[ch[u][0]] * G[u];
if( ch[u][1] ) G[u] = G[u] * G[ch[u][1]];
}
inline void rotate( int u ) {
int f = fa[u] , g = fa[f] , w = ch[f][1]==u , k = ch[u][w^1];
if(notroot( f )) ch[g][ch[g][1]==f] = u;ch[f][w] = k , ch[u][w^1] = f;
fa[f] = u , fa[k] = f , fa[u] = g;
pu( f ) , pu( u );
}
//void rotate( int x ) {
// int f = fa[x] , g = fa[f] , w = ch[fa[x]][1] == x;
// int wf = ch[g][1]==f , k = ch[x][w^1];
// if( notroot(f) ) ch[g][wf] = x; ch[f][w] = k , ch[x][w^1] = f;
// fa[f] = x , fa[k] = f , fa[x] = g;
// pu( f ) , pu( x );
//}
void splay( int x ) {
int f , g;
while( notroot( x ) ) {
f = fa[x] , g = fa[f];
if( notroot( f ) )
rotate( (ch[f][0]==x)^(ch[g][0]==f) ? x : f );
rotate( x );
}
}
void access( int x ) {
for( int p = 0 ; x ; ( p = x , x = fa[x] ) ) {
splay(x);
if( ch[x][1] ) // Heavy -> Light
x[dp][0] += G[ch[x][1]].re() , x[dp][1] += G[ch[x][1]].A[0][0];
if( p ) // Light -> Heavy
x[dp][0] -= G[p].re() , x[dp][1] -= G[p].A[0][0];
ch[x][1] = p;
pu(x);
}
}
void pre( int u , int f ) {
dp[u][1] = w[u];
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == f ) continue;
pre( v , u );
fa[v] = u;
u[dp][0] += max( v[dp][0] , v[dp][1] );
u[dp][1] += v[dp][0];
}
G[u].in( u[dp][0] , u[dp][1] );
}
int main() {
// freopen("in.in","r",stdin);
cin >> n >> m;
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&w[i]);
for( int i = 1 , u , v ; i < n ; ++ i )
scanf("%d%d",&u,&v) , ade( u , v ) , ade( v , u );
pre( 1 , 1 );
// for( int i = 1 ; i <= n ; ++ i ) printf("%d
",G[i].re());
int u , v;
while( m-- ) {
scanf("%d%d",&u,&v);
access( u );
splay( u );
dp[u][1] += v - w[u] , w[u] = v;
pu( u );
printf("%d
",G[u].re());
// for( int i = 1 ; i <= n ; ++ i ) printf("%d ",fa[i]);
// puts("");
}
}
一个例题 NOIP 2018 保卫王国
其实就是板子题,每次询问,强制选本质上就是权值 inf 强制不选择就是 -inf
中间挂了几次。。这题 longlong 得注意。。(矩阵返回值没开LongLong然后wa了两发。。)
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
using namespace std;
#define MAXN 400006
#define max( a , b ) ( (a) > (b) ? (a) : (b) )
#define inf (1ll<<60)
int n , m;
int w[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
void ade( int u , int v ) {
to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
struct mtrx {
#define N 2
long long A[2][2];
inline void in( long long a , long long b ) { A[0][0] = A[0][1] = a , A[1][0] = b , A[1][1] = -inf; }
inline long long re( ) { return max( A[0][0] , A[1][0] ); }
inline mtrx operator * ( const mtrx& a ) const {
mtrx ret;
for( int i = 0 ; i < 2 ; ++ i ) for( int j = 0 ; j < 2 ; ++ j )
ret.A[i][j] = max( A[i][0] + a.A[0][j] , A[i][1] + a.A[1][j] );
return ret;
}
};
int ch[MAXN][2] , fa[MAXN]; long long dp[MAXN][2];
mtrx G[MAXN];
inline bool notroot( int u ) {
return ch[fa[u]][0] == u || ch[fa[u]][1] == u;
}
inline void pu( int u ) {
G[u].in( u[dp][0] , u[dp][1] );
if( ch[u][0] ) G[u] = G[ch[u][0]] * G[u];
if( ch[u][1] ) G[u] = G[u] * G[ch[u][1]];
}
inline void rotate( int u ) {
int f = fa[u] , g = fa[f] , w = ch[f][1]==u , k = ch[u][w^1];
if(notroot( f )) ch[g][ch[g][1]==f] = u;ch[f][w] = k , ch[u][w^1] = f;
fa[f] = u , fa[k] = f , fa[u] = g;
pu( f ) , pu( u );
}
void splay( int x ) {
int f , g;
while( notroot( x ) ) {
f = fa[x] , g = fa[f];
if( notroot( f ) )
rotate( (ch[f][0]==x)^(ch[g][0]==f) ? x : f );
rotate( x );
}
}
void access( int x ) {
for( int p = 0 ; x ; ( p = x , x = fa[x] ) ) {
splay(x);
if( ch[x][1] ) // Heavy -> Light
x[dp][0] += G[ch[x][1]].re() , x[dp][1] += G[ch[x][1]].A[0][0];
if( p ) // Light -> Heavy
x[dp][0] -= G[p].re() , x[dp][1] -= G[p].A[0][0];
ch[x][1] = p;
pu(x);
}
}
void pre( int u , int f ) {
dp[u][1] = w[u];
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == f ) continue;
pre( v , u );
fa[v] = u;
u[dp][0] += max( v[dp][0] , v[dp][1] );
u[dp][1] += v[dp][0];
}
G[u].in( u[dp][0] , u[dp][1] );
}
long long S;
void mdfy( int u , long long x ) {
access( u ) , splay( u );
dp[u][1] += x;
pu( u );
}
int main() {
freopen("defense.in","r",stdin);
freopen("defense.out","w",stdout);
cin >> n >> m;
scanf("%*s");
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&w[i]) , S += w[i];
for( int i = 1 , u , v ; i < n ; ++ i )
scanf("%d%d",&u,&v) , ade( u , v ) , ade( v , u );
pre( 1 , 1 );
// for( int i = 1 ; i <= n ; ++ i ) printf("%d
",G[i].re());
int u , v , x1 , x2;
while( m-- ) {
scanf("%d%d%d%d",&u,&x1,&v,&x2);
mdfy( u , x1 ? -inf : inf ) , mdfy( v , x2 ? -inf : inf );
S += ( ( x1 ^ 1 ) + ( x2 ^ 1 ) ) * inf;
printf("%lld
",S - G[v].re() > 1e10 ? -1 : S - G[v].re());
S -= ( ( x1 ^ 1 ) + ( x2 ^ 1 ) ) * inf;
mdfy( u , x1 ? inf : -inf ) , mdfy( v , x2 ? inf : -inf );
}
}