zoukankan      html  css  js  c++  java
  • 【模板】动态 DP

    luogu传送门

    最近学了一下动态dp,感觉没有想象的难。

    动态DP

    simple的DP是这样的:

    给棵树,每个点给个权值,求一下最大权独立集。

    动态DP是这样的:

    给棵树,每个点给个权值还到处改,每次改的时候求一下最大权独立集。


    题外话:

    ($NOIP2018$保卫王国)

    大佬:动态$dp$板子直接秒。

    神仙:这个是什么让我想一想……倍增?(然后当场切掉)

    除了大佬除了神仙除了我:敲暴力。

    我:(看错题了爆蛋)


    几个前置知识:

    (1)重链剖分

    树剖时选择子树点数最多的儿子作重儿子。

    重剖有几个性质:

    1.终点一定是叶子。(树剖通性)

    2.重链剖分序上一条重链是一个连续的区间。(树剖通性)

    3.任意一个点走到根经过的轻链不超过$logn$条。(重剖特性

    4.……

    (2)矩阵

    著名矩乘由于满足结合律可以用来加速递推/分治。

    大佬们不知道咋想的就搞出了以下运算:

    $$egin{pmatrix}a00&a01\a10&a11end{pmatrix}+egin{pmatrix}b00&b01\b10&b11end{pmatrix} $$

    $$=egin{pmatrix}max(a00+b00,a01+b10)&max(a00+b10,a01+b11)\max(a10+b00,a11+b10)&max(a10+b01,a11+b11)end{pmatrix}$$

    这个东西竟然满足结合律……

    我们可以用他去搞一些事情了。


    树的最大权独立集:

    设$f[u][0/1]$表示在点$u$不取/取的情况下点$u$子树内的最大权独立集的权值。

    有$dp$如下:

    $f[u][0]=sum _{fa[v]=u} max(f[v][0],f[v][1])$

    $f[u][1]=w[u] + sum _{fa[v]=u} f[v][0]$

    让轻重儿子分开讨论有:

    $f[u][0]= max(f[son][0],f[son][1]) + sum _{fa[v]=u,v!=son} max(f[v][0],f[v][1])$

    $f[u][1]= f[son][0] +w[u]+sum_{fa[v]=u,v!=son}f[v][0]$

    设$g[u][0]=sum _{fa[v]=u,v!=son} max(f[v][0],f[v][1]),g[u][1]=w[u]+sum_{fa[v]=u,v!=son}f[v][0]$

    写成上面矩阵的形式是这样的:

    $f[u][0] = max(g[u][0]+f[son][0],g[u][0]+f[son][1])$

    $f[u][1] = max(g[u][1]+f[son][0],-inf+f[son][1])$

    就是$$egin{pmatrix}f[son][0]&f[son][1]end{pmatrix}+egin{pmatrix}g[u][0]&g[u][1]\g[u][0]&-infend{pmatrix}$$

    $$=egin{pmatrix}f[u][0]&f[u][1]end{pmatrix}$$

    神奇。

    考虑到重链剖分序重链上的点是按深度从浅到深排列的,我们应该这样(不然线段树左右要反着合并比较恶心):

    $$egin{pmatrix}g[u][0]&g[u][0]\g[u][1]&-infend{pmatrix}+egin{pmatrix}f[son][0]\f[son][1]end{pmatrix}$$

    $$=egin{pmatrix}f[u][0]\f[u][1]end{pmatrix}$$

    对于每个点保留带$g$的矩阵,这样一个点的$f$矩阵就是重链上从该点的矩阵一直向下加,加到叶子结点。

    区间加法用线段树平衡树等数据结构维护。

    每次修改之后,由于只会修改当前点到根路径上的$f$,我们可以爆跳树链。

    修改点权,当前树链上只会修改当前点的$g[1]$。

    树链之间转移时,考虑轻儿子的$f$对父节点$g$的影响。

    我们可以求出修改前后链顶的$f$值,然后扔到原来方程里更新父亲的$g$。

    重复上述操作一直更新到根。

    由于重剖$logn$条轻链,时间复杂度$O(nlog^2n)$。

    代码:

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    typedef long long ll;
    const int N = 100050;
    const ll Inf = 0x3f3f3f3f3f3f3f3fll;
    template<typename T>
    inline void read(T&x)
    {
        T f = 1,c = 0;char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();}
        x = f*c;
    }
    ll F[N][2],G[N][2],k[N];
    int n,m,hed[N],cnt;
    struct EG
    {
        int to,nxt;
    }e[N<<1];
    void ae(int f,int t)
    {
        e[++cnt].to = t;
        e[cnt].nxt = hed[f];
        hed[f] = cnt;
    }
    int dep[N],siz[N],top[N],pot[N],fa[N],son[N],tin[N],pla[N],tim;
    struct mt
    {
        ll s[2][2];
        mt(){memset(s,0,sizeof(s));}
        mt(ll g0,ll g1){s[0][0]=s[0][1]=g0,s[1][0]=g1,s[1][1]=-Inf;}
        mt operator + (const mt&a)const
        {
            mt ret;
            ret.s[0][0]=max(s[0][0]+a.s[0][0],s[0][1]+a.s[1][0]);
            ret.s[0][1]=max(s[0][0]+a.s[0][1],s[0][1]+a.s[1][1]);
            ret.s[1][0]=max(s[1][0]+a.s[0][0],s[1][1]+a.s[1][0]);
            ret.s[1][1]=max(s[1][0]+a.s[0][1],s[1][1]+a.s[1][1]);
            return ret;
        }
    }t[N];
    struct segtree
    {
        mt w[N<<2];
        void update(int u){w[u]=w[u<<1]+w[u<<1|1];}
        void build(int l,int r,int u)
        {
            if(l==r){w[u]=t[pla[l]]=mt(G[pla[l]][0],G[pla[l]][1]);return ;}
            int mid = (l+r)>>1;
            build(l,mid,u<<1);
            build(mid+1,r,u<<1|1);
            update(u);
        }
        void insert(int l,int r,int u,int qx)
        {
            if(l==r){w[u]=t[pla[l]];return ;}
            int mid = (l+r)>>1;
            if(qx<=mid)insert(l,mid,u<<1,qx);
            else insert(mid+1,r,u<<1|1,qx);
            update(u);
        }
        mt query(int l,int r,int u,int ql,int qr)
        {
            if(l==ql&&r==qr)return w[u];
            int mid = (l+r)>>1;
            if(qr<=mid)return query(l,mid,u<<1,ql,qr);
            else if(ql>mid)return query(mid+1,r,u<<1|1,ql,qr);
            else return query(l,mid,u<<1,ql,mid)+query(mid+1,r,u<<1|1,mid+1,qr);
        }
    }tr;
    void dfs0(int u,int f)
    {
        fa[u] = f,siz[u] = 1,dep[u] = dep[f]+1;
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to==f)continue;
            dfs0(to,u);siz[u]+=siz[to];
            if(siz[to]>siz[son[u]])son[u]=to;
        }
    }
    void dfs1(int u,int Top)
    {
        top[u] = Top,pot[u] = u,tin[u] = ++tim,pla[tim] = u;
        if(son[u])dfs1(son[u],Top),pot[u]=pot[son[u]];
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to!=fa[u]&&to!=son[u])
                dfs1(to,to);
        }
    }
    void dp(int u)
    {
        F[u][0]=0,F[u][1]=k[u];
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to==fa[u])continue;
            dp(to);
            F[u][0]+=max(F[to][0],F[to][1]);
            F[u][1]+=F[to][0];
        }
        G[u][0] = F[u][0] - max(F[son[u]][0],F[son[u]][1]);
        G[u][1] = F[u][1] - F[son[u]][0];
    }
    mt get_mt(int x){return tr.query(1,n,1,tin[top[x]],tin[pot[x]]);}
    void chg(int x,int y)
    {
        t[x].s[1][0] += y-k[x],k[x] = y;
        while(x)
        {
            mt m0 = get_mt(x);
            tr.insert(1,n,1,tin[x]);
            mt m1 = get_mt(x);
            x = fa[top[x]];
            if(!x)break;
            t[x].s[0][0]+=max(m1.s[0][0],m1.s[1][0])-max(m0.s[0][0],m0.s[1][0]);
            t[x].s[0][1]=t[x].s[0][0];
            t[x].s[1][0]+=m1.s[0][0]-m0.s[0][0];
        }
    }
    int main()
    {
    //    freopen("tt.in","r",stdin);
        read(n),read(m);
        for(int i=1;i<=n;i++)
            read(k[i]);
        for(int u,v,i=1;i<n;i++)
            read(u),read(v),ae(u,v),ae(v,u);
        dfs0(1,0),dfs1(1,1);
        dp(1);tr.build(1,n,1);
        for(int x,y,i=1;i<=m;i++)
        {
            read(x),read(y);
            chg(x,y);mt now = get_mt(1);
            printf("%lld
    ",max(now.s[0][0],now.s[1][0]));
        }
        return 0;
    }
    View Code

    附保卫王国代码:

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    typedef long long ll;
    const int N = 100050;
    const ll inf = 1e16;
    template<typename T>
    inline void read(T&x)
    {
        T f = 1,c = 0;char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();}
        x = f*c;
    }
    int n,m,hed[N],cnt;
    ll p[N],F[N][2],G[N][2];
    char op[10];
    struct EG
    {
        int to,nxt;
    }e[N<<1];
    void ae(int f,int t)
    {
        e[++cnt].to = t;
        e[cnt].nxt = hed[f];
        hed[f] = cnt;
    }
    int dep[N],siz[N],fa[N],son[N],top[N],pot[N],tin[N],pla[N],tim;
    void dfs0(int u,int f)
    {
        fa[u] = f,siz[u] = 1,dep[u] = dep[f]+1;
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to==f)continue;
            dfs0(to,u);
            siz[u]+=siz[to];
            if(siz[to]>siz[son[u]])son[u]=to;
        }
    }
    void dfs1(int u,int Top)
    {
        top[u] = Top,tin[u] = ++tim,pla[tim] = u;
        if(son[u])dfs1(son[u],Top),pot[u]=pot[son[u]];
        else pot[u] = u;
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to!=fa[u]&&to!=son[u])
                dfs1(to,to);
        }
    }
    void dp(int u)
    {
        F[u][0] = 0,F[u][1] = p[u];
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to==fa[u])continue;
            dp(to);
            F[u][0] += F[to][1];
            F[u][1] += min(F[to][0],F[to][1]);
        }
        G[u][0] = F[u][0] - F[son[u]][1];
        G[u][1] = F[u][1] - min(F[son[u]][0],F[son[u]][1]);
    }
    struct mt
    {
        ll s[2][2];
        mt(){memset(s,0,sizeof(s));}
        mt(ll g0,ll g1){s[0][0]=inf,s[0][1]=g0,s[1][0]=s[1][1]=g1;}
        mt operator + (const mt&a)const
        {
            mt ret;
            ret.s[0][0]=min(s[0][0]+a.s[0][0],s[0][1]+a.s[1][0]);
            ret.s[0][1]=min(s[0][0]+a.s[0][1],s[0][1]+a.s[1][1]);
            ret.s[1][0]=min(s[1][0]+a.s[0][0],s[1][1]+a.s[1][0]);
            ret.s[1][1]=min(s[1][0]+a.s[0][1],s[1][1]+a.s[1][1]);
            return ret;
        }
    }t[N];
    struct segtree
    {
        mt w[N<<2];
        void update(int u){w[u]=w[u<<1]+w[u<<1|1];}
        void build(int l,int r,int u)
        {
            if(l==r){w[u]=t[pla[l]]=mt(G[pla[l]][0],G[pla[l]][1]);return ;}
            int mid = (l+r)>>1;
            build(l,mid,u<<1),build(mid+1,r,u<<1|1);
            update(u);
        }
        void insert(int l,int r,int u,int qx)
        {
            if(l==r){w[u]=t[pla[l]];return ;}
            int mid = (l+r)>>1;
            if(qx<=mid)insert(l,mid,u<<1,qx);
            else insert(mid+1,r,u<<1|1,qx);
            update(u);
        }
        mt query(int l,int r,int u,int ql,int qr)
        {
            if(l==ql&&r==qr)return w[u];
            int mid = (l+r)>>1;
            if(qr<=mid)return query(l,mid,u<<1,ql,qr);
            else if(ql>mid)return query(mid+1,r,u<<1|1,ql,qr);
            else return query(l,mid,u<<1,ql,mid)+query(mid+1,r,u<<1|1,mid+1,qr);
        }
    }tr;
    mt get_mt(int x){return tr.query(1,n,1,tin[top[x]],tin[pot[x]]);}
    void chg(int a,int x)
    {
        if(x==1)t[a].s[1][0]-=inf;
        else t[a].s[1][0]+=inf;
        t[a].s[1][1]=t[a].s[1][0];
    }
    void upd(int x)
    {
        while(x)
        {
            mt m0 = get_mt(x);
            tr.insert(1,n,1,tin[x]);
            mt m1 = get_mt(x);
            x = fa[top[x]];
            if(!x)break;
            t[x].s[0][1]+=m1.s[1][1]-m0.s[1][1];
            t[x].s[1][0]+=min(m1.s[0][1],m1.s[1][1])-min(m0.s[0][1],m0.s[1][1]);
            t[x].s[1][1]=t[x].s[1][0];
        }
    }
    int main()
    {
    //    freopen("tt.in","r",stdin);
        read(n),read(m);scanf("%s",op);
        for(int i=1;i<=n;i++)read(p[i]);
        for(int u,v,i=1;i<n;i++)read(u),read(v),ae(u,v),ae(v,u);
        dfs0(1,0),dfs1(1,1);dp(1);tr.build(1,n,1);
        for(int a,x,b,y,i=1;i<=m;i++)
        {
            read(a),read(x),read(b),read(y);
            if((fa[a]==b||fa[b]==a)&&!x&&!y){puts("-1");continue;}
            chg(a,x);
            upd(a);
            chg(b,y);
            upd(b);
            mt M = get_mt(1);
            ll ans = min(M.s[0][1],M.s[1][1]);
            ans+=(x+y)*inf;
            printf("%lld
    ",ans);
            chg(a,!x);upd(a);
            chg(b,!y);upd(b);
        }
        return 0;
    }
    保卫王国
  • 相关阅读:
    自学Python编程的第二天----------来自苦逼的转行人
    自学Python编程的第一天----------来自苦逼的转行人
    A-B 高精度
    A+B 高精度
    [NOI2002]银河英雄传说
    口袋的天空
    修复公路(并查集)
    并查集
    Surjectivity is stable under base change
    为什么Fourier分析?
  • 原文地址:https://www.cnblogs.com/LiGuanlin1124/p/11101029.html
Copyright © 2011-2022 走看看