zoukankan      html  css  js  c++  java
  • SPOJ1825 Free tour II 树分治

    题意:带边权树上有白点和黑点,问你最多不经过k个黑点使得路径最长(注意,路径有负数)

    解题思路:基于树的点分治。数的路径问题,具体看09QZC论文,特别注意 当根为黑时的情况

    解题代码:

      1 // File Name: spoj1825.cpp
      2 // Author: darkdream
      3 // Created Time: 2014年10月05日 星期日 20时20分33秒
      4 
      5 #include<vector>
      6 #include<list>
      7 #include<map>
      8 #include<set>
      9 #include<deque>
     10 #include<stack>
     11 #include<bitset>
     12 #include<algorithm>
     13 #include<functional>
     14 #include<numeric>
     15 #include<utility>
     16 #include<sstream>
     17 #include<iostream>
     18 #include<iomanip>
     19 #include<cstdio>
     20 #include<cmath>
     21 #include<cstdlib>
     22 #include<cstring>
     23 #include<ctime>
     24 #define LL long long 
     25 #define maxn 200015
     26 using namespace std;
     27 struct node{
     28     int ne;
     29     int w;
     30     node(int _ne,int _w)
     31     {
     32         ne = _ne ; 
     33         w = _w;
     34     }
     35 };
     36 int n ,K, m ; 
     37 int col[maxn];
     38 int vis[maxn];
     39 vector <node> mp[maxn];
     40 int sum[maxn];
     41 int mx[maxn];
     42 int cnum[maxn];
     43 void getsize(int k,int la)
     44 {
     45     sum[k] = 1; 
     46     mx[k] = 0;
     47     int num = mp[k].size();
     48     int tt = 0 ;
     49     for(int i = 0 ;i < num;i ++)
     50     {
     51        if(!vis[mp[k][i].ne] && mp[k][i].ne != la)
     52        {
     53            getsize(mp[k][i].ne,k);
     54            mx[k] = max(sum[mp[k][i].ne],mx[k]);
     55            sum[k] += sum[mp[k][i].ne];
     56        }
     57     }
     58 }
     59 int root;
     60 int mxv; 
     61 int getroot(int k,int la ,int tans)
     62 {
     63      int tt = max(tans - sum[k],mx[k]);
     64      if(tt < mxv)
     65      {
     66         mxv = tt;
     67         root = k ; 
     68      }
     69      int num = mp[k].size();
     70      for(int i = 0 ;i < num ;i ++)
     71      {
     72        if(!vis[mp[k][i].ne] && mp[k][i].ne != la)
     73        {
     74            getroot(mp[k][i].ne,k,tans);
     75        }
     76      }
     77 }
     78 LL ans = 0 ;
     79 LL dp[maxn];
     80 LL tdp[maxn];
     81 bool cmp(node a, node b)
     82 {
     83     return cnum[a.ne] < cnum[b.ne];
     84 }
     85 void getdep(int k ,int la,int tc,LL dep)
     86 { 
     87     int st = (col[k] == 1?1:0) ;
     88     tdp[tc+st] = max(tdp[tc+st],dep); //这个点是G点的时候
     89     int num = mp[k].size();
     90     for(int i = 0 ;i < num ;i ++)
     91     {
     92         if(!vis[mp[k][i].ne] && mp[k][i].ne != la )
     93         {
     94             getdep(mp[k][i].ne,k,tc + st,dep + mp[k][i].w);
     95         }
     96     }
     97 }
     98 void getcnum(int k ,int la)
     99 {
    100     if(col[k])
    101         cnum[k] = 1; 
    102     else cnum[k] = 0 ; 
    103     int tt = 0 ;
    104     int num = mp[k].size();
    105     for(int i = 0 ;i < num;i ++)
    106     {
    107        if(!vis[mp[k][i].ne] && mp[k][i].ne != la)
    108        {
    109            getcnum(mp[k][i].ne,k);
    110           if(cnum[mp[k][i].ne] > tt)
    111               tt = cnum[mp[k][i].ne];
    112        }
    113     }
    114     cnum[k] += tt;
    115 }
    116 void solve(int k)
    117 {
    118     getsize(k,0);
    119     mxv = 1e9;
    120     getroot(k,0,sum[k]);
    121     k = root;
    122     
    123     getcnum(k,0);    
    124     //printf("*****%d %d
    ",k,cnum[k]);    
    125     int num = mp[k].size();
    126     memset(dp,0,(cnum[k]+3)*sizeof(LL));
    127     int tk ;
    128     int st = 0 ;
    129     if(col[k])
    130     {
    131         tk = K + 1;
    132         st = 1;
    133     }
    134     else tk = K ;
    135     int la =0 ; 
    136     //int size = min(cnum[k],K);
    137     sort(mp[k].begin(),mp[k].end(),cmp);
    138     for(int i = 0 ;i < num ;i ++)
    139     {
    140         if(vis[mp[k][i].ne])
    141             continue;
    142         
    143         memset(tdp,0,(cnum[mp[k][i].ne]+3)*sizeof(tdp[0]));
    144         if(col[k])
    145            getdep(mp[k][i].ne,k,1,mp[k][i].w);        
    146         else 
    147            getdep(mp[k][i].ne,k,0,mp[k][i].w);        
    148     //    printf("**********%d
    ",tk);
    149         
    150         
    151         int tt = min(cnum[mp[k][i].ne]+st,K);    
    152 //        printf("%d %d
    ",cnum[mp[k][i].ne]+st,K);
    153         for(int j = st ;j <= tt;j ++)
    154         {
    155            if(tk - j <= la)
    156            {
    157             if(tdp[j] + dp[tk-j]> ans)
    158             {
    159                 ans = tdp[j] + dp[tk-j];
    160             }
    161            }else{
    162              if(tdp[j] + dp[la]> ans)
    163              {
    164                 ans = tdp[j] + dp[la];
    165              }
    166            }
    167         } 
    168         dp[0] = max(dp[0],tdp[0]);
    169         //printf("%d %d
    ",n,cnum[mp[k][i].ne]);
    170         /*if(tdp[tt+st+1] != 0)
    171         {
    172           printf("&&&&&&&&&&&&&&
    ");
    173         }*/
    174         for(int j = 1 ;j <= tt+st; j ++)
    175         {
    176             dp[j] = max(dp[j],tdp[j]);
    177             dp[j] = max(dp[j],dp[j-1]);
    178         }
    179    //     for(int j = 0;j <= K;j ++)
    180     //        printf("%lld ",dp[j]);
    181     //    puts("");
    182         la = tt + st;
    183     }
    184     //puts("**********8");
    185     vis[k] = 1;
    186     for(int i = 0;i < num;i ++)
    187     {
    188         if(!vis[mp[k][i].ne])
    189             solve(mp[k][i].ne);
    190     }
    191     return; 
    192 }
    193 int main(){
    194    //freopen("out","r",stdin);    
    195    //freopen("output.txt","w",stdin);
    196    while(scanf("%d %d %d",&n,&K,&m) != EOF){
    197     int temp ; 
    198     memset(vis,0,sizeof(vis));
    199     memset(col,0,sizeof(col));
    200     for(int i = 1;i <= n;i ++)
    201         mp[i].clear();
    202     for(int i = 1;i <= m;i ++)
    203     {
    204         scanf("%d",&temp);
    205         col[temp]  = 1;  
    206     }
    207     for(int i = 1;i <= n - 1;i ++)
    208     {
    209         int a, b , w; 
    210         scanf("%d %d %d",&a,&b,&w);
    211         mp[a].push_back(node(b,w));
    212         mp[b].push_back(node(a,w));
    213     }
    214     ans = 0; 
    215     solve(1);
    216     printf("%lld
    ",ans);
    217    }
    218     return 0;
    219 }
    View Code
    没有梦想,何谈远方
  • 相关阅读:
    Springboot使用PlatformTransactionManager接口的事务处理
    js 正则替换html标签
    【转】mysql查询时,查询结果按where in数组排序
    js输出字幕数字a-zA-Z0-9
    tcpdump使用教程
    rsync安装使用教程
    vim配置修改教程
    XD刷机报错bad CRC
    使用docker搭建seafile服务器
    案例:使用sqlplus登录报ORA-12547错误
  • 原文地址:https://www.cnblogs.com/zyue/p/4013037.html
Copyright © 2011-2022 走看看