zoukankan      html  css  js  c++  java
  • 树形DP

    树形DP,顾名思义就是在树上进行dp,dp的时候要充分利用树的性质,注意考虑所有能转移的节点

    例:树的直径

    给你一颗点数为n的树,让你求这棵树的直径是多少,也就是求最长的两个点之间的距离。

    N<=100000

    in

    7 
    1 6 13 
    6 3 9 
    3 5 7 
    4 1 3 
    2 4 20 
    4 7 2 

    out

    52

    两种做法,复杂度都是O(N)

    1.dfs(或者bfs)

    先从任意一个点跑一遍dfs,然后遍历每一个点找到距离这个点最远的,然后再从这个点开始跑一边dfs,再找到距离这个点最远的,那么这两个点之间的距离就是树的直径

    不能跑负边权的树

    代码:

    #include<bits/stdc++.h>
    using namespace std;
    const int N=2333333,inf=0x3f3f3f3f;
    int n;
    
    int dis[N];
    int head[N],cnt;
    struct edge
    {
        int to,dis,nxt;
    }edg[N<<1];
    
    inline void add(int u,int v,int w)
    {
        edg[++cnt].dis=w;
        edg[cnt].to=v;
        edg[cnt].nxt=head[u];
        head[u]=cnt;
    }
    
    void dfs(int x,int fa)
    {
        for(int i=head[x];i;i=edg[i].nxt)
        {
            int v=edg[i].to;
            if(v==fa) continue;
            dis[v]=dis[x]+edg[i].dis;
            dfs(v,x);
        }
    }
    
    int main()
    {
        int n;
        cin>>n;
        for(int i=1;i<=n-1;i++)
        {
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
            add(v,u,w);
        }
        memset(dis,0x3f,sizeof(dis));
        dis[1]=0;
        dfs(1,0);
        int s=0,maxn=-1;
        for(int i=1;i<=n;i++)
        {
            if(dis[i]>maxn&&dis[i]!=inf)
            {
                maxn=dis[i];
                s=i;
            }
        }
        memset(dis,0x3f,sizeof(dis));
        dis[s]=0;
        dfs(s,0);
        for(int i=1;i<=n;i++)
        {
            if(dis[i]>maxn&&dis[i]!=inf)
            {
                maxn=dis[i];
                s=i;
            }
        }
        cout<<maxn;
    }

    2.树形dp

    设f[i]表示i这个点到子树里面的最远点是多远的,然后对于每一个点u求出以这个点为根的最远路径距离,直接找{f[son_i]+edge_i}的前两大值加起来即可。然后再在每一个点构成的答案里面取最大值就是全局的最优值了。

    代码:

    #include<bits/stdc++.h>
    using namespace std;
    const int N=2333333,inf=0x3f3f3f3f;
    int n;
    
    int f[N];
    int head[N],cnt;
    int ans;
    struct edge
    {
        int to,dis,nxt;
    }edg[N<<1];
    
    inline void add(int u,int v,int w)
    {
        edg[++cnt].dis=w;
        edg[cnt].to=v;
        edg[cnt].nxt=head[u];
        head[u]=cnt;
    }
    
    void dp(int x,int fa)
    {
        int maxn=0;
        for(int i=head[x];i;i=edg[i].nxt)
        {
            int v=edg[i].to;
            if(v==fa) continue;
            dp(v,x);
            ans=max(ans,maxn+edg[i].dis+f[v]);
            f[x]=max(f[x],f[v]+edg[i].dis);
            maxn=max(maxn,f[v]+edg[i].dis);
        }
    }
    
    int main()
    {
        int n;
        cin>>n;
        for(int i=1;i<=n-1;i++)
        {
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
            add(v,u,w);
        }
        dp(1,0);
        cout<<ans;
    }

    例:没有上司的舞会

                        

                         

    这道题算是树形dp中比较经典而且比较简单的题了

    首先注意到对于每一个节点,可以选或者不选

    如果选的话,他的所有儿子都不能选

    如果不选的话,他的所有儿子可以选也可以不选,显然要在选或者不选里面取最大的

    设f[i][0/1]表示当前在第i个点,这个点 不选/选 的最大价值

    于是状态转移方程

    f[i][1]=∑f[son[i]][0]

    f[i][0]=∑max(f[son[i]][0],f[son[i]][1]);

    代码:

    #include<bits/stdc++.h>
    using namespace std;
    
    typedef long long ll;
    
    inline ll read()
    {
        ll ans=0;
        char last=' ',ch=getchar();
        while(ch<'0'||ch>'9') last=ch,ch=getchar();
        while(ch>='0'&&ch<='9') ans=ans*10+ch-'0',ch=getchar();
        if(last=='-') ans=-ans;
        return ans;
    }
    
    struct edge
    {
        int to,nxt;
    }edg[6050];
    int head[6050],cnt,in[6050];
    
    int n,r[6050],f[6050][3];
    int root;
    
    inline void add(int u,int v)
    {
        edg[++cnt].to=v;
        edg[cnt].nxt=head[u];
        head[u]=cnt;
    }
    
    void dp(int now)
    {
        f[now][1]=r[now];
        for(int i=head[now];i;i=edg[i].nxt)
        {
            int v=edg[i].to;
            dp(v);
            f[now][1]+=f[v][0];
            f[now][0]+=max(f[v][0],f[v][1]);
        }
    }
    
    int main()
    {
        n=read();
        for(int i=1;i<=n;i++) r[i]=read();
        for(int i=1;i<=n-1;i++)
        {
            int v=read(),u=read();in[v]++;
            add(u,v);
        }
        for(int i=1;i<=n;i++)
        {
            if(in[i]==0) root=i;
        }
        dp(root);
        cout<<max(f[root][0],f[root][1]);
    }
  • 相关阅读:
    echarts中3D地球模型
    面试题68
    Java正确创建对象数组
    8.Arrays类和比较器
    7.Base64类和UUID类
    6.大数字处理类
    3.JVM重要知识点
    2.JVM基础知识点
    1.JVM入门知识
    6.适配器模式
  • 原文地址:https://www.cnblogs.com/lcezych/p/11482580.html
Copyright © 2011-2022 走看看