zoukankan      html  css  js  c++  java
  • hdu5293 lca+dp+树状数组+时间戳

    题意是给了 n 个点的树,会有m条链条 链接两个点,计算出他们没有公共点的最大价值,  公共点时这样计算的只要在他们 lca 这条链上有公共点的就说明他们相交

    dp[i]为这个点包含的子树所能得到的最大价值 

    sum[i]表示这个点没有选择经过i这个点链条的总价值

    两种选择 

    这个点没有被选择 

             dp[i]=sum[i]=sigma(dp[k])k为i的子树

    选择了某个链 

            假设这条链 为(tyuijk)

           那么dp[i]=(sum[i]-dp[u]-dp[j])+(sum[j]-dp[k])+dp[k] +(sum[u]-dp[y])+(sum[y]-dp[t])+sum[t];

          整理后发现 dp[i]=sum[i] +(sum[j]-dp[j])+(sum[k]-dp[k])+(sum[u]-dp[u])+(sum[y]-dp[y])+(sum[t]-dp[t]);

    使用lca计算出每条链的最近公共祖先,在这个最近公共祖先上判断是否使用这条链,还有我们可以使用时间戳加树状数组来求得sum和dp

    #include <iostream>
    #include <algorithm>
    #include <string.h>
    #include <cstdio>
    #include <vector>
    using namespace std;
    const int maxn=100000+10;
    int to[maxn*2],nx[maxn*2],H[maxn*2],numofedg,timoflook;
    int fa[maxn][20],first[maxn],last[maxn],depth[maxn];
    void addedg(int u, int v)
    {
         numofedg++; to[numofedg]=v; nx[numofedg]=H[u]; H[u]=numofedg;
         numofedg++; to[numofedg]=u; nx[numofedg]=H[v]; H[v]=numofedg;
    }
    void dfs(int cur, int per, int dep)
    {
        first[cur]=++timoflook;
        depth[cur]=dep;
        fa[cur][0]=per;
        for(int i=1; i<20; i++)
        {
            fa[cur][i]=fa[ fa[cur][i-1] ][ i-1 ];
        }
        for(int i=H[cur]; i; i=nx[i])
            {
                if(to[i]==per)continue;
                dfs(to[i],cur,dep+1);
            }
        last[cur]=++timoflook;
    }
    int getlca(int u,int v)
    {
         if(depth[u]<depth[v])swap(u,v);
         for(int i=19; i>=0; i--)
            {
                 if(depth[fa[u][i]]>=depth[v])
                    u=fa[u][i];
                 if(u==v)return u;
            }
         for(int i=19; i>=0; i--)
            {
                 if(fa[u][i]!=fa[v][i])
                 {
                     u=fa[u][i];
                     v=fa[v][i];
                 }
            }
            return fa[u][0];
    }
    struct Edg
    {
      int u,v,lca,val;
    }P[maxn];
    vector<int>E[maxn];
    int dp[maxn],sum[maxn],CS[maxn*3],CD[maxn*3];
    int lowbit(int x)
    {
         return x&-x;
    }
    void add(int x, int d, int *C)
    {
          while(x<=timoflook)
            {
                 C[x]+=d;
                 x+=lowbit(x);
            }
    }
    int getsum(int x, int *C)
    {
         int ret=0;
          while(x>0)
            {
                ret+=C[x];
                x-=lowbit(x);
            }
            return ret;
    }
    void solve(int cur, int per)
    {
         dp[cur]=sum[cur]=0;
         for(int i=H[cur]; i; i=nx[i])
            {
                if(to[i]==per)continue;
                solve(to[i],cur);
                sum[cur]+=dp[to[i]];
            }
         dp[cur]=sum[cur];
         for(int i=0; i<E[cur].size(); i++)
            {
                  int id=E[cur][i];
                  int u=P[id].u;
                  int v=P[id].v;
                  int t1=getsum(first[u],CS);
                  int t2=getsum(first[v],CS);
                  int t3=getsum(first[u],CD);
                  int t4=getsum(first[v],CD);
                  int tmp=t1+t2-t3-t4;
                  dp[cur]=max(dp[cur],tmp+P[id].val+sum[cur]);
            }
         add(first[cur],sum[cur],CS);
         add(last[cur],-sum[cur],CS);
         add(first[cur],dp[cur],CD);
         add(last[cur],-dp[cur],CD);
    
    }
    int main()
    {
        int cas;
        scanf("%d",&cas);
        for(int cc=1; cc<=cas; cc++)
            {
                  int n,m;
                  timoflook=numofedg=0;
                  scanf("%d%d",&n,&m);
                  for(int i=0; i<=n; i++)
                    {
                        CS[i*2]=CS[i*2+1]=CD[i*2]=CD[i*2+1]=0;
                        H[i]=0;E[i].clear();
    
                    }
    
                  for(int i=1; i<n; i++)
                    {
                        int u,v;
                        scanf("%d%d",&u,&v);
                        addedg(u,v);
                    }
                    fa[1][0]=1;
                    dfs(1,1,0);
                    for(int i=0; i<m; i++)
                        {
                               scanf("%d%d%d",&P[i].u,&P[i].v,&P[i].val);
                               P[i].lca=getlca(P[i].u,P[i].v);
                               E[P[i].lca].push_back(i);
                        }
                    solve(1,-1);
                    printf("%d
    ",dp[1]);
            }
        return 0;
    }
    View Code
  • 相关阅读:
    python字符串连接方式(转)
    Python顺序与range和random
    将EXCEL中的列拼接成SQL insert插入语句
    Python OS模块
    Python3.5连接Mysql
    Mysql查看连接端口及版本
    Mysqldb连接Mysql数据库(转)
    Python 文件I/O (转)
    Python 日期和时间(转)
    Python序列的方法(转)
  • 原文地址:https://www.cnblogs.com/Opaser/p/4788669.html
Copyright © 2011-2022 走看看