zoukankan      html  css  js  c++  java
  • [POJ1741]Tree

    题目大意:
      给你一棵带权树,求出树中距离$leq k$的点对个数。
    思路:
      运用树上分治的思想,每次找出树的重心,考虑以下三种情况:
        1.两个结点在不同子树内,且距离$leq k$,则算入答案中;
        2.两个结点距离$leq k$,但属于同一棵子树中,需要被算入答案中,但考虑到以后会被子树的重心重新计算,故在这里忽略;
        3.两个结点距离$>k$,显然需要忽略。
      简而言之,就是每次统计最短路径经过重心的、距离$leq k$的点对个数。
      我们可以每次先DP求出这棵子树的重心,再以这个重心为根,遍历整棵子树,将得到的离重心的距离存入一个数组中,然后枚举子树中的每个点对,将距离和$leq k$的计入答案。
      考虑到要去除同一棵子树中的点对,我们需要在记录距离的同时,要记录每个结点所属的子树的编号,然后枚举的时候判断两个点是否属于统一子树即可。
      但是这样还是会TLE。
      考虑每次使用两个数组$a$和$b$,$a$存储当前子树各个结点的距离,$b$存储之前所有结点的距离,统计答案时只需要对$a$排序,然后枚举$b$中每个元素$b_i$,在$a$中二分查找小于等于$k-b_i$的元素个数即可。
      树上分治是$O(log n)$的,排序是$O(nlog n)$的,因此总的时间复杂度是$O(nlog^2 n)$。

      1 #include<cstdio>
      2 #include<cctype>
      3 #include<vector>
      4 #include<cstring>
      5 #include<algorithm>
      6 inline int getint() {
      7     char ch;
      8     while(!isdigit(ch=getchar()));
      9     int x=ch^'0';
     10     while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
     11     return x;
     12 }
     13 const int inf=0x7fffffff;
     14 const int V=10001;
     15 struct Edge {
     16     int to,w;
     17     Edge(const int to,const int w) {
     18         this->to=to;
     19         this->w=w;
     20     }
     21 };
     22 std::vector<Edge> e[V];
     23 inline void add_edge(const int u,const int v,const int w) {
     24     e[u].push_back(Edge(v,w));
     25 }
     26 int size[V];
     27 bool vis[V];
     28 int min_subtree_size,centroid,tree_size;
     29 void get_centroid(const int x,const int par) {
     30     size[x]=1;
     31     int max=0;
     32     for(unsigned i=0;i<e[x].size();i++) {
     33         int &y=e[x][i].to;
     34         if(vis[y]||y==par) continue;
     35         get_centroid(y,x);
     36         size[x]+=size[y];
     37         max=std::max(max,size[y]);
     38     }
     39     max=std::max(max,tree_size-size[x]);
     40     if(max<min_subtree_size) {
     41         min_subtree_size=max;
     42         centroid=x;
     43     }
     44 }
     45 int k;
     46 std::vector<int> dis,tdis;
     47 void get_dist(const int x,const int par,const int d) {
     48     if(d<=k) tdis.push_back(d);
     49     size[x]=1;
     50     for(unsigned i=0;i<e[x].size();i++) {
     51         int &y=e[x][i].to;
     52         if(vis[y]||y==par) continue;
     53         get_dist(y,x,d+e[x][i].w);
     54         size[x]+=size[y];
     55     }
     56 }
     57 int ans=0;
     58 inline void solve(const int x,const int sz) {
     59     tree_size=sz;
     60     min_subtree_size=inf;
     61     get_centroid(x,0);
     62     vis[centroid]=true;
     63     /*dis.clear();
     64     dis.push_back(Vertex(0,centroid));
     65     for(unsigned i=0;i<e[centroid].size();i++) {
     66         int &y=e[centroid][i].to;
     67         if(vis[y]) continue;
     68         get_dist(y,centroid,e[centroid][i].w,y);
     69     }
     70     std::sort(dis.begin(),dis.end());
     71     for(unsigned i=0;i<dis.size();i++) {
     72         for(unsigned j=i+1;j<dis.size();j++) {
     73             if(dis[i].d+dis[j].d>k) break;
     74             if(dis[i].root!=dis[j].root) ans++;
     75         }
     76     }*/
     77     dis.clear();
     78     dis.push_back(0);
     79     for(unsigned i=0;i<e[centroid].size();i++) {
     80         int &y=e[centroid][i].to;
     81         if(vis[y]) continue;
     82         tdis.clear();
     83         get_dist(y,centroid,e[centroid][i].w);
     84         std::sort(tdis.begin(),tdis.end());
     85         for(unsigned i=0;i<dis.size();i++) {
     86             ans+=std::upper_bound(tdis.begin(),tdis.end(),k-dis[i])-tdis.begin();
     87         }
     88         dis.insert(dis.end(),tdis.begin(),tdis.end());
     89     }
     90     int cur=centroid;
     91     for(unsigned i=0;i<e[cur].size();i++) {
     92         int &y=e[cur][i].to;
     93         if(vis[y]) continue;
     94         solve(y,size[y]);
     95     }
     96 }
     97 inline void init() {
     98     ans=0;
     99     memset(vis,0,sizeof vis);
    100     for(int i=0;i<=V;i++) e[i].clear();
    101 }
    102 int main() {
    103     for(;;) {
    104         int n=getint();
    105         k=getint();
    106         if(!n&&!k) return 0;
    107         init();
    108         for(int i=1;i<n;i++) {
    109             int u=getint(),v=getint(),w=getint();
    110             add_edge(u,v,w);
    111             add_edge(v,u,w);
    112         }
    113         solve(1,n); 
    114         printf("%d
    ",ans);
    115     }
    116 }
  • 相关阅读:
    php实现rpc简单的方法
    统计代码量
    laravel的速查表
    header的参数不能带下划线
    PHP简单实现单点登录功能示例
    phpStorm函数注释的设置
    数据结构基础
    laravel生命周期和核心思想
    深入理解php底层:php生命周期
    Jmeter:实例(性能测试目标)
  • 原文地址:https://www.cnblogs.com/skylee03/p/7470625.html
Copyright © 2011-2022 走看看