zoukankan      html  css  js  c++  java
  • 点分治

    点分治

    蒟蒻迟迟没法开点分治的坑,主要是因为最近找了几道点分治题,全都可以用长链剖分写,由于博主又懒又菜,所以点分治没有得到练习的机会。终于,最近安排了专题分享,最菜的chd捡了一个分治专题,不得不学学这些东西了。

    举个简单的例子,我们对树上的路径有一些询问。我们考虑对于树上的任意一个点,它的子树中的路径可以分为两类:一种是经过根节点的路径,一种是不经过根节点的路径。由于不经过根节点的路径可以递归处理,这样分治的思路就很明显了。对于每一个点的子树,可以选取一个点作为根节点,递归到当前层时,只处理过根节点的路径,然后递归处理我们选定的根节点的儿子,就可以考虑到所有路径。

    但是需要考虑一种极端情况:假如题目给出了一条长度为(n)的链,而我们第(i)层递归选定第(i)个点作为根节点往下递归,那就要递归(n)层,总的复杂度无法保证。这时我们就要考虑选取根节点的技巧了,每次递归选取树的重心为根节点,由于重心有一个性质:删除重心后得到的森林中的每一棵树的大小都不超过原树的一半,这样就能保证递归层数是(logn)层。

    那么怎样求重心呢?考虑树的重心定义为一棵树中删除它后能使得到的森林中最大的树最小的点。一遍dfs就能求出,具体的实现下面会有。

    接下来看题吧。

    P4178 Tree

    做这道题之前可以先看看它的弱化版CF161D Distance in Tree。那道是求树上长度等于(k)的路径条数,而这道是求长度(le k)的路径条数。那道题是直接点分治统计,现在这道题就是加一个树状数组统计前缀和。必要的分析上面都已经说过,具体看代码。

    #include<cstdio>
    #include<cctype>
    #define R register
    #define I inline
    using namespace std;
    const int S=40003,N=80003;
    char buf[1000000],*p1,*p2;
    I char gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,S,stdin),p1==p2)?EOF:*p1++;}
    I int rd(){
    	R int f=0; R char c=gc();
    	while(c<48||c>57) c=gc();
    	while(c>47&&c<58) f=f*10+(c^48),c=gc();
    	return f;
    }
    int h[S],s[N],g[N],w[N],t[S],v[S],a[S],p[S],q[S],c,r,u,n,m,o;
    I int max(int x,int y){return x>y?x:y;}
    I void add(int x,int y,int z){s[++c]=h[x],h[x]=c,g[c]=y,w[c]=z;}
    I void mdf(int x,int v){for(;x<=m;x+=x&-x) a[x]+=v;}
    I int qry(int x){R int r=0; for(;x;x^=x&-x) r+=a[x]; return r;}
    void gts(int x,int f){t[x]=1;//getsize
    	for(R int i=h[x],y;i;i=s[i]) if((y=g[i])^f&&!v[y]) gts(y,x),t[x]+=t[y];
    }
    void gtr(int x,int f,int a){R int m=0,i,y;//getroot 重心
    	for(i=h[x];i;i=s[i]) if((y=g[i])^f&&!v[y]) m=max(m,t[y]),gtr(y,x,a);
    	m=max(m,a-t[x]); if(m<u) u=m,r=x;
    }
    void dfs(int x,int f,int d){if(d>m) return ; p[++p[0]]=q[++q[0]]=d;//统计路径
    	for(R int i=h[x],y;i;i=s[i]) if((y=g[i])^f&&!v[y]) dfs(y,x,d+w[i]);
    }
    void dac(int x){//divide and conquer
    	q[0]=0,u=n,gts(x,0),gtr(x,0,t[x]),v[r]=1; R int i,j,y; 
    	for(i=h[r];i;i=s[i])
    		if(!v[y=g[i]]){p[0]=0,dfs(y,r,w[i]);
    			for(j=p[0];j;--j) if(p[j]<=m) o+=qry(m-p[j]);
    			for(j=p[0];j;--j) if(p[j]<=m) mdf(p[j],1),++o;
    		}
    	for(i=q[0];i;--i) if(q[i]<=m) mdf(q[i],-1);
    	for(i=h[r];i;i=s[i]) if(!v[y=g[i]]) dac(y);
    }
    int main(){
    	R int i,x,y,z;
    	for(n=rd(),i=1;i<n;++i) x=rd(),y=rd(),z=rd(),add(x,y,z),add(y,x,z);
    	m=rd(),dac(1),printf("%d
    ",o);
    	return 0;
    }
    

    P3806 【模板】点分治1

    洛谷把这道题作了板子题,我觉得非常有道理。有些题解其实写的是(O(n^2))的,而不是他们所说的(O(nlogn))(O(n(logn)^2))(例如作者写这篇文章时的洛谷第一篇题解)。

    注意这道题是要求(m(mle 100))个长度为(k)的路径是否出现过。

    看看我们是怎样统计过当前根节点的路径的?求出从根节点到所有点的路径再两两组合?这样统计是(O(n^2))的。考虑把路径按长度排序,同时记录所来自的子树(相同子树的路径不能对答案造成贡献),对于每个询问拿一个指针从左往右扫,同时二分能和它配对的值,注意判掉在同一棵子树内的路径。总的复杂度是(O(mn(logn)^2))。实现起来有些繁琐,但事实上不难。

    #include<cstdio>
    #include<cctype>
    #include<algorithm>
    #define R register
    #define I inline
    using namespace std;
    const int S=10003,N=20003,M=103;
    char buf[1000000],*p1,*p2;
    I char gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,S,stdin),p1==p2)?EOF:*p1++;}
    I int rd(){
    	R int f=0; R char c=gc();
    	while(c<48||c>57) c=gc();
    	while(c>47&&c<58) f=f*10+(c^48),c=gc();
    	return f;
    }
    struct D{int d,t;}p[S];
    int h[S],s[N],g[N],w[N],t[S],v[S],q[M],b[M],c,r,n,m,u,e;
    I int max(int x,int y){return x>y?x:y;}
    I int cmp(D x,D y){return x.d^y.d?x.d<y.d:x.t>y.t;}
    I void add(int x,int y,int z){s[++c]=h[x],h[x]=c,g[c]=y,w[c]=z;}
    void gts(int x,int f){t[x]=1;
    	for(R int i=h[x],y;i;i=s[i]) if((y=g[i])^f&&!v[y]) gts(y,x),t[x]+=t[y];
    }
    void gtr(int x,int f,int a){R int m=0,i,y;
    	for(i=h[x];i;i=s[i]) if((y=g[i])^f&&!v[y]) m=max(m,t[y]),gtr(y,x,a);
    	m=max(m,a-t[x]); if(m<u) u=m,r=x;
    }
    void dfs(int x,int f,int d){p[++e]=(D){d,c};
    	for(R int i=h[x],y;i;i=s[i]) if((y=g[i])^f&&!v[y]) dfs(y,x,d+w[i]);
    }
    I int fnd(int x){
    	R int s=0,l=1,r=e,m;
    	while(l<=r){m=l+r>>1;
    		if(p[m].d<x) l=m+1;
    		else s=m,r=m-1;
    	}return s;
    }
    void dac(int x){R int i,j,k,y;
    	for(c=0,e=0,u=n,gts(x,0),gtr(x,0,t[x]),v[r]=1,i=h[r];i;i=s[i])
    		if(!v[y=g[i]]) ++c,dfs(y,r,w[i]);
    	p[++e]=(D){0,0},sort(p+1,p+1+e,cmp);
    	for(k=1;k<=m;++k){
    		if(b[k]) continue;
    		for(i=1;i<e&&p[i].d+p[e].d<q[k];++i);
    		while(i<e){
    			if(q[k]-p[i].d<p[i].d) break;
    			for(j=fnd(q[k]-p[i].d);p[j].d+p[i].d==q[k]&&p[i].t==p[j].t;++j);
    			if(p[i].d+p[j].d==q[k]) b[k]=1; ++i;
    		}
    	}
    	for(i=h[r];i;i=s[i]) if(!v[y=g[i]]) dac(y);
    }
    int main(){
    	R int i,x,y,z;
    	for(n=rd(),m=rd(),i=1;i<n;++i) x=rd(),y=rd(),z=rd(),add(x,y,z),add(y,x,z);
    	for(i=1;i<=m;++i) q[i]=rd();
    	for(dac(1),i=1;i<=m;++i) b[i]?printf("AYE
    "):printf("NAY
    ");
    	return 0;
    }
    

    然而,我们发现还存在更快的方法,如果统计的时候开一个值域大小的桶((1e7)不用虚),每次判断的时候就可以(O(1)),这样总复杂度就又降了一个(logn),总时间复杂度(O(mnlogn))。附上代码:

    #include<cstdio>
    #include<cctype>
    #define R register
    #define I inline
    using namespace std;
    const int S=100003,N=200003,M=103,K=10000003;
    char buf[1000000],*p1,*p2;
    I char gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,S,stdin),p1==p2)?EOF:*p1++;}
    I int rd(){
    	R int f=0; R char c=gc();
    	while(c<48||c>57) c=gc();
    	while(c>47&&c<58) f=f*10+(c^48),c=gc();
    	return f;
    }
    int h[S],s[N],g[N],w[N],t[S],v[S],a[M],b[K],p[S],q[M],l[S],c,e,n,m,u,r;
    I int max(int x,int y){return x>y?x:y;}
    I void add(int x,int y,int z){s[++c]=h[x],h[x]=c,g[c]=y,w[c]=z;}
    void gts(int x,int f){t[x]=1;
    	for(R int i=h[x],y;i;i=s[i]) if(!v[y=g[i]]&&y^f) gts(y,x),t[x]+=t[y];
    }
    void gtr(int x,int f,int a){R int m=0,i,y;
    	for(i=h[x];i;i=s[i]) if(!v[y=g[i]]&&y^f) m=max(m,t[y]),gtr(y,x,a);
    	m=max(m,a-t[x]); if(m<u) u=m,r=x;
    }
    void dfs(int x,int f,int d){p[++e]=d;
    	for(R int i=h[x],y;i;i=s[i]) if(!v[y=g[i]]&&y^f) dfs(y,x,d+w[i]);
    }
    void dac(int x){R int i,j,k,y;
    	for(b[0]=1,l[0]=0,u=n,gts(x,0),gtr(x,0,t[x]),v[r]=1,i=h[r];i;i=s[i])
    		if(!v[y=g[i]]){dfs(y,r,w[i]);
    			for(j=1;j<=e;++j)
    				for(k=1;k<=m;++k) if(q[k]>=p[j]) a[k]|=b[q[k]-p[j]];
    		   	for(;e;--e) l[++l[0]]=p[e],b[p[e]]=1;
    		}
    	for(j=l[0];j;--j) b[l[j]]=0;
    	for(i=h[r];i;i=s[i]) if(!v[y=g[i]]) dac(y);
    }
    int main(){
    	R int i,x,y,z;
    	for(n=rd(),m=rd(),i=1;i<n;++i) x=rd(),y=rd(),z=rd(),add(x,y,z),add(y,x,z);
    	for(i=1;i<=m;++i) q[i]=rd();
    	for(dac(1),i=1;i<=m;++i) printf(a[i]?"AYE
    ":"NAY
    ");
    	return 0;
    }
    
    

    写到这里,(像我这样的)初学者可能会有一个疑问,如果要判断一个长度的路径是否在树上出现过,不是应该遍历(n(n-1))条路径才能知道吗?这个算法的复杂度下界应该是(O(n^2))才对啊。

    事实上,注意我们统计过根节点的路径时的操作,我们并不需要求出所有的路径,只需要求出从根出发的路径((O(n))条),面向询问在排序的基础上通过二分判断,我们的复杂度是把这里的(O(n^2))降成了(O(mnlogn))。那我们为什么点分治呢?事实上点分治是为这样做提供了条件,只有点分治之后才能通过对于从根节点出发的路径的统计来不重不漏地判断所有的情况。

    但是你可能仍然存在疑问:开头所说的那些题解中,对于每一层递归的统计都是(O(n^2))的(可以仔细看看他们的代码递归统计答案的部分),总共递归(logn)层,那么总的复杂度不就是(O(n^2logn))吗?为什么会比(n^2)暴力跑得快呢?事实上,这些现在看似(O(n^2logn))的题解的复杂度都是(O(n^2))的!

    下面给出简单证明:考虑最坏的情况,一条链。每一层递归对(2^i)个大小为(frac{n}{2^i})的部分进行(n^2)级别的处理,总的复杂度趋近于(sumlimits_{i=1}^infty 2^i*(frac{n}{2^i})^2),即(sumlimits_{i=1}^infty frac{n^2}{2^i}),然而这个值是趋近于(2*n^2)的,所以这个算法的复杂度是(O(n^2)),而非(O(n^2logn))

    至于为什么有位同学(开头就是“暴力过淀粉质模板”的那位,抱歉,并无恶意)过得那么惊险,除了常数大,你的写法确实比较玄学,我用一条链、菊花图、扫把图、二叉树都没卡掉,但是你在随机图上的时间是在这些构造图上跑的3倍qwq。

    推荐题目:
    P4149 [IOI2011]Race 题解
    P4886 快递员 题解
    [BJOI2017]树的难题 题解

  • 相关阅读:
    关于linux curl 地址参数的问题
    mac系统安装php redis扩展
    Shell获取上一个月、星期的时间范围
    python redis使用
    python pycurl模块
    Memcached常规应用与分布式部署方案
    mysql忘记密码重置(mac)
    shell命令从目录中循环匹配关键词
    python两个文件的对比
    MySQL优化方案
  • 原文地址:https://www.cnblogs.com/cj-chd/p/10102720.html
Copyright © 2011-2022 走看看