zoukankan      html  css  js  c++  java
  • 点分治

     点分治

    点分治是树上分治的一种(树上分治还有边分治),常用于解决和树上路径有关的问题。

    因为树上路径有一条性质:树上的任何路径,要么经过根节点$rt$要么就全部在$rt$的一颗子树上。

    正确性显而易见:树上两点的路径是唯一的,如果两点在$rt$的同一子树上,则路径完全在一颗子树上,如果在$rt$的不同子树,则必然经过$rt$。

    有了这条性质,我们就可以对树上的路径进行分治:先将经过$rt$的路径处理完,此时$rt$的各个子树上的路径就互不影响了,故可以递归分治。

    但怎么使得分治均匀呢?如果随意选择根节点,则如果树退化成链,则递归层数为$N$,且每次操作的节点数目都会非常多。

    所以此时,我们选择树的重心,这样可以使每一次分治后剩下的最大子树的大小下降最快(重心的最大子树最小),每一次分治我们都找重心,操作本身复杂度为$O(NlogN)$,可以使分治底层执行次数降到$O(NlogN)$(主定理)。通常情况下,节点的子树超过$2$棵,则复杂度往往会低于$O(NlogN)$

    void getrt(int x,int fa){
        siz[x]=1;
        maxs[x]=0;
        for(int i=head[x];i;i=nxt[i]){
            if(fa==to[i]&&vis[to[i]])
                continue;
            getrt(to[i],x);
            siz[x]+=siz[to[i]];
            maxs[x]=max(maxs[x],siz[to[i]]);
        }
        maxs[x]=max(maxs[x],sum-siz[x]);
        if(maxs[x]<maxs[rt])
            rt=x;
    }
    求树的重心代码(顺便更新根节点)

    这有什么用呢?

    举个例子,如果我们要统计一棵树内各个长度的路径的个数,就可以用点分治来做。

    首先,我们将树的重心设为根,求出树上各点到根的距离($O(N)$),并统计每个长度的路径的数量($O(N^2)$),然后枚举每一个子树,递归执行同样的操作,并在同一个数组上统计数量

    算法的核心在于分治:

    就着代码讲:

    void Divid(int x)
    {
        ans+=solve(x,0);  ①统计经过RT的所有路径(不一定合法,下文会重点讲)
        vis[x] = 1;
        for (int i = head[x];i;i = edges[i].net)  枚举所有子树
        {
            edge v = edges[i];
            if(vis[v.to]) continue;  防止遍历到上层
            ans-=solve(v.to,edges[i].cost);  减掉①中统计的不合法路径*
            S = size[v.to]; root = 0;
            find(v.to,x);  找到新根
            Divid(root);  按子树分治
        }
    }

    *上面代码中提到求出的“经过RT的路径”不一定合法,要减掉一部分,为什么呢?在统计经过rt的路径时,我们将所有点都两两匹配了,此时同一棵子树中的点当然不能配对。

    如图,我们将$RT ightarrow B$和$RT ightarrow C$的路径合并了,其中$RT ightarrow A$部分的路径重复了

    怎么去掉重复的部分呢?见代码,我们【将“经过A点的路径”(不一定合法)加上$RT ightarrow A$的长度】这一部分产生的贡献去掉

    如图,过A点的路径为$B ightarrow A  ightarrow C$,统计结果为$A ightarrow B+A ightarrow C+2*(RT ightarrow A)$(因为所有长度都加了$RT ightarrow A$)

    这样,我们就求出了所有经过RT的路径的贡献

    接下来,就是逐步细化实现了。对于大部分点分治的题目,上面分治的代码都是差不多的,题目的差异在于solve函数。

    下面,我以洛谷P3806 【模板】点分治1为例介绍一下实现的过程 题解

    此题的题意很简单,就是询问树上长度为k的路径是否存在。看到询问路径,就很容易想到点分治,故我们可以套上面的模板。

    这里介绍一下solve函数的实现

    首先,我们要求出从当前根节点到子数中所有点的距离,然后将这些距离组合配对,将所有的和统计到答案中即可

    void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N^2)*/
        tp=0;
        dis[x]=len;
        get_dis(x,0,len);
        for(int i=1;i<=tp;i++)
            for(int j=1;j<=tp;j++)
                if(i!=j)
                    ans[st[i]+st[j]]+=w;
    }

    这是求距离的代码

    void get_dis(int x,int fa,int len){
        if(len<=1e7)
            st[++tp]=len;
        for(int i=head[x];i;i=nxt[i]){
            if(to[i]==fa||vis[to[i]])
                continue;
            dis[to[i]]=len+val[i];
            get_dis(to[i],x,len+val[i]);
        }
    }

    于是,我们能得到以下代码,可以AC(是因为数据水)

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 typedef long long LL;
     4 const int MAXK=2e7,MAXN=1e4+7,MAXM=2e4+7;
     5 inline void Max(int &x,int y){
     6     x=x>y?x:y;
     7 }
     8 int sz,head[MAXN],to[MAXM],nxt[MAXM],val[MAXM];
     9 inline void add(int x,int y,int z){
    10     nxt[++sz]=head[x]; head[x]=sz; to[sz]=y; val[sz]=z;
    11     nxt[++sz]=head[y]; head[y]=sz; to[sz]=x; val[sz]=z;
    12 }
    13 int rt,siz[MAXN],maxson[MAXN],vis[MAXN],S;
    14 void find(int x/*cur vertex*/,int fa/*father*/){/*find root*/
    15     siz[x]=1;
    16     maxson[x]=0;
    17     for(int i=head[x];i;i=nxt[i]){
    18         if(to[i]==fa||vis[to[i]])
    19             continue;
    20         find(to[i],x);
    21         siz[x]+=siz[to[i]];
    22         Max(maxson[x],siz[to[i]]);
    23     }
    24     Max(maxson[x],S-siz[x]);
    25     if(maxson[x]<maxson[rt])
    26         rt=x;
    27 }
    28 int dis[MAXN],st[MAXN],tp;
    29 void get_dis(int x,int fa,int len){
    30     if(len<=1e7)
    31         st[++tp]=len;
    32     for(int i=head[x];i;i=nxt[i]){
    33         if(to[i]==fa||vis[to[i]])
    34             continue;
    35         dis[to[i]]=len+val[i];
    36         get_dis(to[i],x,len+val[i]);
    37     }
    38 }
    39 int ans[MAXK];
    40 void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N^2)*/
    41     tp=0;
    42     dis[x]=len;
    43     get_dis(x,0,len);
    44     for(int i=1;i<=tp;i++)
    45         for(int j=1;j<=tp;j++)
    46             if(i!=j)
    47                 ans[st[i]+st[j]]+=w;
    48 }
    49 int N,Q,K;
    50 void divide(int x){
    51     solve(x,0,1);
    52     vis[x]=1;
    53     for(int i=head[x];i;i=nxt[i]){
    54         if(vis[to[i]])
    55             continue;
    56         solve(to[i],val[i],-1);
    57         S=siz[x];
    58         rt=0;
    59         maxson[0]=N;
    60         find(to[i],x);
    61         divide(rt);
    62     }
    63 }
    64 int main(){
    65     scanf("%d%d",&N,&Q);
    66     for(int i=1;i<N;i++){
    67         int ii,jj,kk;
    68         scanf("%d%d%d",&ii,&jj,&kk);
    69         add(ii,jj,kk);
    70     }
    71     S=N;
    72     maxson[0]=N;
    73     rt=0;
    74     find(1,0);
    75     divide(rt);
    76     while(Q--){
    77         scanf("%d",&K);
    78         puts(ans[K]?"AYE":"NAY");
    79     }
    80     return 0;
    81 }
    View Code

    分析代码可以发现,实际上代码的时间复杂度为$Theta(N^2 log N)$,在较强的数据中是会TLE的,于是我们要优化

    我们发现我们对于所有可能的询问,都统计了答案,这其实是一种冗余。题目中的m非常小,我们其实可以根据询问来统计答案,效率会提高两个数量级

    于是,我们很容易想到将询问离线,然后每次在表内统计一份结果,并且枚举所有的询问,在表内查询之前是否得到过$答案-当前结果$的值

    新的solve函数

    void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N*M)*/
        ++timeclock;
        tp=0;
        dis[x]=len;
        get_dis(x,0,len);
        for(int i=1;i<=tp;i++)
            for(int j=1;j<=Q;j++){
                int ii=qry[j]-st[i];
                if(ii<0||date[ii]!=timeclock||(b[ii]==1&&ii==st[i]))
                    continue;
                ans[j]+=w;
            }
    }

    最后的代码也就很简单了,用时为原来的几十分之一

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 typedef long long LL;
     4 const int MAXK=2e7,MAXN=1e4+7,MAXM=2e4+7,MAXQ=1e2+7;
     5 inline void Max(int &x,int y){
     6     x=x>y?x:y;
     7 }
     8 int sz,head[MAXN],to[MAXM],nxt[MAXM],val[MAXM];
     9 inline void add(int x,int y,int z){
    10     nxt[++sz]=head[x]; head[x]=sz; to[sz]=y; val[sz]=z;
    11     nxt[++sz]=head[y]; head[y]=sz; to[sz]=x; val[sz]=z;
    12 }
    13 int rt,siz[MAXN],maxson[MAXN],vis[MAXN],S;
    14 void find(int x/*cur vertex*/,int fa/*father*/){/*find root*/
    15     siz[x]=1;
    16     maxson[x]=0;
    17     for(int i=head[x];i;i=nxt[i]){
    18         if(to[i]==fa||vis[to[i]])
    19             continue;
    20         find(to[i],x);
    21         siz[x]+=siz[to[i]];
    22         Max(maxson[x],siz[to[i]]);
    23     }
    24     Max(maxson[x],S-siz[x]);
    25     if(maxson[x]<maxson[rt])
    26         rt=x;
    27 }
    28 int dis[MAXN],st[MAXN],tp;
    29 int qry[MAXQ],ans[MAXQ],date[MAXK],b[MAXK],timeclock;
    30 int N,Q,K;
    31 void get_dis(int x,int fa,int len){
    32     if(len<=1e7){
    33         st[++tp]=len;
    34         if(date[len]==timeclock)
    35             b[len]++;
    36         else{
    37             b[len]=1;
    38             date[len]=timeclock;
    39         }
    40     }
    41     for(int i=head[x];i;i=nxt[i]){
    42         if(to[i]==fa||vis[to[i]])
    43             continue;
    44         dis[to[i]]=len+val[i];
    45         get_dis(to[i],x,len+val[i]);
    46     }
    47 }
    48 void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N*M)*/
    49     ++timeclock;
    50     tp=0;
    51     dis[x]=len;
    52     get_dis(x,0,len);
    53     for(int i=1;i<=tp;i++)
    54         for(int j=1;j<=Q;j++){
    55             int ii=qry[j]-st[i];
    56             if(ii<0||date[ii]!=timeclock||(b[ii]==1&&ii==st[i]))
    57                 continue;
    58             ans[j]+=w;
    59         }
    60 }
    61 void divide(int x){
    62     solve(x,0,1);
    63     vis[x]=1;
    64     for(int i=head[x];i;i=nxt[i]){
    65         if(vis[to[i]])
    66             continue;
    67         solve(to[i],val[i],-1);
    68         S=siz[x];
    69         rt=0;
    70         maxson[0]=N;
    71         find(to[i],x);
    72         divide(rt);
    73     }
    74 }
    75 int main(){
    76     scanf("%d%d",&N,&Q);
    77     for(int i=1;i<N;i++){
    78         int ii,jj,kk;
    79         scanf("%d%d%d",&ii,&jj,&kk);
    80         add(ii,jj,kk);
    81     }
    82     for(int i=1;i<=Q;i++){
    83         scanf("%d",qry+i);
    84     }
    85     S=N;
    86     maxson[0]=N;
    87     rt=0;
    88     find(1,0);
    89     divide(rt);
    90     for(int i=1;i<=Q;i++)
    91         puts(ans[i]?"AYE":"NAY");
    92     return 0;
    93 }
    View Code
  • 相关阅读:
    cookie 保存信息 例子
    win7 远程桌面 不用域账户登录
    jfreechart demo 源代码 下载
    PostgreSQL学习手册(数据表)
    apache tomcat 伪静态 struts2 伪静态
    List methods
    DX 骨骼动画
    如何成为一名优秀的程序员?
    程序员改编游戏向女友求婚
    引用 病毒是怎么命名的?教你认识病毒命名规则
  • 原文地址:https://www.cnblogs.com/guoshaoyang/p/10994997.html
Copyright © 2011-2022 走看看