zoukankan      html  css  js  c++  java
  • 【知识总结】动态 DP

    勾起了我悲伤的回忆 —— NOIP2018 316pts ……

    主要思想:将 DP 过程分解为方便单点修改和一个区间合并的操作(通常类似矩阵乘法),然后用数据结构(通常为线段树)维护。

    例:给定一个长为 (n) 的整数序列,相邻两个数最多选一个,有 (m) 次修改序列中的一个数,求每次修改后选出数之和的最大值。
    (n,mleq 10^5)

    如果不会做不带修改的情况,请默默摁 Ctrl + w 然后去学 DP 入门

    如果不带修改,明显设 (f_{i,0/1}) 表示当第 (i) 个点选 (0) / 不选 (1) 时,前 (i) 个点的和的最大值。于是有如下转移方程:

    [f_{i,0}=f_{i-1,1} ]

    [f_{i,1}=max(f_{i-1,0},f_{i-1,1})+a_i ]

    如果加入修改操作呢?只有这两个 DP 方程比较难办,因为修改一个值就要重新计算后面的所有答案。GG

    接下来是「动态 DP 」中最巧妙的部分:考虑用一个矩阵来表示从 (i-1) 点向 (i) 点转移,用某个表示「初始状态」的矩阵依次乘上每个点的转移就是答案。因为矩阵乘法有结合律,所以可以把答案表示成「初始状态」乘上「修改点前面的矩阵乘积」乘上「当前位置修改后的矩阵」乘上「修改点后面的矩阵乘积」。这样只需要用线段树单点修改和查询区间乘积(事实上这道题只需要查全局乘积)即可。

    然而,这道题中转移的运算并不是加和乘,尤其是其中还有一个碍眼的求最大值。但我们可以把矩阵乘法的定义稍加修改,把原来两个整数的「乘法」改为两个整数的加法,「加法」改为对两个整数取最大值。这样我们就构造如下转移矩阵:

    [egin{bmatrix} f_{i-1,0}&f_{i-1,1} end{bmatrix} egin{bmatrix} 0&a_i\ 0&-infty\ end{bmatrix}= egin{bmatrix} f_{i,0}&f_{i,1}\ end{bmatrix}]

    还有一个很多人没考虑过的细节 (可能是大佬们认为这个问题太显然不需要考虑) :这个「初始状态」是什么呢?对于这道题,前一个数如果不选是不影响当前决策的,而如果选了的话就会造成一个当前点不能选的「约束」。而第一个点无论如何都不会受到这种「约束」,所以第一个点的「前一个点」应该被看作「没有选」,即初始状态为 (egin{bmatrix}0&-inftyend{bmatrix})

    我们把这个问题扩展到树上,即每条边的两端点中至少选一个点(洛谷 4719【模板】动态 DP )。考虑树链剖分来转化成序列问题。设 (f_{i,0/1}) 表示 (i) 点选 / 不选时 (i) 点子树中的最大权值和,(g_{i,0/1}) 表示 (i) 点选 / 不选时 (i) 点子树除 (s_i) 的子树以外的部分中的最大权值和,其中 (s_i)(i) 的重儿子。对于一条重链有如下方程:

    [egin{bmatrix} f_{s_i,0}&f_{s_i,1} end{bmatrix} egin{bmatrix} g_{i,0}&g_{i,1}\ g_{i,0}&-infty\ end{bmatrix}= egin{bmatrix} f_{i,0}&f_{i,1}\ end{bmatrix}]

    这样,每个点的答案是「初始状态」乘上它到所在重链末尾的矩阵乘积。

    至于具体实现,可以开始先一遍 DP 算出所有的 (f)(g) 。每次修改时沿着重链向上爬,暴力修改链首父亲的 (g) 值。链首到链首父亲的边是一条轻边,所以这样每次修改一个点时要更新 (g) 值的点的数量约等于当前点到根的路径上的轻边数量(可能有加一减一之类的细节),是 (O(log n)) 。因此总复杂度 (O(mlog^2n))

    和上面类似的分析,初始状态(叶子节点那个不存在的重儿子的 (f) 值)是 (egin{bmatrix}0&-inftyend{bmatrix}) 。用这个东西去乘相当于取原矩阵的第一行,所以不需要「显式」地乘。

    代码:

    很抱歉我代码里的矩阵行列和上文是反的,所有矩阵乘法的顺序也是反的我也不知道怎么回事 QAQ 。

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <cctype>
    using namespace std;
    
    namespace zyt
    {
    	template<typename T>
    	inline bool read(T &x)
    	{
    		char c;
    		bool f = false;
    		x = 0;
    		do
    			c = getchar();
    		while (c != EOF && c != '-' && !isdigit(c));
    		if (c == EOF)
    			return false;
    		if (c == '-')
    			f = true, c = getchar();
    		do
    			x = x * 10 + c - '0', c = getchar();
    		while (isdigit(c));
    		if (f)
    			x = -x;
    		return true;
    	}
    	template<typename T>
    	inline void write(T x)
    	{
    		static char buf[20];
    		char *pos = buf;
    		if (x < 0)
    			putchar('-'), x = -x;
    		do
    			*pos++ = x % 10 + '0';
    		while (x /= 10);
    		while (pos > buf)
    			putchar(*--pos);
    	}
    	const int N = 1e5 + 10, INF = 0x3f3f3f3f;
    	int n, m, head[N], ecnt, w[N], size[N], son[N], fa[N], dfn[N], dfncnt, top[N], f[N][2], g[N][2], end[N], pos[N];
    	struct edge
    	{
    		int to, next;
    	}e[N << 1];
    	void add(const int a, const int b)
    	{
    		e[ecnt] = (edge){b, head[a]}, head[a] = ecnt++;
    	}
    	void dfs(const int u, const int f)
    	{
    		fa[u] = f, size[u] = 1;
    		for (int i = head[u]; ~i; i = e[i].next)
    		{
    			int v = e[i].to;
    			if (v == f)
    				continue;
    			dfs(v, u);
    			size[u] += size[v];
    			if (size[v] > size[son[u]])
    				son[u] = v;
    		}
    	}
    	void dfs2(const int u, const int t)
    	{
    		top[u] = t, dfn[u] = ++dfncnt, pos[dfncnt] = u, end[t] = u;
    		if (son[u])
    			dfs2(son[u], t);
    		for (int i = head[u]; ~i; i = e[i].next)
    		{
    			int v = e[i].to;
    			if (v == fa[u] || v == son[u])
    				continue;
    			dfs2(v, v);
    		}
    	}
    	void dfs3(const int u)
    	{
    		g[u][0] = 0, g[u][1] = w[u];
    		for (int i = head[u]; ~i; i = e[i].next)
    		{
    			int v = e[i].to;
    			if (v == fa[u] || v == son[u])
    				continue;
    			dfs3(v);
    			g[u][0] += max(f[v][0], f[v][1]);
    			g[u][1] += f[v][0];
    		}
    		f[u][0] = g[u][0], f[u][1] = g[u][1];
    		if (son[u])
    		{
    			dfs3(son[u]);
    			f[u][0] += max(f[son[u]][0], f[son[u]][1]);
    			f[u][1] += f[son[u]][0];
    		}
    	}
    	struct Matrix
    	{
    		int data[2][2], n, m;
    		Matrix(const int _n = 0, const int _m = 0)
    			: n(_n), m(_m)
    		{
    			for (int i = 0; i < n; i++)
    				for (int j = 0; j < m; j++)
    					data[i][j] = -INF;
    		}
    		Matrix operator * (const Matrix &b) const
    		{
    			Matrix ans(n, b.m);
    			for (int i = 0; i < n; i++)
    				for (int k = 0; k < m; k++)
    					for (int j = 0; j < b.m; j++)
    						ans.data[i][j] = max(ans.data[i][j], data[i][k] + b.data[k][j]);
    			return ans;
    		}
    	}val[N];
    	namespace Segment_Tree
    	{
    		struct node
    		{
    			Matrix m;
    		}tree[N << 2];
    		void update(const int rot)
    		{
    			tree[rot].m = tree[rot << 1].m * tree[rot << 1 | 1].m;
    		}
    		void build(const int rot, const int lt, const int rt)
    		{
    			tree[rot].m = Matrix(2, 2);
    			if (lt == rt)
    				return void(tree[rot].m = val[pos[lt]]);
    			int mid = (lt + rt) >> 1;
    			build(rot << 1, lt, mid), build(rot << 1 | 1, mid + 1, rt);
    			update(rot);
    		}
    		void change(const int rot, const int lt, const int rt, const int p)
    		{
    			if (lt == rt)
    				return void(tree[rot].m = val[pos[p]]);
    			int mid = (lt + rt) >> 1;
    			if (p <= mid)
    				change(rot << 1, lt, mid, p);
    			else
    				change(rot << 1 | 1, mid + 1, rt, p);
    			update(rot);
    		}
    		Matrix query(const int rot, const int lt, const int rt, const int ls, const int rs)
    		{
    			if (ls <= lt && rt <= rs)
    				return tree[rot].m;
    			int mid = (lt + rt) >> 1;
    			if (rs <= mid)
    				return query(rot << 1, lt, mid, ls, rs);
    			else if (ls > mid)
    				return query(rot << 1 | 1, mid + 1, rt, ls, rs);
    			else
    				return query(rot << 1, lt, mid, ls, rs) * query(rot << 1 | 1, mid + 1, rt, ls, rs);
    		}
    	}
    	int work()
    	{
    		using namespace Segment_Tree;
    		read(n), read(m);
    		memset(head, -1, sizeof(int[n + 1]));
    		for (int i = 1; i <= n; i++)
    			read(w[i]), val[i] = Matrix(2, 2);
    		for (int i = 1; i < n; i++)
    		{
    			int a, b;
    			read(a), read(b);
    			add(a, b), add(b, a);
    		}
    		dfs(1, 0), dfs2(1, 1), dfs3(1);
    		for (int i = 1; i <= n; i++)
    			val[i].data[0][0] = val[i].data[0][1] = g[i][0], val[i].data[1][0] = g[i][1], val[i].data[1][1] = -INF;
    		build(1, 1, n);
    		while (m--)
    		{
    			int u, x;
    			read(u), read(x);
    			val[u].data[1][0] += x - w[u];
    			w[u] = x;
    			Matrix a, b;
    			while (u)
    			{
    				a = query(1, 1, n, dfn[top[u]], dfn[end[top[u]]]);
    				change(1, 1, n, dfn[u]);
    				b = query(1, 1, n, dfn[top[u]], dfn[end[top[u]]]);
    				u = fa[top[u]];
    				val[u].data[0][0] += max(b.data[0][0], b.data[1][0]) - max(a.data[0][0], a.data[1][0]);
    				val[u].data[0][1] = val[u].data[0][0];
    				val[u].data[1][0] += b.data[0][0] - a.data[0][0];
    			}
    			Matrix ans = query(1, 1, n, dfn[1], dfn[end[1]]);
    			write(max(ans.data[0][0], ans.data[1][0])), putchar('
    ');
    		}
    		return 0;
    	}
    }
    int main()
    {
    #ifdef BlueSpirit
    	freopen("4719.in", "r", stdin);
    #endif
    	return zyt::work();
    }
    
  • 相关阅读:
    LeetCode数字之和总结
    排序类总结
    web sockect的练习
    RNA速率scVelo
    创建Numpy数组的不同方式
    numpy的课程学习二
    scrapy的cmdline命令和其文件写入乱码问题
    scrapy选择器
    python数据分析的numpy学习笔记
    Numpy的学习笔记一
  • 原文地址:https://www.cnblogs.com/zyt1253679098/p/11182552.html
Copyright © 2011-2022 走看看