zoukankan      html  css  js  c++  java
  • DDP入门

    DDP,即动态动态规划,可以用于解决一类带修改的DP问题。
    我们从一个比较简单的东西入手,最大子段和。
    带修改的最大子段和其实是常规问题了,经典的解决方法是用线段树维护从左,右开始的最大子段和和区间最大子段和,然后进行合并。
    现在我们换一种方法来解决它。我们假设(f[i])表示以i为结尾的最大子段和大小,(g[i])表示[1,i]的最大子段和大小,显然有转移:(f[i] = max(f[i-1]+a[i],a[i]),g[i] = max(g[i-1],f[i]))

    这个DP每次修改显然要(O(n))
    我们考虑到好多在DP的时候,我们都用矩阵来加速递推。
    我们现在引入全新的思想,如何将它改写成矩阵呢?
    其实矩阵乘法能够成立,依赖的是乘法对加法有分配律。之后我们发现,加法对取(min/max)的操作也是有分配律的。比如(a+max(b,c) = max(a+b,a+c))
    那么我们完全可以考虑重新定义矩阵乘法,使得其满足如下的运算:(C[i][j] = max{A[i][k]+B[k][j]})

    这样的话……刚才的转移方程,我们就可以改写成如下的形式了。

    [egin{bmatrix} a_i & -infty & a_i \ a_i & 0 &a_i \ -infty & -infty & 0end{bmatrix} imes egin{bmatrix} f_{i-1}\ g_{i-1} \ 0end{bmatrix}quad = egin{bmatrix}f_i \ g_i \ 0end{bmatrix} ]

    那么我们就可以用线段树维护区间矩阵乘积来计算答案了。

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define fr friend inline
    #define y1 poj
    #define pr pair<int,int>
    #define fi first
    #define sc second
    #define mp make_pair
    
    using namespace std;
    typedef long long ll;
    const int M = 200005;
    const int INF = 1e9+7;
    const double eps = 1e-7;
    
    int read()
    {
       int ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    struct matrix
    {
       int f[3][3];
       matrix(){memset(f,0,sizeof(f));}
       void change(int x)
       {
          f[0][0] = f[1][0] = f[0][2] = f[1][2] = x;
          f[0][1] = f[2][0] = f[2][1] = -INF;
       }
       friend matrix operator + (const matrix &a,const matrix &b)
       {
          matrix c;
          rep(i,0,2) rep(j,0,2) c.f[i][j] = -INF;
          rep(k,0,2)
          rep(i,0,2)
          rep(j,0,2)
          c.f[i][j] = max(c.f[i][j],a.f[i][k] + b.f[k][j]);
          return c;
       }
    };
    
    struct node
    {
       matrix mat;
    }t[M<<2];
    
    int n,q,x,y,op;
    
    void build(int p,int l,int r)
    {
       if(l == r) {t[p].mat.change(read());return;}
       int mid = (l+r) >> 1;
       build(p<<1,l,mid),build(p<<1|1,mid+1,r);
       t[p].mat = t[p<<1].mat + t[p<<1|1].mat;
    }
    
    void modify(int p,int l,int r,int pos,int val)
    {
       if(l == r) {t[p].mat.change(val);return;}
       int mid = (l+r) >> 1;
       if(pos <= mid) modify(p<<1,l,mid,pos,val);
       else modify(p<<1|1,mid+1,r,pos,val);
       t[p].mat = t[p<<1].mat + t[p<<1|1].mat;
    }
    
    matrix query(int p,int l,int r,int kl,int kr)
    {
       if(l == kl && r == kr) return t[p].mat;
       int mid = (l+r) >> 1;
       if(kr <= mid) return query(p<<1,l,mid,kl,kr);
       else if(kl > mid) return query(p<<1|1,mid+1,r,kl,kr);
       else return query(p<<1,l,mid,kl,mid) + query(p<<1|1,mid+1,r,mid+1,kr);
    }
    
    int main()
    {
       n = read(),build(1,1,n),q = read();
       while(q--)
       {
          op = read(),x = read(),y = read();
          if(op == 0) modify(1,1,n,x,y);
          else
          {
         matrix k = query(1,1,n,x,y);
         printf("%d
    ",max(k.f[1][0],k.f[1][2]));
          }
       }
       return 0;
    }
    

    之后我们再来考虑下一个问题。树上最大独立集。
    (f[i][0])表示不选i,子树中最大独立集的大小,(f[i][1])表示选i,子树中最大独立集的大小。
    显然有(f[i][0] = sum max(f[v][0],f[v][1]),f[i][1] = sum f[v][0] + a[i])
    我们要把这玩意改写成矩阵的形式。但是我们首先要使用数据结构维护树,比如树剖。(LCT版的我还不会)
    因为树剖可以把重链整成一段连续的区间,那么我们先把与重链无关的一些东西提取出来。这样,我们设(g[i][0/1])表示不取/取i,i的非重儿子中最大独立集的大小
    这样的话,dp的方程就变成了这样:(f[i][0] =g[i][0] + max(f[son[i]][0],f[son[i]][1]),f[i][1] = g[i][1] + f[son[i]][0])
    然后就可以开心的写成矩阵的形式:

    [egin{bmatrix} g[i][0] & g[i][0] \ g[i][1] & -inftyend{bmatrix} imes egin{bmatrix} f[son[i][0]]\ f[son[i]][1] end{bmatrix}= egin{bmatrix}f[i][0] \ f[i][1]end{bmatrix} ]

    那么现在我们就可以用树剖+矩阵去维护了。这个和普通的树剖有一些区别,就是我们需要先跑一次树DP来计算出来f,g数组,之后初始化矩阵,每次从修改点跳重链跳到根节点,注意每次跳重链的时候要取一段完整的重链,所以我们还需要额外记录链的底部在哪。
    然后就不大难修改了。线段树和上面基本是一样的,树剖也比较简单,修改过程就是一个先减再加的过程。
    看一下luogu的模板

    #include<bits/stdc++.h>
    #define rep(i,a,n) for(int i = a;i <= n;i++)
    #define per(i,n,a) for(int i = n;i >= a;i--)
    #define enter putchar('
    ')
    #define pr pair<int,int>
    #define mp make_pair
    #define fi first
    #define sc second
    using namespace std;
    typedef long long ll;
    const int M = 200005;
    const int N = 10000005;
    const int INF = 1e9;
    
    int read()
    {
       int ans = 0,op = 1;char ch = getchar();
       while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
       while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
       return ans * op;
    }
    
    struct matrix
    {
       int f[2][2];
       matrix(){memset(f,0,sizeof(f));}
       friend matrix operator + (const matrix &a,const matrix &b)
       {
          matrix c;
          rep(i,0,1)
          rep(j,0,1) c.f[i][j] = -INF;
          rep(k,0,1)
          rep(i,0,1)
          rep(j,0,1) c.f[i][j] = max(c.f[i][j],a.f[i][k] + b.f[k][j]);
          return c;
       }
    }val[M];
    
    struct node
    {
       matrix mat;
    }t[M<<1];
    
    struct edge
    {
       int next,to,from;
    }e[M<<1];
    
    int n,m,head[M],ecnt,v[M],top[M],fa[M],hson[M],size[M];
    int ed[M],x,y,pos[M],dfn[M],idx,F[M][2];
    void add(int x,int y) {e[++ecnt] = {head[x],y,x},head[x] = ecnt;}
    
    void dfs1(int x,int f)
    {
       size[x] = 1,fa[x] = f;
       for(int i = head[x];i;i = e[i].next)
       {
          if(e[i].to == f) continue;
          dfs1(e[i].to,x),size[x] += size[e[i].to];
          if(size[e[i].to] > size[hson[x]]) hson[x] = e[i].to;
       }
    }
    
    void dfs2(int x,int t)
    {
       dfn[x] = ++idx,pos[idx] = x,top[x] = t,ed[t] = max(ed[t],idx);
       F[x][0] = 0,F[x][1] = v[x];
       val[x].f[0][0] = val[x].f[0][1] = 0,val[x].f[1][0] = v[x];
       if(hson[x])
       {
          int v = hson[x];
          dfs2(v,t),F[x][0] += max(F[v][0],F[v][1]),F[x][1] += F[v][0];
       }
       for(int i = head[x];i;i = e[i].next)
       {
          int v = e[i].to;
          if(v == fa[x] || v == hson[x]) continue;
          dfs2(v,v),F[x][0] += max(F[v][0],F[v][1]),F[x][1] += F[v][0];
          val[x].f[0][0] += max(F[v][0],F[v][1]);
          val[x].f[0][1] = val[x].f[0][0],val[x].f[1][0] += F[v][0];
       }
    }
    
    void build(int p,int l,int r)
    {
       if(l == r) {t[p].mat = val[pos[l]];return;}
       int mid = (l+r) >> 1;
       build(p<<1,l,mid),build(p<<1|1,mid+1,r);
       t[p].mat = t[p<<1].mat + t[p<<1|1].mat;
    }
    
    void modify(int p,int l,int r,int x)
    {
       if(l == r){t[p].mat = val[pos[x]];return;}
       int mid = (l+r) >> 1;
       if(x <= mid) modify(p<<1,l,mid,x);
       else modify(p<<1|1,mid+1,r,x);
       t[p].mat = t[p<<1].mat + t[p<<1|1].mat;
    }
    
    matrix query(int p,int l,int r,int kl,int kr)
    {
       if(l == kl && r == kr) return t[p].mat;
       int mid = (l+r) >> 1;
       if(kr <= mid) return query(p<<1,l,mid,kl,kr);
       else if(kl > mid) return query(p<<1|1,mid+1,r,kl,kr);
       else return query(p<<1,l,mid,kl,mid) + query(p<<1|1,mid+1,r,mid+1,kr);
    }
    
    void uprange(int x,int y)
    {
       val[x].f[1][0] += y - v[x],v[x] = y;
       matrix A,B;
       while(x)
       {
          B = query(1,1,n,dfn[top[x]],ed[top[x]]),modify(1,1,n,dfn[x]);
          A = query(1,1,n,dfn[top[x]],ed[top[x]]),x = fa[top[x]];
          val[x].f[0][0] += max(A.f[0][0],A.f[1][0]) - max(B.f[0][0],B.f[1][0]);
          val[x].f[0][1] = val[x].f[0][0];
          val[x].f[1][0] += (A.f[0][0] - B.f[0][0]);
       }
    }
    
    int main()
    {
       n = read(),m = read();
       rep(i,1,n) v[i] = read();
       rep(i,1,n-1) x = read(),y = read(),add(x,y),add(y,x);
       dfs1(1,0),dfs2(1,1),build(1,1,n);
       while(m--)
       {
          x = read(),y = read(),uprange(x,y);
          matrix ans = query(1,1,n,dfn[1],ed[1]);
          printf("%d
    ",max(ans.f[0][0],ans.f[1][0]));
       }
       return 0;
    }
    
    
  • 相关阅读:
    mysql远程执行sql脚本
    数据库死锁
    sqlserver 数据库之调优
    sqlserver 数据库之性能优化
    Session共享的解决办法
    关于对session机制的理解--通俗易懂
    kafka之常用命令
    分布式消息队列之kafka
    vuejs调试代码
    json
  • 原文地址:https://www.cnblogs.com/captain1/p/10459348.html
Copyright © 2011-2022 走看看