zoukankan      html  css  js  c++  java
  • 关于树上DP的转移方式与复杂度证明

    以后要勤写总结了唔

    这种优化针对于转移的状态数与子树规模有关的柿子。

    例如对于n个树型依赖物品的树上背包dp,每个节点是一个物品且大小为1,设待转移结点u则u的背包容量不会超过u的子树规模,转移子节点v占据的容积不会超过v的子树规模。所以我们有以下转移方式:

    • 最脑残的,直接开siz[u]个固定背包容量(第一层for ∑siz[v]),用儿子来”填“(第二层for siz[v])。显然在转移非最后一个儿子时有大量无用状态。
    • 较脑小的(是我 并不),开∑已扫过儿子和当前儿子的siz(第一层for ∑已扫过儿子和当前儿子的siz「要为该儿子‘填’留出空间」),依然是用儿子”填“待转移状态(第二层for siz[v])。然而这种转移复杂度依然不乐观,接近于n^2(不会证,但的确被总复杂度n^2的数据卡掉了)再加上dfs n^3,不够优秀。
    • 卓越算法,既然”填“不行,我们可以”推“(因为有效的状态比无效状态少)。开∑已扫过儿子的siz(第一层for 都可能对答案造成贡献),这次我们用儿子和左边”推“状态,类似dp[u][k+j]的形式。对于一些题面较为复杂的限制条件,我们可以开个temp数组(类似于temp[0/1滚动儿子][0/1限制条件 状态1][状态2]...)完成以上转移,然后再把temp塞到dp数组里。这种转移依然有三层for,但确有O(n^2)的优秀复杂度。以下给出证明:

    首先脑补一下转移的图(其实是我懒得画了),对于当前转移的v子树(之所以说子树是因为最多转移它的siz)与之前已考虑的子子树中所有点构成点对且点对的lca为u,当完成对所有点的dfs,所有点作为lca的情况都被考虑且保证点对不重复(点对的lca唯一),将构成n^2个点对。得证。

     1     F(i,2,son[u][0])
     2     {
     3         int v=son[u][i];
     4         F(j,0,deep[u])
     5         {
     6             F(k,0,deep[son[u][i]])
     7             {
     8                 temp[i&1][0][max(j,k+1)]+=temp[i&1^1][0][j]*f[v][k]%p;       temp[i&1][0][max(j,k+1)]%=p;
     9                 temp[i&1][1][max(j,k)]+=temp[i&1^1][0][j]*f[v][k]%p;         temp[i&1][1][max(j,k)]%=p;
    10                 temp[i&1][1][max(j,k+1)]+=temp[i&1^1][1][j]*f[v][k]%p;        temp[i&1][1][max(j,k+1)]%=p;
    11             }
    12         }
    13         F(j,0,deep[u]) temp[i&1^1][0][j]=temp[i&1^1][1][j]=0;
    14         deep[u]=max(deep[u],deep[son[u][i]]+1);
    15     }
    16     F(j,0,deep[u]) f[u][j]=temp[son[u][0]&1][1][j];
    大概长这样

    相关题目(在本家OJ NOIP模拟3里)

    树上染色

    点与点之间不好推,平时应多注意边的贡献。对于一条边,它的贡献为$一边的黑点数*另一边的黑点数*边权+一边的白点数*另一边的白点数*边权$,我们可以定义状态dp[u][j]为u子树中选出j个黑点的最大收益,简化下,即在u阶段下在$|son[u]|$个分组中每个至多选一个(可以不选)状态权值为dp[v][k]重量为k使背包dp[u][j]收益最大,然后就可以在树上进行背包DP。

    观察$N<=2000$,那么我们就需要top所说的$O(N^2)$转移了。由于这题dp柿子比较简单,用不用temp无所谓,只要“推”状态就可以达到。

     1 #include<cstdio>
     2 #include<vector>
     3 #include<set>
     4 #include<cmath>
     5 #include<cstring>
     6 #include<algorithm>
     7 #define MAXN 2005
     8 #define ll long long
     9 #define reg register
    10 #define F(i,a,b) for(register int (i)=(a);(i)<=(b);++(i))
    11 using namespace std;
    12 inline int read();
    13 struct R{
    14     int u,v,next;
    15     ll w;
    16 }r[MAXN<<1];
    17 int n,black_tot;
    18 int fir[MAXN],o,siz[MAXN];
    19 ll dp[MAXN][MAXN];
    20 void add(int u,int v,ll w)
    21 {
    22     r[++o].u=u;
    23     r[o].v=v;
    24     r[o].w=w;
    25     r[o].next=fir[u];
    26     fir[u]=o;
    27 }
    28 void dfs(int u,int fa)
    29 {
    30     siz[u]=1;
    31     for(reg int i=fir[u];i;i=r[i].next)
    32     {
    33         int v=r[i].v;
    34         if(v==fa) continue;
    35         dfs(v,u);
    36         for(reg int j=min(siz[u],black_tot);j>=0;--j)
    37             for(reg int k=min(siz[v],black_tot-j);k>=0;--k)
    38                 dp[u][j+k]=max(dp[u][j+k],dp[u][j]+dp[v][k]+1ll*k*(black_tot-k)*r[i].w+1ll*(siz[v]-k)*(n-siz[v]-black_tot+k)*r[i].w);
    39         siz[u]+=siz[v];
    40     }
    41 }
    42 int main()
    43 {
    44     n=read(); black_tot=read();
    45     int a,b;
    46     ll t;
    47     F(i,1,n-1)
    48     {
    49         a=read(); b=read(); scanf("%lld",&t);
    50         add(a,b,t); add(b,a,t);
    51     }
    52     dfs(1,0);
    53     printf("%lld",dp[1][black_tot]);
    54     return 0;
    55 }
    56 inline int read()
    57 {
    58     int x=0;
    59     char tc=getchar();
    60     while(tc<'0'||tc>'9') tc=getchar();
    61     while(tc>='0'&&tc<='9') x=x*10+tc-48,tc=getchar();
    62     return x;
    63 }
    不用temp
     1 #include<cstdio>
     2 #include<vector>
     3 #include<set>
     4 #include<cmath>
     5 #include<cstring>
     6 #include<algorithm>
     7 #define MAXN 2005
     8 #define ll long long
     9 #define inf (1e9)+1
    10 #define reg register
    11 #define F(i,a,b) for(register int (i)=(a);(i)<=(b);++(i))
    12 using namespace std;
    13 inline int read();
    14 struct R{
    15     int u,v,w,next;
    16 }r[MAXN<<1];
    17 int n,black_tot;
    18 int fir[MAXN],o;
    19 int siz[MAXN];
    20 ll dp[MAXN][MAXN],temp[2][MAXN];
    21 void add(int u,int v,int w)
    22 {
    23     r[++o].u=u;
    24     r[o].v=v;
    25     r[o].w=w;
    26     r[o].next=fir[u];
    27     fir[u]=o;
    28 }
    29 void dfs(int u,int fa)
    30 {
    31     siz[u]=1;
    32     for(reg int i=fir[u];i;i=r[i].next)
    33     {
    34         int v=r[i].v;
    35         if(v==fa) continue;
    36         dfs(v,u);
    37     }
    38     memset(temp,0,sizeof(temp));
    39     int cur=0;
    40     for(reg int i=fir[u];i;i=r[i].next)
    41     {
    42         int v=r[i].v;
    43         if(v==fa) continue;
    44         for(reg int j=min(siz[u],black_tot);j>=0;--j)
    45         {
    46             for(reg int k=0;k<=min(siz[v],black_tot-j);++k)
    47         //    for(reg int k=min(siz[v],black_tot-j);k>=0;--k)
    48             {
    49                 temp[cur^1][j+k]=max(temp[cur^1][j+k],temp[cur][j]+dp[v][k]+1ll*k*(black_tot-k)*r[i].w+1ll*(siz[v]-k)*(n-siz[v]-black_tot+k)*r[i].w);
    50             }
    51         }
    52         siz[u]+=siz[v];
    53         for(reg int j=0;j<=siz[u];++j) temp[cur][j]=0;
    54         cur^=1;
    55     }
    56     for(reg int j=0;j<=siz[u];++j) dp[u][j]=temp[cur][j];
    57 }
    58 int main()
    59 {
    60 //    freopen("data.in","r",stdin);
    61 //    freopen("data.out","w",stdout);
    62     n=read(); black_tot=read();
    63     int a,b;
    64     ll t;
    65     F(i,1,n-1)
    66     {
    67         a=read(); b=read(); scanf("%lld",&t);
    68         add(a,b,t); add(b,a,t);
    69     }
    70     dfs(1,0);
    71     printf("%lld
    ",dp[1][black_tot]);
    72     return 0;
    73 }
    74 inline int read()
    75 {
    76     int x=0;
    77     char tc=getchar();
    78     while(tc<'0'||tc>'9') tc=getchar();
    79     while(tc>='0'&&tc<='9') x=x*10+tc-48,tc=getchar();
    80     return x;
    81 }
    temp

    可怜与超市

    搜题解有惊喜(滑稽)

    优惠券之间有限制,$1<xi<i$?联想下随机树的生成方法(蓝皮书后),不难发现限制关系是树形的。

    我们把$xi->i$连边建树,那么对于一个节点v能使用优惠券,当且仅当它的父亲节点u能使用优惠券且购买。

    这题稍有一点反套路在于b很大,不能做常规的背包dp定义第二维为背包容量。

    那么我们无妨改变dp数组含义定义$dp[u][j][0/1]$为子树u中购买j商品用(1)或不用(0)卷的最小花费,显然当值大于b时不再能有贡献。

    然后就可以写出dp式子:

    dp[u][j][0]=min(dp[u][j-k][0]+dp[v][k][0])

    dp[u][j][1]=min(dp[u][j-k][1]+min(dp[v][k][0]+dp[v][k][1]))  这式子一开始我还以为是错的e

    然后你就会T70。这就是上文所述的第二种转移。

    我们把它转化为推的形式,且用temp数组维护。

    temp[cur^1][0][j+k]=min(temp[cur^1][0][j+k],min(temp[cur][0][j]+dp[v][k][0],temp[cur][0][j+k]));

    if(j)temp[cur^1][1][j+k]=min(temp[cur^1][1][j+k],min(temp[cur][1][j]+min(dp[v][k][0],dp[v][k][1]),temp[cur][1][j+k]));

    滚动维护前缀子树temp和当前子树v更新。

    siz[u]的更新一定一定要放在后边,不然就会退化成$O(n^3)$

     

  • 相关阅读:
    Python操作Excel
    JMeter生成UUID方式
    JMeter之Beanshell用法
    JMeter后置处理器
    JMeter后置处理器
    Python之正则匹配 re库
    read(),readline() 和 readlines() 比较
    Python的位置参数、默认参数、关键字参数、可变参数之间的区别
    调查管理系统 -(6)自定义Struts2的拦截器&自定义UserAware接口&Action中模型赋值问题&Hibernate懒加载问题
    调查管理系统 -(5)Struts2配置&用户注册/登录/校验
  • 原文地址:https://www.cnblogs.com/hzoi-yzh/p/11201918.html
Copyright © 2011-2022 走看看