点分治
拾知识点了!!!
点分治是什么
点分治,是一种处理树上路径问题的工具,举个经典的点分治例子
给定一棵树和一个整数k,求树上长度等于k的路径有多少条
先来考虑暴力
枚举不同的两个点,然后dfs计算距离统计一下(O(n^3))
复杂度有点爆表 大概10分?
枚举不同的两个点,然后用倍增求两两点之间的距离(O(n^2logn))
n大一点还是不行 大概20?
做n次dfs,求每个点与所有点之间的距离(O(n^2))
(1e5)不太行啊 大概有30?
这个时候我们就需要把(n^2)里的一个(n)通过分治变成(logn)怎么样?
点分治,顾名思义,就是把树上的节点拆开来进行分治,每次通过点分把一棵树拆成多棵子树,然后继续递归下去,最后综合起来求出答案。
如何分治
既然要分治,我们肯定是每次选择一个点,从它开始分治下去,我们如果考虑选择这个点才能保证复杂度,我们发现每次处理完一个点之后,我们都要递归进它的子树,那么这个分治点的时间复杂度就会受到分治点的最大的子树的大小影响。
举一个极端的例子,如果原树是一条链,我们选择每次链首那么我们的复杂度就退化到了(O(n^2)),所以我们需要保证每次选到的分治点的最大子树最小。
说到这里那么我们每次选择的分治点就肯定是当前树的重心了
复杂度分析
我们把复杂度分析伴随着大概的思路顺一遍。
我们先来看重心的性质:每个子树的大小都不会超过总大小的一半
那么我们最多向下分治(logn)层,到达最后一层返回答案然后合并一下复杂度(O(nlogn))
求重心
直接dfs搞一下
void getroot(int x,int fa) {
siz[x]=1;int num=0;
for(int i=link[x];i;i=a[i].next) {
int v=a[i].y;
if(v==fa||vis[v]) continue;
getroot(v,x);siz[x]+=siz[v];
num=max(num,siz[v]);
}
num=max(num,size-siz[x]);
if(num<mx) mx=num,root=x;
}
合并求答案
void getdis(int x,int fa) {
q[++r]=d[x];
for(int i=link[x];i;i=a[i].next) {
int v=a[i].y;
if(v==fa||vis[v]) continue;
d[v]=d[x]+a[i].v;
getdis(v,x);
}
}
ll work(int x,int v) {
r=0;d[x]=v;
getdis(x,0);
ll sum=0;l=1;
sort(q+1,q+r+1);
while(l<r) {//双指针扫一遍
if(q[l]+q[r]<=k) sum+=r-l,++l;
return --r;
}
return sum;
}
ll ans=0;
void dfs(int x) {
vis[x]=1;
work(x,0,1);//当前为根求答案
for(int i=link[x];i;i=a[i].next) {
int v=a[i].y;
if(vis[v]) continue;
work(v,a[i].v,0);//容斥一下减去不可取的答案
mx=inf,rt=0,size=sz[v];
getrt(v,0); dfs(rt);//继续分治
}
}
例题
bzoj2152 聪聪可可
题目描述
聪聪和可可是兄弟俩,他们俩经常为了一些琐事打起来,例如家中只剩下最后一根冰棍而两人都想吃、两个人都想玩儿电脑(可是他们家只有一台电脑)……遇到这种问题,一般情况下石头剪刀布就好了,可是他们已经玩儿腻了这种低智商的游戏。
他们的爸爸快被他们的争吵烦死了,所以他发明了一个新游戏:由爸爸在纸上画n个“点”,并用n-1条“边”把这n个“点”恰好连通(其实这就是一棵树)。并且每条“边”上都有一个数。接下来由聪聪和可可分别随即选一个点(当然他们选点时是看不到这棵树的),如果两个点之间所有边上数的和加起来恰好是3的倍数,则判聪聪赢,否则可可赢。
聪聪非常爱思考问题,在每次游戏后都会仔细研究这棵树,希望知道对于这张图自己的获胜概率是多少。现请你帮忙求出这个值以验证聪聪的答案是否正确。
输入输出格式
输入格式:
输入的第1行包含1个正整数n。后面n-1行,每行3个整数x、y、w,表示x号点和y号点之间有一条边,上面的数是w。
输出格式:
以即约分数形式输出这个概率(即“a/b”的形式,其中a和b必须互质。如果概率为1,输出“1/1”)。
输入输出样例
输入样例#1:
复制
5
1 2 1
1 3 2
1 4 1
2 5 3
输出样例#1:
复制
13/25
说明
【样例说明】
13组点对分别是(1,1) (2,2) (2,3) (2,5) (3,2) (3,3) (3,4) (3,5) (4,3) (4,4) (5,2) (5,3) (5,5)。
【数据规模】
对于100%的数据,n<=20000。
很明显,思路就是统计长度为33的倍数的路径的条数,然后除以路径总和就是答案
先贴一句话题解:先用点分计算出路径长度,把路径长度对33取模,然后用sum[1],sum[2],sum[0]表示模数是1,2,3的情况的总数,那么就是ans+=sum[1]∗sum[2]∗2+sum[0]∗sum[0]],最后答案就是ans/(n∗n)
用人话说的话,我们可以先考虑一个点,用sum[1,2,3]sum[1,2,3]分别表示从以这一个点为根,往下的长度对33取模余数是1,2,31,2,3的路径条数,那么所有经过这一个点的路径有多少条呢?所有长度为1和2的路径可以两两拼起来成为一条,反着也可以,长度为3的路径可以两两拼。所以答案就加上上面那个式子
然后进行点分,不断递归就可以了
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int inf=0x7777777f;
char buf[1<<15],*fs,*ft;
inline char getc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;}
inline int read() {
int n=1,num=0; char ch=getchar();
while(!isdigit(ch)) {n=(ch=='-')?-1:1;ch=getchar();}
while(isdigit(ch)) {num=(num<<1)+(num<<3)+(ch^48);ch=getchar();}
return n*num;
}
char sr[1<<21],z[20];int C=-1,Z;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
inline void print(int x){
if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
while(z[++Z]=x%10+48,x/=10);
while(sr[++C]=z[Z],--Z);
}
const int N=20005,mod=3;
int head[N],Next[N<<1],edge[N<<1],ver[N<<1];ll ans=0;
int sz[N],son[N],sum[4],vis[N];
int size,mx,rt,n,tot;
inline void add(int u,int v,int e){
ver[++tot]=v,Next[tot]=head[u],head[u]=tot,edge[tot]=e;
ver[++tot]=u,Next[tot]=head[v],head[v]=tot,edge[tot]=e;
}
void getrt(int u,int fa){
sz[u]=1,son[u]=0;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(vis[v]||v==fa) continue;
getrt(v,u);
sz[u]+=sz[v];
son[u]=max(son[u],sz[v]);
}
son[u]=max(son[u],size-sz[u]);
if(son[u]<mx) mx=son[u],rt=u;
}
void query(int u,int fa,int d){
++sum[d%mod];
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(vis[v]||v==fa) continue;
query(v,u,(d+edge[i])%mod);
}
}
ll solve(int rt,int d){
sum[0]=sum[1]=sum[2]=0;
query(rt,0,d);
ll res=1ll*sum[1]*sum[2]*2+1ll*sum[0]*sum[0];
return res;
}
void divide(int u){
ans+=solve(u,0);
vis[u]=1;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(vis[v]) continue;
ans-=solve(v,edge[i]);
mx=inf,rt=0,size=sz[v];
getrt(v,0);
divide(rt);
}
}
inline ll gcd(ll a,ll b){
while(b^=a^=b^=a%=b);
return a;
}
int main(){
n=read();
for(int i=1;i<n;++i){
int u=read(),v=read(),e=read();
add(u,v,e%3);
}
mx=inf,size=n,ans=0,rt=0;
getrt(1,0),divide(rt);
ll p=n*n,GCD=gcd(ans,p);
print(ans/GCD),sr[++C]='/',print(p/GCD);
Ot();
return 0;
}
洛谷3806 【模板】点分治1
#include<bits/stdc++.h>
#define inf 0x7777777f
#define ll long long
const int N=100010;
using namespace std;
char buf[1<<15],*fs,*ft;
inline char getc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;}
inline int read() {
int n=1,num=0; char ch=getc();
while(!isdigit(ch)) {n=(ch=='-')?-1:1;ch=getc();}
while(isdigit(ch)) {num=(num<<1)+(num<<3)+(ch^48);ch=getc();}
return n*num;
}
int ans[10000005],link[100010];
struct gg {
int y,next,v;
}a[1000010];
int sz[N],son[N],st[N];bool vis[N];
int n,m,size,mx,rt,tot,top;
inline void add(int x,int y,int v) {
a[++tot].y=y;a[tot].next=link[x];link[x]=tot;a[tot].v=v;
a[++tot].y=x;a[tot].next=link[y];link[y]=tot;a[tot].v=v;
}
void getrt(int x,int fa) {
sz[x]=1,son[x]=0;
for(int i=link[x];i;i=a[i].next) {
int v=a[i].y;
if(vis[v]||v==fa) continue;
getrt(v,x);
sz[x]+=sz[v],son[x]=max(son[x],sz[v]);
}
son[x]=max(son[x],size-sz[x]);
if(son[x]<mx) mx=son[x],rt=x;
}
void ask(int x,int fa,int d) {
st[++top]=d;
for(int i=link[x];i;i=a[i].next) {
int v=a[i].y;
if(vis[v]||v==fa) continue;
ask(v,x,d+a[i].v);
}
}
void work(int x,int d,int f) {
top=0;
ask(x,0,d);
if(f) {
for(int i=1;i<top;++i)
for(int j=i+1;j<=top;++j) ++ans[st[i]+st[j]];
}
else {
for(int i=1;i<top;++i)
for(int j=i+1;j<=top;++j) --ans[st[i]+st[j]];
}
}
void dfs(int x) {
vis[x]=1;
work(x,0,1);
for(int i=link[x];i;i=a[i].next) {
int v=a[i].y;
if(vis[v]) continue;
work(v,a[i].v,0);
mx=inf,rt=0,size=sz[v];
getrt(v,0); dfs(rt);
}
}
int main(){
n=read(),m=read();
for(int i=1;i<n;++i){
int u=read(),v=read(),e=read();
add(u,v,e);
}
rt=0,mx=inf,size=n;
getrt(1,0),dfs(rt);
while(m--){
int k=read();
puts(ans[k]?"AYE":"NAY");
}
return 0;
}