zoukankan      html  css  js  c++  java
  • 动态dp 学习笔记

    前言

    被迫营业*2,不过去仔细学一下也挺好的。为了营业去学了好多新东西(((

    由于本人水平有限,如有不严谨的地方还请指出。

    ddp 主要用来处理树上dp问题,有时候出题人比较恶心带上修改,ddp就是用来支持快速修改的。

    使用的前提是转移比较简洁,可以写成矩阵,基本上都是单点修改。

    我感觉这个算法还是看题比较好理解。

    例题一

    这里会介绍三种常见的ddp维护方法。

    P4719 【模板】"动态 DP"&动态树分治

    首先考虑不带修改的情况。

    (f(u, 0 / 1)) 分别表示:强制选择 (u) 这个点,强制不选择 (u) 这个点 时,以 (u) 为根的子树的最大独立集。

    这时候以 (u) 为根的子树的答案就是 (max(f(u,0),f(u,1)))

    有转移

    [f(u,0) = sum_{vin son(u)} max(f(v, 0),f(v, 1))\ f(u,1) = a_u + sum_{vin son(u)} f(v, 0) ]

    接下去引入 ddp 的思想。

    考虑将重儿子和轻儿子分开考虑。

    (g(u, 0 / 1)) 表示,只考虑 (u) 的轻儿子,强制选 (u) 或不选 (u) 这个点时,以 (u) 为根的子树的最大独立集。

    (wson(u))(u) 的重儿子。

    可以得到

    [g(u, 0) = sum_{vin son(u), v ot = wson(u)} max(f(v,0),f(v,1))\ g(u, 1) = a_u + sum_{vin son(u), v ot = wson(u)} f(v,0)\ f(u, 0) = g(u, 0) + max(f(son_u, 0), f(son_u, 1))\ f(u, 1) = g(u, 1) + f(son_u, 0) ]

    注意转移方程中,把 (a_u) 的贡献加到了 (g(u,1)) 中,不然第四条多个 (a_u) 方程不够简洁,不方便写成矩阵。

    (f) 的转移写成矩阵:

    [egin{bmatrix} g(u,0)&g(u,0)\ g(u,1)&-infty end{bmatrix} * egin{bmatrix} f(son_u,0)\ f(son_u,1) end{bmatrix} = egin{bmatrix} f(u,0)\ f(u,1) end{bmatrix} ]

    注意上面的 (*) 是广义矩阵乘法: (a_{i,j}=max{b_{i,k}+c_{k,j}}),这个运算符也是有结合律的。

    一开始的时候我们先一趟dp求出 (f,g)

    这时候对于点 (u) 为根的子树查询答案会非常方便:设 (u) 所在重链底端是 (End(u)),我们把 (u)(End(u)) 这段区间的矩阵全部按顺序乘起来就好了。

    可以脑补一下这个过程:重链底端是个叶子,然后不断加入重链周围的轻子树以及重儿子,拼凑成了整颗子树。

    考虑如何带上修改。

    假设修改了点 (u)

    直接影响到的是 (f(u,1),g(u,1))

    接着可以想象,往祖先走的时候同重链的 (f) 都能被矩阵直接更新,那么影响到的就是所有轻边的转移。

    于是考虑计算更新之后对于轻边父亲的 (g) 的贡献。

    发现 (g) 的转移和 (f) 有关,并且我们可以知道这条重链的 (f) 以及更新前重链的 (f),那么把之前的贡献减掉,把现在的贡献加上,就更新完毕了。

    快速查询两点间矩阵乘积可以通过 树链剖分+线段树维护区间矩阵乘积 来维护,而且修改也可以通过跳轻边很方便地维护。

    总共会跳到 (O(log n)) 条轻边,还有每跳一次线段树修改的 (O(log n)),总复杂度是 (O(nlog^2 n))

    实现的时候建议这种小型矩阵手动展开,常数上可以减小好多。

    远古代码。但是好像不是特别丑

    Code
    #include<bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    typedef double db;
    #define pb(x) push_back(x)
    #define mkp(x,y) make_pair(x,y)
    inline int read() {
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
    	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
    	return x*f;
    }
    const int N=100005;
    const int M=N<<2;
    const int inf=1e8;
    int n,m,a[N];
    int siz[N],dfn[N],tmr,son[N],fa[N],top[N],rev[N],ed[N],f[N][2];
    struct edge{
    	int nxt,to;
    }e[N<<1];
    int head[N],num_edge;
    void addedge(int fr,int to){
    	++num_edge;
    	e[num_edge].nxt=head[fr];
    	e[num_edge].to=to;
    	head[fr]=num_edge;
    }
    struct Matrix{
    	int a[2][2];
    	Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
    	int*operator[](const int&k){return a[k];}
    	Matrix operator * (const Matrix&b){
    		Matrix res;
    //		for(int i=0;i<2;++i)
    //			for(int j=0;j<2;++j)
    //				for(int k=0;k<2;++k)
    //					res.a[i][j]=max(res.a[i][j],a[i][k]+b.a[k][j]);
    		res[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
    		res[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
    		res[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
    		res[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
    		return res;
    	}
    }mat[N],val[M];
    void dfs1(int u,int ft){
    	siz[u]=1,f[u][1]=a[u];
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;if(v==ft)continue;
    		fa[v]=u,dfs1(v,u),siz[u]+=siz[v];
    		if(siz[v]>siz[son[u]])son[u]=v;
    		f[u][0]+=max(f[v][0],f[v][1]);
    		f[u][1]+=f[v][0];
    	}
    }
    void dfs2(int u,int tp){
    	top[u]=tp,dfn[u]=++tmr,rev[tmr]=u;
    	if(son[u])dfs2(son[u],tp),ed[u]=ed[son[u]];
    	else ed[u]=u;
    	int g[2];g[0]=0,g[1]=a[u];
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==son[u]||v==fa[u])continue;
    		dfs2(v,v);
    		g[0]+=max(f[v][0],f[v][1]);
    		g[1]+=f[v][0];
    	}
    	mat[u][0][0]=g[0],mat[u][0][1]=g[0];
    	mat[u][1][0]=g[1],mat[u][1][1]=-inf;
    }
    #define lc (p<<1)
    #define rc (p<<1|1)
    void pushup(int p){val[p]=val[lc]*val[rc];}
    void build(int l,int r,int p){
    	if(l==r)return val[p]=mat[rev[l]],void();
    	int mid=(l+r)>>1;
    	build(l,mid,lc),build(mid+1,r,rc);
    	pushup(p);
    }
    Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
    	if(ql<=l&&r<=qr)return val[p];
    	int mid=(l+r)>>1;
    	if(qr<=mid)return query(ql,qr,l,mid,lc);
    	if(mid<ql)return query(ql,qr,mid+1,r,rc);
    	return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
    }
    void change(int pos,int l=1,int r=n,int p=1){
    	if(l==r)return val[p]=mat[rev[l]],void();
    	int mid=(l+r)>>1;
    	if(pos<=mid)change(pos,l,mid,lc);
    	else change(pos,mid+1,r,rc);
    	pushup(p);
    }
    void update(int x,int v){
    	mat[x][1][0]+=v-a[x],a[x]=v;
    	while(x){
    		Matrix lst=query(dfn[top[x]],dfn[ed[x]]);
    		change(dfn[x]);
    		Matrix now=query(dfn[top[x]],dfn[ed[x]]);
    		x=fa[top[x]];
    		mat[x][0][0]+=max(now[0][0],now[1][0])-max(lst[0][0],lst[1][0]);
    		mat[x][0][1]=mat[x][0][0];
    		mat[x][1][0]+=now[0][0]-lst[0][0];
    	}
    }
    signed main(){
    	n=read(),m=read();
    	for(int i=1;i<=n;++i)a[i]=read();
    	for(int i=1;i<n;++i){
    		int x=read(),y=read();
    		addedge(x,y),addedge(y,x);
    	}
    	dfs1(1,0),dfs2(1,1),build(1,n,1);
    	while(m--){
    		int x=read(),v=read();
    		update(x,v);
    		Matrix t=query(dfn[1],dfn[ed[1]]);
    		printf("%d
    ",max(t[0][0],t[1][0]));
    	}
    	return 0;
    }
    

    树剖+线段树的复杂度是两只 log,这使得人们思考有没有更快的方法。

    P4751 【模板】"动态DP"&动态树分治(加强版)

    可以发现上面那种方法的在做的其实就是维护链上矩阵积,维护链上信息使我们想到了 (O(nlog n)) 的 LCT。

    考虑只维护实边信息,虚儿子信息在 access 的时候更新上去。

    更新一个节点信息的时候可以先 access 再 splay,这时候修改它对于任何节点都是没有影响的,可以直接修改。

    查询一个节点的信息会有点特殊,需要执行的操作是:(access(fa_x),splay(x))。因为把 (fa_x) 以上的节点移到别的 splay 里面去,(splay(x))(x) 下面挂的节点才是 (x) 的子树内的节点。

    这里附上一份 LCT 实现,顺便去学了一下。

    成功把复杂度降掉一只 (log),LCT 的常数非常大就是了。

    说句闲话,这题貌似正常常数的 LCT 都能过去(

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp make_pair
    #define pb push_back
    #define sz(v) (int)(v).size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	return f?x:-x;
    }
    const int N = 1000005;
    const int inf = 0x3f3f3f3f;
    int n, m, a[N], lastans;
    vector<int> e[N];
    struct Matrix {
    	int a[2][2];
    	Matrix(){ memset(a, -0x3f, sizeof a); }
    	inline int* operator [](const int &k) { return a[k]; }
    	inline Matrix operator * (const Matrix &t) const {
    		Matrix res;
    		res.a[0][0] = max(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
    		res.a[0][1] = max(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
    		res.a[1][0] = max(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
    		res.a[1][1] = max(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
    		return res;
    	}
    };
    int fa[N], ch[N][2], dp[N][2];
    Matrix val[N], sum[N];
    inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
    inline void pushup(int x) {
    	sum[x] = val[x];
    	if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
    	if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
    }
    inline void rotate(int x) {
    	int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
    	if(nroot(y)) ch[z][ch[z][1] == y] = x;
    	ch[x][!k] = y, ch[y][k] = w;
    	fa[w] = y, fa[y] = x, fa[x] = z;
    	pushup(y);
    }
    inline void splay(int x) {
    	while(nroot(x)) {
    		int y = fa[x], z = fa[y];
    		if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
    		rotate(x);
    	}
    	pushup(x);
    }
    inline void access(int x) {
    	for(int y = 0; x; x = fa[y = x]) {
    		splay(x);
    		if(y) {
    			val[x][0][0] -= max(sum[y][0][0], sum[y][1][0]);
    			val[x][0][1] = val[x][0][0];
    			val[x][1][0] -= sum[y][0][0];
    		}
    		if(ch[x][1]) {
    			int t = ch[x][1];
    			val[x][0][0] += max(sum[t][0][0], sum[t][1][0]);
    			val[x][0][1] = val[x][0][0];
    			val[x][1][0] += sum[t][0][0];
    		}
    		ch[x][1] = y, pushup(x);
    	}
    }
    void dfs(int u, int ft) {
    	dp[u][1] = a[u], fa[u] = ft;
    	for(int v : e[u]) if(v != ft) {
    		dfs(v, u);
    		dp[u][0] += max(dp[v][0], dp[v][1]);
    		dp[u][1] += dp[v][0];
    	}
    	val[u][0][0] = val[u][0][1] = dp[u][0];
    	val[u][1][0] = dp[u][1], val[u][1][1] = -inf;
    	sum[u] = val[u];
    }
    signed main() {
    	n = read(), m = read();
    	rep(i, 1, n) a[i] = read();
    	rep(i, 2, n) {
    		int x = read(), y = read();
    		e[x].pb(y), e[y].pb(x);
    	}
    	dfs(1, 0);
    	while(m--) {
    		int x = read() ^ lastans, y = read();
    		access(x), splay(x);
    		val[x][1][0] += y - a[x], a[x] = y, pushup(x);
    		splay(1);
    		printf("%d
    ", lastans = max(sum[1][0][0], sum[1][1][0]));
    	}
    	return 0;
    }
    

    考虑到这棵树并不会动,用 LCT 维护有点浪费,于是考虑搞一种新的方法来划分树。

    有人从上古论文里翻出来了一个科技叫做“全局平衡二叉树”。

    注意到 LCT 就是每一条链建平衡树,考虑用类似的思想。

    建立的方法就是:先树剖,对于每一条重链每次取带权重心建立二叉树,连实边。对于轻子树建立的二叉树的根往当前二叉树上的节点拉虚边。

    容易发现在同一颗二叉树内往父亲跳的时候,每跳一次子树大小都会至少翻倍;切换一次二叉树意味着切换一次重边,只会有 (O(log n)) 次。仔细想想,这两个 (log) 并不是乘起来的,是加起来的,因为子树大小翻倍至多 (log n) 次,跳轻链也至多 (log n) 次。所以树高是 (O(log n)) 级别的,粗略分析上限是 (2log n),注意有个常数。

    仍然采用矩阵维护,维护方法类似 LCT,只维护实边信息,虚边一路跳到根更新 (g)

    如果我们要查询某个子树的答案怎么办?

    首先找到这个节点在全局平衡二叉树上所在的二叉树。

    考虑到全局平衡二叉树上每一个由实边连接的二叉树都是一条重链,并且先序遍历就是这条重链,根据之前重链剖分时的思路,我们要求的是一个点到重链底端的矩阵积。

    就相当于我们在二叉排序树上查询序列后缀积,这个随便写写就好了。

    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef double db;
    typedef long long LL;
    #define fi first
    #define se second
    #define pb push_back
    #define mkp make_pair
    #define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
    #define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
    template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
    template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
    inline int read() {
        int x = 0, f = 1; char ch = getchar();
        while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
        while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
        return f ? x : -x;
    }
    
    const int N = 1000005;
    const int inf = 0x3f3f3f3f;
    int n, m, a[N], f[N][2], g[N][2];
    vector<int> e[N];
    namespace Tree {
    
    int siz[N], fa[N], son[N];
    void dfs(int u, int ft) {
    	siz[u] = 1, fa[u] = ft;
    	f[u][1] = a[u];
    	for(int v : e[u]) if(v != ft) {
    		dfs(v, u), siz[u] += siz[v];
    		if(siz[v] > siz[son[u]]) son[u] = v;
    		f[u][0] += max(f[v][0], f[v][1]);
    		f[u][1] += f[v][0];
    	}
    	g[u][0] = f[u][0] - max(f[son[u]][0], f[son[u]][1]);
    	g[u][1] = f[u][1] - f[son[u]][0];
    }
    
    }
    
    struct Matrix {
    	int a[2][2];
    	Matrix(){ memset(a, -0x3f, sizeof a); }
    	inline int* operator [](const int &k) { return a[k]; }
    	inline Matrix operator * (const Matrix &t) const {
    		Matrix res;
    		res.a[0][0] = max(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
    		res.a[0][1] = max(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
    		res.a[1][0] = max(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
    		res.a[1][1] = max(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
    		return res;
    	}
    	void print() {
    		cerr << a[0][0] << ' ' << a[0][1] << '
    ' << a[1][0] << ' ' << a[1][1] << '
    ';
    	}
    };
    
    namespace bst {
    int fa[N], ch[N][2], stk[N], top, tsz[N], rt;
    bool isrt[N];
    Matrix val[N], sum[N];
    inline void pushup(int u) {
    	sum[u] = val[u];
    	if(ch[u][0]) sum[u] = sum[ch[u][0]] * sum[u];
    	if(ch[u][1]) sum[u] = sum[u] * sum[ch[u][1]];
    }
    inline int build2(int l, int r) {
    	if(l > r) return 0;
    	int ALL = 0, now = 0;
    	rep(i, l, r) ALL += tsz[i];
    	rep(i, l, r) {
    		now += tsz[i];
    		if(now << 1 >= ALL) {
    			int u = stk[i];
    			fa[ch[u][0] = build2(l, i - 1)] = u;
    			fa[ch[u][1] = build2(i + 1, r)] = u;
    			return pushup(u), u;
    		}
    	}
    	assert(0);
    }
    int build(int tp) {
    	for(int i = tp; i; i = Tree::son[i]) {
    		for(int v : e[i]) if(v != Tree::fa[i] && v != Tree::son[i])
    			fa[build(v)] = i;
    		val[i][0][0] = val[i][0][1] = g[i][0];
    		val[i][1][0] = g[i][1], val[i][1][1] = -inf;
    	}
    	top = 0;
    	for(int i = tp; i; i = Tree::son[i])
    		stk[++top] = i, tsz[top] = Tree::siz[i] - Tree::siz[Tree::son[i]];
    	int tmp = build2(1, top);
    	isrt[tmp] = 1;
    	return tmp;
    }
    void modify(int x, int y) {
    	val[x][1][0] += y - a[x], a[x] = y;
    	for(int i = x; i; i = fa[i]) {
    		Matrix pre = sum[i];
    		pushup(i);
    		Matrix suf = sum[i];
    		if(isrt[i] && fa[i]) {
    			int f = fa[i];
    			val[f][0][0] += max(suf[0][0], suf[1][0]) - max(pre[0][0], pre[1][0]);
    			val[f][0][1] = val[f][0][0];
    			val[f][1][0] += suf[0][0] - pre[0][0];
    		}
    	}
    }
    
    }
    
    signed main() {
    	n = read(), m = read();
    	rep(i, 1, n) a[i] = read();
    	rep(i, 2, n) {
    		int x = read(), y = read();
    		e[x].pb(y), e[y].pb(x);
    	}
    	Tree::dfs(1, 0);
    	bst::rt = bst::build(1);
    	int lastans = 0;
    	while(m--) {
    		int x = read() ^ lastans, y = read();
    		bst::modify(x, y);
    		printf("%d
    ", lastans = max(bst::sum[bst::rt][0][0], bst::sum[bst::rt][1][0]));
    	}
    	return 0;
    }
    

    到现在为止三种维护ddp常用的方法都已经介绍完毕,用哪种请读者自己选择。

    从我个人角度不建议写树剖,因为复杂度多一只 (log),可能被卡。而且其实树剖+线段树是三种写法里码量最大的。

    给出我的实现下,这两道题在洛谷评测的最大测试点用时:

    P4719(普通版) P4751(加强版)
    树剖+线段树 179ms >3.7s(TLE)
    LCT 77ms 2.95s
    全局平衡二叉树 55ms 1.42s

    毕竟每个人的实现都会有偏差,但是总体是可以看出每种方法的常数差别,在加强版体现的尤为突出。

    小总结

    上面解题的步骤其实是比较清晰的,也是一般做 ddp 题的步骤:

    • 写出不修改情况下的状态转移方程

    • 分离轻重儿子的贡献

    • 把转移写成矩阵

    • 大力码码码

    一般都会在最开始暴力跑一趟树形dp求出不包括重儿子的答案塞进矩阵。

    修改一般采用的方法是,消去原贡献,加入新贡献。

    查询只要理解ddp本质都没问题。

    例题二

    P6021 洪水

    小清新题,和模板没太大区别。

    题意:给一棵树,每个点有点权 (a_i),每次询问:在某个子树内以点权为代价删除(堵上)一些点使得根与子树内所有叶子不连通的最小代价;带单点修改。

    不带修情况

    (f_i) 表示把以 (i) 为根的子树完全堵上的答案。

    [f_u=min(a_u,sum_{v in son(u)} f_v) ]

    转移简洁并且单点修改使我们想到使用 ddp 来维护。

    分离轻重儿子(这里的 (son(u)) 表示 (u) 的重儿子):

    [f_u=min(a_u,f_{wson(u)}+sum_{v ot= wson(u)}f_v) ]

    [f_u=min(a_u,f_{wson(u)}+g_u) ]

    其中 (g_u) 是轻儿子的贡献

    把转移方程写成矩阵

    重定义广义矩阵乘法:

    [res_{i,j}=min(a_{i,k}+b_{k,j}) ]

    构造矩阵:

    如果矩阵是一维的

    [egin{bmatrix} x&y end{bmatrix} * egin{bmatrix} f_{wson(u)} end{bmatrix} =egin{bmatrix} f_u end{bmatrix} ]

    那么 (x=g_u,y=a_u-f_{son(u)})

    发现左矩阵做了一个重儿子的东西,不能做ddp。

    考虑再加一维:

    [egin{bmatrix} x&y\ z&w end{bmatrix} * egin{bmatrix} f_{wson(u)}\ p end{bmatrix} = egin{bmatrix} f_u\ q end{bmatrix} ]

    因为 (a_u=a_u+0) ,考虑直接把 (p,q) 设成 (0) , 那么根据转移方程,(x=g_u,y=a_u) ,这样 (f_u) 已经被正确表示了。

    但是底下的 (q) 在矩乘之后不一定是 (0) ,考虑通过 (z,w) 来维护 (q)

    直接展开 (q)(q=min(z+f_{son(u)},w))(p=0) 就不写了)

    (f_{son(u)}) 是非负的,所以 (w=0,zge 0) 即可

    矩阵构造完毕!

    [egin{bmatrix} g_u&a_u\ 0&0 end{bmatrix} * egin{bmatrix} f_{wson(u)}\ 0 end{bmatrix} = egin{bmatrix} f_u\ 0 end{bmatrix} ]

    封死一颗子树的代价就是 (f_u)

    这里要提一个小细节,就是叶子节点没有轻儿子的时候 (g) 怎么办。

    我想到了两种处理方法:

    一种处理方法是把 (g_u) 设为 (a_u),因为叶子节点在ddp的时候要满足 (f_u=g_u)

    还有一种方法就是把 (g_u) 设成 (+infty),直接禁止从“封死所有子树”这种方法的转移。并且矩阵左上角不和自己的 (a)(min),只维护轻子树的 dp值 和,调用 dp 值的时候再和自己的 (a)(min)

    第一种在树剖的时候比较好搞;如果用 LCT 维护我暂时没想到什么好的维护方法所以用了第二种。

    修改

    (x)(v)

    直接影响的就是这个节点的 (a_u)

    但是叶子节点还得同时更改 (g_u),千万别忘。

    至于轻边父亲的修改,减去原贡献加上新贡献就好了。

    LCT 同理,不过在修改一个节点矩阵的时候要先 accesssplay,这时候修改它对于任何节点都是没有影响的。

    查询

    树剖直接用线段树把这个点到重链底端的矩阵全部乘起来就好了。

    LCT 比较特殊。要把这个节点在 原树 上的父亲 access 一下,再 splay 这个节点,这样子这个节点的信息就是这颗子树的信息。

    我access在lct上的父亲调了一晚上(((

    矩乘只是 (2 imes2 imes2) ,建议手动暴力展开,可以快非常多。

    因为树剖代码是在之前的代码上改的,怕有些地方与描述不同,因此重写了一份 LCT 的代码

    树剖版代码
    #include<bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    typedef double db;
    #define pb(x) push_back(x)
    #define mkp(x,y) make_pair(x,y)
    //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
    //char buf[1<<21],*p1=buf,*p2=buf;
    inline int read() {
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
    	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
    	return x*f;
    }
    int rdc(){
    	char ch=getchar();
    	while(ch!='Q'&&ch!='C')ch=getchar();
    	return ch=='Q';
    }
    const int N=200005;
    const LL inf=1e14;
    const int T=N<<2;
    int n,dp[N];
    LL a[N];
    int head[N],num_edge;
    int dfn[N],rev[N],tmr,fa[N],siz[N],son[N],top[N],ed[N];
    struct edge{
    	int nxt,to;
    }e[N<<1];
    void addedge(int fr,int to){
    	++num_edge;
    	e[num_edge].nxt=head[fr];
    	e[num_edge].to=to;
    	head[fr]=num_edge;
    }
    struct Matrix {
    	LL a[2][2];
    	Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=inf;}
    	LL*operator[](const int&k){return a[k];}
    	Matrix operator * (const Matrix&b){
    		Matrix res;
    		res.a[0][0] = min(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
    		res.a[0][1] = min(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
    		res.a[1][0] = min(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
    		res.a[1][1] = min(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
    		return res;
    	}
    	void print(){
    		printf("%lld %lld
    %lld %lld
    
    ",a[0][0],a[0][1],a[1][0],a[1][1]);
    	}
    }mat[N],val[T];
    void dfs1(int u,int ft){
    	if(!e[head[u]].nxt)return dp[u]=a[u],siz[u]=1,void();
    	LL sum=0;siz[u]=1;
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;if(v==ft)continue;
    		fa[v]=u,dfs1(v,u),sum+=dp[v],siz[u]+=siz[v];
    		if(siz[v]>siz[son[u]])son[u]=v;
    	}
    	dp[u]=min(sum,a[u]);
    }
    void dfs2(int u,int tp){
    	top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
    	mat[u][0][0]=0,mat[u][0][1]=a[u],
    	mat[u][1][0]=0,mat[u][1][1]=0;
    	if(!son[u])return mat[u][0][0]=a[u],void();
    	dfs2(son[u],tp);
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==fa[u]||v==son[u])continue;
    		dfs2(v,v),mat[u][0][0]+=dp[v];
    	}
    }
    #define lc (p<<1)
    #define rc (p<<1|1)
    void pushup(int p){val[p]=val[lc]*val[rc];}
    void build(int l,int r,int p=1){
    	if(l==r)return val[p]=mat[rev[l]],void();
    	int mid=(l+r)>>1;
    	build(l,mid,lc),build(mid+1,r,rc);
    	pushup(p);
    }
    Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
    	if(ql<=l&&r<=qr)return val[p];
    	int mid=(l+r)>>1;
    	if(qr<=mid)return query(ql,qr,l,mid,lc);
    	if(mid<ql)return query(ql,qr,mid+1,r,rc);
    	return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
    }
    void change(int pos,int l=1,int r=n,int p=1){
    	if(l==r)return val[p]=mat[rev[pos]],void();
    	int mid=(l+r)>>1;
    	if(pos<=mid)change(pos,l,mid,lc);
    	else change(pos,mid+1,r,rc);
    	pushup(p);
    }
    void update(int x,int v){
    	mat[x][0][1]+=v,a[x]+=v;
    	if(siz[x]==1)mat[x][0][0]+=v;
    	while(x){
    		Matrix lst=query(dfn[top[x]],ed[top[x]]);
    		change(dfn[x]);
    		Matrix now=query(dfn[top[x]],ed[top[x]]);
    		x=fa[top[x]];
    		mat[x][0][0]+=now[0][0]-lst[0][0];
    	}
    }
    signed main(){
    	n=read();
    	for(int i=1;i<=n;++i)a[i]=read();
    	for(int i=1;i<n;++i){
    		int x=read(),y=read();
    		addedge(x,y),addedge(y,x);
    	}
    	dfs1(1,0),dfs2(1,1),build(1,n);
    	
    	for(int m=read();m;--m){
    		int opt=rdc(),x=read();
    		if(opt){
    			Matrix t=query(dfn[x],ed[top[x]]);
    			printf("%lld
    ",t[0][0]);
    		}
    		else update(x,read());
    	}
    	return 0;
    }
    
    LCT 版代码
    #include <bits/stdc++.h>
    using namespace std;
    typedef double db;
    typedef long long LL;
    #define fi first
    #define se second
    #define pb push_back
    #define mkp make_pair
    #define sz(v) (int)(v).size()
    #define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
    #define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
    template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
    template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
    inline int read() {
        int x = 0, f = 1; char ch = getchar();
        while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
        while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
        return f ? x : -x;
    }
    inline int rdch() {
    	char ch = getchar();
    	while(ch != 'Q' && ch != 'C') ch = getchar();
    	return ch == 'Q';
    }
    const int N = 200005;
    const LL inf = 1e14;
    int n, m, lef[N], tfa[N];
    LL a[N], dp[N];
    vector<int> e[N];
    int fa[N], ch[N][2];
    struct Matrix {
    	LL a[2][2];
    	Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = inf; }
    	inline LL* operator [](const int &k) { return a[k]; }
    	inline Matrix operator * (const Matrix &t) const {
    		Matrix res;
    		res.a[0][0] = min(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
    		res.a[0][1] = min(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
    		res.a[1][0] = min(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
    		res.a[1][1] = min(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
    		return res;
    	}
    } val[N], sum[N];
    inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
    inline void pushup(int x) {
    	sum[x] = val[x];
    	if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
    	if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
    }
    inline void rotate(int x) {
    	int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
    	if(nroot(y)) ch[z][ch[z][1] == y] = x;
    	ch[x][!k] = y, ch[y][k] = w;
    	fa[w] = y, fa[y] = x, fa[x] = z;
    	pushup(y);
    }
    inline void splay(int x) {
    	while(nroot(x)) {
    		int y = fa[x], z = fa[y];
    		if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
    		rotate(x);
    	}
    	pushup(x);
    }
    inline void access(int x) {
    	for(int y = 0; x; x = fa[y = x]) {
    		splay(x);
    		if(y) val[x][0][0] -= min(sum[y][0][0], sum[y][0][1]);
    		if(ch[x][1]) val[x][0][0] += min(sum[ch[x][1]][0][0], sum[ch[x][1]][0][1]);
    		ch[x][1] = y, pushup(x);
    	}
    }
    void dfs(int u, int ft) {
    	lef[u] = 1;
    	for(int v : e[u]) if(v != ft)
    		tfa[v] = fa[v] = u, dfs(v, u), dp[u] += dp[v], lef[u] = 0;
    	if(lef[u]) dp[u] = a[u];
    	val[u][0][0] = lef[u] ? inf : dp[u], val[u][0][1] = a[u];
    	val[u][1][0] = val[u][1][1] = 0;
    	pushup(u);
    	ckmin(dp[u], a[u]);
    }
    signed main() {
    	n = read();
    	rep(i, 1, n) a[i] = read();
    	rep(i, 2, n) {
    		int x = read(), y = read();
    		e[x].pb(y), e[y].pb(x);
    	}
    	dfs(1, 0);
    	for(m = read(); m--; ) {
    		int op = rdch(), x = read();
    		if(op) {
    			if(tfa[x]) access(tfa[x]);
    			splay(x), printf("%lld
    ", min(sum[x][0][0], sum[x][0][1]));
    		} else {
    			int y = read();
    			access(x), splay(x);
    			val[x][0][1] += y;
    			pushup(x);
    		}
    	}
    }
    

    例题三

    P5024 [NOIP2018 提高组] 保卫王国

    分别钦定两个城市必取或者必不取的最小独立集。

    对于一定驻扎,把点权设为 (-infty)。对于一定不驻扎,点权设为 (+infty)

    然后跑最小独立集即可。

    最后输出的时候加或减一下之前偏移的 (infty)

    可以偷懒把点权取反拉最大独立集的板子(

    但是这里修改带四倍常数,用树剖写时间非常紧,uoj 上根本过不去,建议写全局平衡二叉树,我懒得重写了。

    Code
    #include<bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    typedef double db;
    #define pb(x) push_back(x)
    #define mkp(x,y) make_pair(x,y)
    //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
    //char buf[1<<21],*p1=buf,*p2=buf;
    inline int read() {
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
    	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
    	return x*f;
    }
    const int N=100005;
    const int M=N<<2;
    const LL inf=1e12;
    int n,m;
    LL p[N];
    int siz[N],dfn[N],tmr,son[N],fa[N],top[N],rev[N],ed[N],f[N][2];
    char cynAKIOI[114514];
    struct edge{
    	int nxt,to;
    }e[N<<1];
    int head[N],num_edge;
    void addedge(int fr,int to){
    	++num_edge;
    	e[num_edge].nxt=head[fr];
    	e[num_edge].to=to;
    	head[fr]=num_edge;
    }
    struct Matrix{
    	LL p[2][2];
    	Matrix(){p[0][0]=p[0][1]=p[1][0]=p[1][1]=-inf;}
    	LL*operator[](const int&k){return p[k];}
    	Matrix operator * (const Matrix&b){
    		Matrix res;
    //		for(int i=0;i<2;++i)
    //			for(int j=0;j<2;++j)
    //				for(int k=0;k<2;++k)
    //					res.p[i][j]=max(res.p[i][j],p[i][k]+b.p[k][j]);
    		res[0][0]=max(p[0][0]+b.p[0][0],p[0][1]+b.p[1][0]);
    		res[0][1]=max(p[0][0]+b.p[0][1],p[0][1]+b.p[1][1]);
    		res[1][0]=max(p[1][0]+b.p[0][0],p[1][1]+b.p[1][0]);
    		res[1][1]=max(p[1][0]+b.p[0][1],p[1][1]+b.p[1][1]);
    		return res;
    	}
    }mat[N],val[M];
    void dfs1(int u,int ft){
    	siz[u]=1,f[u][1]=p[u];
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;if(v==ft)continue;
    		fa[v]=u,dfs1(v,u),siz[u]+=siz[v];
    		if(siz[v]>siz[son[u]])son[u]=v;
    		f[u][0]+=max(f[v][0],f[v][1]);
    		f[u][1]+=f[v][0];
    	}
    }
    void dfs2(int u,int tp){
    	top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
    	if(son[u])dfs2(son[u],tp);
    	LL g[2];g[0]=0,g[1]=p[u];
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==son[u]||v==fa[u])continue;
    		dfs2(v,v);
    		g[0]+=max(f[v][0],f[v][1]);
    		g[1]+=f[v][0];
    	}
    	mat[u][0][0]=g[0],mat[u][0][1]=g[0];
    	mat[u][1][0]=g[1],mat[u][1][1]=-inf;
    }
    #define lc (p<<1)
    #define rc (p<<1|1)
    void pushup(int p){val[p]=val[lc]*val[rc];}
    void build(int l,int r,int p){
    	if(l==r)return val[p]=mat[rev[l]],void();
    	int mid=(l+r)>>1;
    	build(l,mid,lc),build(mid+1,r,rc);
    	pushup(p);
    }
    Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
    	if(ql<=l&&r<=qr)return val[p];
    	int mid=(l+r)>>1;
    	if(qr<=mid)return query(ql,qr,l,mid,lc);
    	if(mid<ql)return query(ql,qr,mid+1,r,rc);
    	return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
    }
    void change(int pos,int l=1,int r=n,int p=1){
    	if(l==r)return val[p]=mat[rev[l]],void();
    	int mid=(l+r)>>1;
    	if(pos<=mid)change(pos,l,mid,lc);
    	else change(pos,mid+1,r,rc);
    	pushup(p);
    }
    void update(int x,LL v){
    	mat[x][1][0]+=v,p[x]+=v;
    	while(x){
    		Matrix lst=query(dfn[top[x]],ed[top[x]]);
    		change(dfn[x]);
    		Matrix now=query(dfn[top[x]],ed[top[x]]);
    		x=fa[top[x]];
    		mat[x][0][0]+=max(now[0][0],now[1][0])-max(lst[0][0],lst[1][0]);
    		mat[x][0][1]=mat[x][0][0];
    		mat[x][1][0]+=now[0][0]-lst[0][0];
    	}
    }
    signed main(){
    	n=read(),m=read(),scanf("%s",cynAKIOI);
    	for(int i=1;i<=n;++i)p[0]+=(p[i]=read());
    	for(int i=1;i<n;++i){
    		int x=read(),y=read();
    		addedge(x,y),addedge(y,x);
    	}
    	dfs1(1,0),dfs2(1,1),build(1,n,1);
    	while(m--){
    		int a=read(),x=read(),b=read(),y=read();
    		LL ad1=x?-inf:inf,out1=x?0:inf;
    		LL ad2=y?-inf:inf,out2=y?0:inf;
    		update(a,ad1),update(b,ad2);
    		Matrix res=query(dfn[1],ed[1]);
    		LL out=p[0]-max(res[0][0],res[1][0])+out1+out2;
    		out<inf?printf("%lld
    ",out):puts("-1");
    		update(a,-ad1),update(b,-ad2);
    	}
    	return 0;
    }
    

    例题四

    P3781 [SDOI2017]切树游戏

    先考虑不带修改的情况如何dp。

    (dp(u,msk)) 表示以 (u) 为根的联通子树 (operatorname{xor}) 起来为 (msk) 的方案数。

    转移是:

    [f(u, msk) = sum_{vin son(u),Xoplus Y = msk}f(u, X)*f(v,Y) ]

    暴力转移是 (O(m^2)),非常明显可以 FWT 优化成 (O(mlog m))

    统计答案的时候,假设询问 (k),那么就是 (sum_{i=1}^{n} dp(i,k))

    我觉得这里还是有必要提一下暴力 dp 的边界以及转移的细节。

    我一开始写的边界处理是:(dp(u,w_u)=1),在把所有孩子合并上来之后再给 (dp(u,0)) 加一,这样它父亲调用它的时候那个 (0) 就相当于不选自己。

    凭直觉就知道在这种鬼地方多个类似 if 的东西可能非常难办。以及 FWT 和 IFWT 的位置可能影响我们维护修改的难度。还有我们统计答案的方式是遍历所有节点而非在单一节点统计答案。这些问题在一开始就得解决。

    以下记 (hat{a}) 表示 (a) FWT 之后的数组。

    首先解决统计答案的问题。

    考虑记 (g(u,msk)=f(u,msk)+sum_{vin_{son(u)}}f(v,msk))

    这样子我们调用 (g(1,msk)) 就能得到整颗树的答案了。

    接下去看怎么把转移写简洁。

    最后单独给 (0) 加一肯定要去掉。那么在转移方程后加一项就行了

    [f(u, msk) = sum_{vin son(u),Xoplus Y = msk}f(u, X)*f(v,Y)+f(u,msk) ]

    FWT 之后有

    [hat{f}(u,msk)=hat{f}(u,msk)hat{f}(v,msk)+hat{f}(u,msk)=hat{f}(u,msk)(hat{f}(v,msk)+1) ]

    于是这个转移可以写的非常简洁:

    [hat{f}(u,msk)=hat{w}(u,msk)prod_{vin son(u)} (hat{f}(v,msk)+1) ]

    (hat{w}(u)) 表示这个点的点权 FWT 之后的序列。

    考虑 (g) 怎么搞。如果 IFWT 回去再统计又会使转移非常麻烦。

    注意到点值是可以直接加的,那不妨维护 (hat{g}),最后 IFWT 回去输出答案。

    [hat{g}(u,msk)=hat{f}(u,msk)+sum_{vin son(u)}hat{g}(v,msk) ]

    现在转移方程非常简洁了,只不过复杂度是 (O(qnmlog m)),考虑怎么优化。

    注意以下的 (f,g) 全部定义为多项式,乘法定义为按位相乘。

    考虑 ddp。分离轻重儿子:

    [hat{f}(u)=(hat{f}(wson_u)+1)hat{w}(u)prod_{vin son(u),v ot=wson(u)}(hat{f}(v)+1)\ hat{g}(u)=hat{f}(u)+hat{g}(wson_u)+sum_{vin won(u),v ot=wson(u)}hat{g}(v) ]

    [hat{F}(u)=hat{w}(u)prod_{vin son(u),v ot=wson(u)}(hat{f}(v)+1)\ hat{G}(u)=sum_{vin won(u),v ot=wson(u)}hat{g}(v) ]

    那么 dp 就写成了下面的形式

    [hat{f}(u)=(hat{f}(wson_u)+1)hat{F}(u)\=hat{F}(u)+hat{F}(u)hat{f}(wson_u)\ hat{g}(u)=hat{f}(u)+hat{g}(wson_u)+hat{G}(u)\=hat{F}(u)+hat{F}(u)hat{f}(wson_u)+hat{g}(wson_u)+hat{G}(u) ]

    然后就构造矩阵转移

    [egin{bmatrix} hat{F}(u) & 0 & hat{F}(u)\ hat{F}(u)& 1 & hat{F}(u)+hat{G}(u)\ 0 & 0 & 1 end{bmatrix} * egin{bmatrix} hat{f}(wson_u)\ hat{g}(wson_u)\ 1 end{bmatrix} =egin{bmatrix} hat{f}(u)\ hat{g}(u)\ 1 end{bmatrix} ]

    修改的时候只需要跳重链修改 (hat{F},hat{G}) 就好,就是消去原贡献加入现在的贡献。

    但是 (F) 消贡献是除掉一个东西,并且 XOR 的 FWT 是可以出负数的,加上模数非常小,很有可能除一个 (0) 下去(样例就是),看起来非常棘手。

    事实上这个处理非常简单:对于每个节点开桶存乘了几个 (0),除以 (0) 的时候操作桶就行了。

    复杂度是 (O(qm(log n+log m)))

    但是矩阵乘法带 (27) 常数,全局平衡二叉树带 (2) 倍常数,带进去一算是惊人的 1e9,加上大量封装,根本过不去。

    这时候有个小 trick,有些矩阵矩乘之后常数不变,这个矩阵也是这样。

    [egin{bmatrix} a_1 & 0 & c_1\ b_1 & 1 & d_1\ 0 & 0 & 1 end{bmatrix} * egin{bmatrix} a_2 & 0 & c_2\ b_2 & 1 & d_2\ 0 & 0 & 1 end{bmatrix} = egin{bmatrix} a_1a_2 & 0 & a_1c_2+c_1\ b_1a_2+b_2 & 1 & b_1c_2+d_2+d_1\ 0 & 0 & 1 end{bmatrix} ]

    于是只用维护四个值,常数就从 (27) 降到了 (8)

    到此为止思路结束了,码代码就靠自己了(逃

    不过这题别写树剖,多个 (log) 运算量差不多是 1e9,加上洛谷有个毒瘤加了组对着树剖卡的数据,基本不用想过。

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	return f?x:-x;
    }
    inline int rdch() {
    	char ch = getchar();
    	while(ch != 'Q' && ch != 'C') ch = getchar();
    	return ch == 'Q';
    }
    const int N = 30005;
    const int mod = 10007;
    const int iv2 = (mod + 1) >> 1;
    int inv[N];
    int n, m, w[N];
    inline int qpow(int n, int k) {
    	int res = 1;
    	for(; k; k >>= 1, n = n * n % mod)
    		if(k & 1) res = res * n % mod;
    	return res;
    }
    inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x; }
    inline int sub(int x, int y) { return (x -= y) < 0 ? x + mod : x; }
    struct pint {
    	int v, c;
    	pint() { v = 1, c = 1; }
    	pint(int v_) {
    		if(!v_) v = 1, c = 1;
    		else v = v_, c = 0;
    	}
    	inline int val() const { return c ? 0 : v; }
    	friend pint operator * (pint a, const int &b) {
    		if(!b) return ++a.c, a;
    		else return (a.v *= b) %= mod, a;
    	}
    	friend pint operator / (pint a, const int &b) {
    		if(!b) return --a.c, a;
    		else return (a.v *= inv[b]) %= mod, a;
    	}
    };
    inline vector<int> change(const vector<pint> &a) {
    	vector<int> res(m);
    	for(int i = 0; i < m; ++i) res[i] = a[i].val();
    	return res;
    }
    inline vector<pint> operator * (const vector<pint> &a, const vector<int> &b) {
    	vector<pint> res(m);
    	for(int i = 0; i < m; ++i) res[i] = a[i] * b[i];
    	return res;
    }
    inline vector<pint> operator / (const vector<pint> &a, const vector<int> &b) {
    	vector<pint> res(m);
    	for(int i = 0; i < m; ++i) res[i] = a[i] / b[i];
    	return res;
    }
    inline vector<int> operator + (const vector<int> &a, const vector<int> &b) {
    	vector<int> res(m);
    	for(int i = 0; i < m; ++i) res[i] = add(a[i], b[i]);
    	return res;
    }
    inline vector<int> operator - (const vector<int> &a, const vector<int> &b) {
    	vector<int> res(m);
    	for(int i = 0; i < m; ++i) res[i] = sub(a[i], b[i]);
    	return res;
    }
    inline vector<int> operator * (const vector<int> &a, const vector<int> &b) {
    	vector<int> res(m);
    	for(int i = 0; i < m; ++i) res[i] = a[i] * b[i] % mod;
    	return res;
    }
    inline vector<int> addone(vector<int> a) {
    	for(int i = 0; i < m; ++i) a[i] = add(a[i], 1);
    	return a;
    }
    inline vector<int> XOR(vector<int> a) {
    	for(int i = 1; i < m; i <<= 1)
    		for(int j = 0; j < m; j += i << 1)
    			for(int k = 0; k < i; ++k) {
    				int X = a[j + k], Y = a[i + j + k];
    				a[j + k] = add(X, Y), a[i + j + k] = sub(X, Y);
    			}
    	return a;
    }
    inline vector<int> IXOR(vector<int> a) {
    	for(int i = 1; i < m; i <<= 1)
    		for(int j = 0; j < m; j += i << 1)
    			for(int k = 0; k < i; ++k) {
    				int X = a[j + k], Y = a[i + j + k];
    				a[j + k] = (X + Y) * iv2 % mod, a[i + j + k] = (X - Y + mod) * iv2 % mod;
    			}
    	return a;
    }
    int rt, tfa[N], fa[N], cnz[N], siz[N], son[N], stk[N], top, ch[N][2], tsz[N];
    bool isrt[N];
    vector<int> e[N];
    struct Matrix {
    	vector<int> a00, a10, a02, a12;
    	inline Matrix operator * (const Matrix &t) const {
    		Matrix res;
    		res.a00 = a00 * t.a00;
    		res.a10 = a10 * t.a00 + t.a10;
    		res.a02 = a00 * t.a02 + a02;
    		res.a12 = a10 * t.a02 + a12 + t.a12;
    		return res;
    	}
    } val[N], sum[N];
    Matrix mat;
    vector<pint> F[N];
    vector<int> G[N], ans, f[N], g[N], a[N];
    inline void pushup(int u) {
    	sum[u] = val[u];
    	if(ch[u][0]) sum[u] = sum[ch[u][0]] * sum[u];
    	if(ch[u][1]) sum[u] = sum[u] * sum[ch[u][1]];
    }
    inline void get(int u) {
    	val[u].a00 = val[u].a10 = val[u].a02 = val[u].a12 = change(F[u]);
    	val[u].a12 = val[u].a12 + G[u];
    }
    void dfs(int u, int ft) {
    	f[u].resize(m), g[u].resize(m);
    	f[u][w[u]] = 1, f[u] = XOR(f[u]);
    	a[u] = f[u];
    	
    	siz[u] = 1;
    	for(int v : e[u]) if(v != ft) {
    		tfa[v] = u, dfs(v, u), siz[u] += siz[v];
    		if(siz[v] > siz[son[u]]) son[u] = v;
    		f[u] = f[u] * addone(f[v]), g[u] = g[u] + g[v];
    	}
    	g[u] = g[u] + f[u];
    	
    	F[u].resize(m), G[u].resize(m);
    	for(int i = 0; i < m; ++i) F[u][i] = a[u][i];
    	for(int v : e[u]) if(v != ft && v != son[u]) {
    		F[u] = F[u] * addone(f[v]), G[u] = G[u] + g[v];
    	}
    	get(u);
    }
    inline int build2(int l, int r) {
    	if(l > r) return 0;
    	int ALL = 0, now = 0;
    	for(int i = l; i <= r; ++i) ALL += tsz[i];
    	for(int i = l; i <= r; ++i) {
    		now += tsz[i];
    		if(now << 1 >= ALL) {
    			int u = stk[i];
    			fa[ch[u][0] = build2(l, i - 1)] = u;
    			fa[ch[u][1] = build2(i + 1, r)] = u;
    			return pushup(u), u;
    		}
    	}
    	return -1;
    }
    int build(int tp) {
    	for(int i = tp; i; i = son[i])
    		for(int v : e[i]) if(v != son[i] && v != tfa[i])
    			fa[build(v)] = i;
    	top = 0;
    	for(int i = tp; i; i = son[i]) stk[++top] = i, tsz[top] = siz[i] - siz[son[i]];
    	int tmp = build2(1, top);
    	return isrt[tmp] = 1, tmp;
    }
    void modify(int x, int y) {
    	F[x] = F[x] / a[x];
    	memset(a[x].data(), 0, m << 2);
    	a[x][y] = 1, w[x] = y, a[x] = XOR(a[x]);
    	F[x] = F[x] * a[x], get(x);
    	for(; x; x = fa[x]) {
    		if(fa[x] && isrt[x]) {
    			F[fa[x]] = F[fa[x]] / addone(sum[x].a02), G[fa[x]] = G[fa[x]] - sum[x].a12;
    			pushup(x);
    			F[fa[x]] = F[fa[x]] * addone(sum[x].a02), G[fa[x]] = G[fa[x]] + sum[x].a12;
    			get(fa[x]);
    		} else pushup(x);
    	}
    }
    signed main() {
    	inv[1] = 1;
    	for(int i = 2; i < mod; ++i) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
    	n = read(), m = read();
    	rep(i, 1, n) w[i] = read();
    	rep(i, 2, n) {
    		int x = read(), y = read();
    		e[x].pb(y), e[y].pb(x);
    	}
    	dfs(1, 0);
    	rt = build(1);
    	ans = IXOR(sum[rt].a12);
    	for(int q = read(); q; --q) {
    		int op = rdch(), x = read();
    		if(op == 1) {
    			printf("%d
    ", ans[x]);
    		} else {
    			int y = read();
    			modify(x, y);
    			ans = IXOR(sum[rt].a12);
    		}
    	}
    	return 0;
    }
    

    参考资料

    shadowice1984 P3781 的题解

    Tweetuzki P4719 的题解

    Great_Influence 对全局平衡二叉树的讲解

    路漫漫其修远兮,吾将上下而求索
  • 相关阅读:
    配置sonar、jenkins进行持续审查
    maven命令解释
    Maven中-DskipTests和-Dmaven.test.skip=true的区别
    maven之一:maven安装和eclipse集成
    Eclipse安装Maven插件
    IntelliJ IDEA单元测试和代码覆盖率图解
    关于Spring中的<context:annotation-config/>配置
    Java开发之@PostConstruct和@PreConstruct注解
    Java定时任务的三种实现方法
    JAVA之Mybatis基础入门二 -- 新增、更新、删除
  • 原文地址:https://www.cnblogs.com/zzctommy/p/14731712.html
Copyright © 2011-2022 走看看