zoukankan      html  css  js  c++  java
  • [模板] 动态DP

    一、题目

    点此看题

    二、解法

    动态 (dp) 的思路主要是用矩阵乘法加速 (dp),所以首先要知道矩阵乘法的扩展版:

    [c(i,k)=max{a(i,j)+b(j,k)} ]

    令人震惊的是上面这东西也满足结合律,现在我们来证明一下,假设有三个矩阵 (a,b,c) 相乘,大小分别是 (n imes m,m imes p,p imes q),我们把最终某一个位置上的值暴力展开:

    [(i,j)=max_{k=1}^ma(i,k)+Big(max_{l=1}^p b(k,l)+c(l,j)Big) ]

    [=max_{k=1}^mmax_{l=1}^p a(i,k)+b(k,l)+c(l,j) ]

    [=max_{l=1}^pBig(max_{k=1}^m a(i,k)+b(k,l)Big)+c(l,j) ]

    所以先乘 (a,b) 还是先乘 (b,c) 对答案没有影响,结合律得证。


    首先写出暴力的 (dp) 柿子,设 (f(u,0/1)) 表示 (u) 这个点不选(/)选的最大权值,转移:

    [f(u,0)=summax(f(v,0),f(v,1)) ]

    [f(u,1)=a(u)+sum f(v,0) ]

    先来考虑一下链怎么做,我们构造一个像这样的转移矩阵:

    [left(egin{matrix}0&0\a(u)&-inftyend{matrix} ight) imesleft(egin{matrix}f(v,0)\f(v,1)end{matrix} ight)=left(egin{matrix}f(u,0)\f(u,1)end{matrix} ight) ]

    然后要求根的 (dp) 值就直接把所有矩阵乘起来就行了,时间复杂度 (O(nlog n))

    那么我们能不能把上面的做法搬到树上呢?考虑把树剖分成链然后套上面的做法,也就是用树链剖分。每个点的转移矩阵就针对他的重儿子来定义,但同时我们要考虑轻儿子对他 (dp) 值的贡献,所以再定义 (f'(u,0/1)) 表示 (u) 不选(/)选,考虑 (u)(u) 的所有轻儿子的最大值,那么有如下转移:

    [f(u,0)=f'(u,0)+max{f(son,0),f(son,1)} ]

    [f(u,1)=f'(u,1)+f(son,0) ]

    写成矩阵就是这个样子的:

    [left(egin{matrix}f'(u,0)&f’(u,0)\f'(u,1)&-inftyend{matrix} ight) imesleft(egin{matrix}f(son,0)\f(son,1)end{matrix} ight)=left(egin{matrix}f(u,0)\f(u,1)end{matrix} ight) ]

    先考虑怎么统计答案,我们找到根所在的那条重链,把所有转移矩阵乘起来就行了。

    再考虑如何修改,修改一个点的点权只会对它的祖先产生影响。而且由于路径上只有 (O(log n)) 条轻边,所以一共只需要改 (O(log n)) 个矩阵,这部分可以看看代码:

    void modify(int u,int w)//把u点权改成w 
    {
    	val[u].a[1][0]+=w-a[u];
    	a[u]=w;
    	while(u)
    	{
    		matrix t1=ask(1,1,n,num[top[u]],num[bot[u]]);//算出f(u,0/1)
    		upd(1,1,n,num[u]);//在线段树上更新那个位置的矩阵
    		matrix t2=ask(1,1,n,num[top[u]],num[bot[u]]);//算出新的f(u,0/1)
    		u=fa[top[u]];//要更新重链顶端父亲的转移矩阵
    		val[u].a[0][0]+=max(t2.a[0][0],t2.a[1][0])-max(t1.a[0][0],t1.a[1][0]);
    		val[u].a[0][1]=val[u].a[0][0];
    		val[u].a[1][0]+=t2.a[0][0]-t1.a[0][0];
    	}
    }
    

    用一个线段树维护矩阵套上树链剖分:(O(2^3cdot nlog^2 n))

    #include <cstdio>
    #include <iostream>
    using namespace std;
    const int M = 100005;
    const int inf = 1e9;
    int read()
    {
    	int x=0,f=1;char c;
    	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
    	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
    	return x*f;
    }
    int n,m,tot,cnt,f[M],a[M],id[M],fa[M];
    int siz[M],son[M],num[M],top[M],bot[M],dp[M][2];
    //top表示重链头
    //bot表示重链尾
    //num表示这个点在线段树上的位置
    //id表示线段树上位置所对应的点 
    //dp表示初始的dp数组 
    struct edge
    {
    	int v,next;
    	edge(int V=0,int N=0) : v(V) , next(N) {}
    }e[2*M];
    struct matrix
    {
    	int a[2][2];
    	matrix() {a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
    	matrix operator * (const matrix &b) const
    	{
    		matrix r;
    		for(int i=0;i<2;i++)
    			for(int j=0;j<2;j++)
    				for(int k=0;k<2;k++)
    					r.a[i][k]=max(r.a[i][k],a[i][j]+b.a[j][k]);
    		return r;
    	}
    	void print()
    	{
    		puts("---------");
    		for(int i=0;i<2;i++,puts(""))
    			for(int j=0;j<2;j++)
    				printf("%d ",a[i][j]);
    	}
    }val[M],tr[4*M];
    //线段树部分
    void up(int i)
    {
    	tr[i]=tr[i<<1]*tr[i<<1|1];
    }
    void build(int i,int l,int r)
    {
    	if(l==r)
    	{
    		tr[i]=val[id[l]];
    		return ;
    	}
    	int mid=(l+r)>>1;
    	build(i<<1,l,mid);
    	build(i<<1|1,mid+1,r);
    	up(i);
    }
    void upd(int i,int l,int r,int x)//修改x这个位置的矩阵
    {
    	if(l==r)
    	{
    		tr[i]=val[id[x]];
    		return ;
    	}
    	int mid=(l+r)>>1;
    	if(mid>=x) upd(i<<1,l,mid,x);
    	else upd(i<<1|1,mid+1,r,x);
    	up(i);
    }
    matrix ask(int i,int l,int r,int L,int R)
    {
    	if(L<=l && r<=R) return tr[i];
    	int mid=(l+r)>>1;
    	if(R<=mid) return ask(i<<1,l,mid,L,R);
    	if(L>mid) return ask(i<<1|1,mid+1,r,L,R);
    	return ask(i<<1,l,mid,L,R)*ask(i<<1|1,mid+1,r,L,R);
    }
    //树链剖分部分 
    void dfs1(int u,int p)
    {
    	siz[u]=1;fa[u]=p;
    	for(int i=f[u];i;i=e[i].next)
    	{
    		int v=e[i].v;
    		if(v==p) continue;
    		dfs1(v,u);
    		siz[u]+=siz[v];
    		if(siz[son[u]]<siz[v]) son[u]=v;
    	}
    }
    void dfs2(int u,int tp)
    {
    	top[u]=tp;
    	num[u]=++cnt;
    	id[cnt]=u;
    	val[u].a[0][0]=val[u].a[0][1]=0;
    	val[u].a[1][0]=dp[u][1]=a[u];
    	if(son[u])
    	{
    		dfs2(son[u],tp),bot[u]=bot[son[u]];
    		dp[u][0]+=max(dp[son[u]][0],dp[son[u]][1]);
    		dp[u][1]+=dp[son[u]][0];
    	}
    	else bot[u]=u;//如果没有重儿子底部就是自己
    	for(int i=f[u];i;i=e[i].next)
    	{
    		int v=e[i].v;
    		if(v==fa[u] || v==son[u]) continue;
    		dfs2(v,v);
    		dp[u][0]+=max(dp[v][0],dp[v][1]);
    		dp[u][1]+=dp[v][0];
    		val[u].a[0][0]+=max(dp[v][0],dp[v][1]);
    		val[u].a[0][1]=val[u].a[0][0];
    		val[u].a[1][0]+=dp[v][0];
    		//(0,0)/(0,1)表示这个点不选,(1,0)表示这个点要选 
    	}
    }
    void modify(int u,int w)//把u点权改成w 
    {
    	val[u].a[1][0]+=w-a[u];
    	a[u]=w;
    	while(u)
    	{
    		matrix t1=ask(1,1,n,num[top[u]],num[bot[u]]);
    		upd(1,1,n,num[u]);
    		matrix t2=ask(1,1,n,num[top[u]],num[bot[u]]);
    		u=fa[top[u]];
    		val[u].a[0][0]+=max(t2.a[0][0],t2.a[1][0])-max(t1.a[0][0],t1.a[1][0]);
    		val[u].a[0][1]=val[u].a[0][0];
    		val[u].a[1][0]+=t2.a[0][0]-t1.a[0][0];
    	}
    }
    signed main()
    {
    	n=read();m=read();
    	for(int i=1;i<=n;i++)
    		a[i]=read();
    	for(int i=1;i<n;i++)
    	{
    		int u=read(),v=read();
    		e[++tot]=edge(v,f[u]),f[u]=tot;
    		e[++tot]=edge(u,f[v]),f[v]=tot;
    	}
    	dfs1(1,0);
    	dfs2(1,1);
    	build(1,1,n);
    	while(m--)
    	{
    		int x=read(),y=read();
    		modify(x,y);
    		matrix t1=ask(1,1,n,num[1],num[bot[1]]);
    		printf("%d
    ",max(t1.a[0][0],t1.a[1][0]));
    	}
    }
    
  • 相关阅读:
    添加arcgis portal数据存储bad login user
    使用python从地图服务中提取数据
    山体
    也能用高德输入点击初始结果
    从源代码构建Qt6开发工具
    rust组件安装
    ubuntu apt-get 安装指定版本软件
    Ubuntu上如何查询和安装指定版本的软件
    gnutls not found using pkg-config
    Package not found
  • 原文地址:https://www.cnblogs.com/C202044zxy/p/14528935.html
Copyright © 2011-2022 走看看