啥是点分治?
点分治一般来说用于解决大规模的树上路径问题。比如说最经典的一道题,给定一棵树,计算一共有多少点对满足之间距离<=k。
这种题一般数据范围在10^4~10^5,直接暴力求是n^2的肯定会超时。
那怎么办?我们考虑分治。
先说一下点分治的基本思想,就是对于每一棵树,先找到这棵树的重心。
啥叫重心?重心就是一棵树中最大的子树节点最少的那个点。求树的重心就是直接暴力DFS,但是毕竟人家是O(n)的。
先看一眼代码。
void getroot(int x,int fa) { size[x] = 1,maxs[x] = 0; for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(t == fa || vis[t]) continue; getroot(t,x); size[x] += size[t]; maxs[x] = max(maxs[x],size[t]); } maxs[x] = max(maxs[x],sum - size[x]); if(maxs[x] < maxs[root]) root = x; }
其中size记录节点子树大小,maxs记录最大子树节点大小。
最后一步就是因为还要计算自己的父亲和父亲之上的一些树,所以要进行更新。
这样我们就成功的O(n)求出了重心。
好接着上面的说,我们求出重心之后,首先统计这棵子树之内的所有答案,之后再分别递归到重心分割开的所有子树里面去统计答案。
怎么统计呢?首先我们对于每棵子树,我们从根开始向下进行dfs,更新到达每个点需要的距离,并且把出现过的点的距离全部加入当前统计数组里面。之后把统计数组排序,从两头开始找,只要当前两点之间的距离小于等于k,那么就直接加上r-l个答案,直到l>r为止。
不过这里有一些问题要注意,就是如果直接这么统计会出问题,因为我们只计算那些经过重心的道路,而这种计算方法它会在本次计算中重复计算一些子树中的情况,这样再向下递归的时候答案就会算重。所以对于每次计算,我们要再减去所有子树中的可能情况。之后向下递归求子树重心,继续求解。
这就是大致的操作了。然后注意的是每次我们需要不断更改当前树大小(这个好做直接用size赋值)
还有就是我们要每次在getroot之前把根结点的值赋成0,毕竟各个过程是相对独立的,可以避免很多不必要的麻烦。
总复杂度大概是O(nlog^2n),更多的细节看一下代码。
#include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<iostream> #include<queue> #include<set> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar(' ') using namespace std; typedef long long ll; const int M = 100005; int read() { int ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >= '0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } struct edge { int next,to,v; }e[M]; int n,k,head[M],dis[M],ecnt,size[M],maxs[M],root,x,y,z,sum,tot,cur[M],ans; bool vis[M]; void add(int x,int y,int z) { e[++ecnt].to = y; e[ecnt].v = z; e[ecnt].next = head[x]; head[x] = ecnt; } void getroot(int x,int fa)//求树的重心 { size[x] = 1,maxs[x] = 0; for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(t == fa || vis[t]) continue; getroot(t,x); size[x] += size[t]; maxs[x] = max(maxs[x],size[t]); } maxs[x] = max(maxs[x],sum - size[x]); if(maxs[x] < maxs[root]) root = x; } void getdis(int x,int fa,int leng)//leng记录当前的长度 { cur[++tot] = leng; for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(t == fa || vis[t]) continue; getdis(t,x,leng + e[i].v); } } int calc(int x,int leng) { tot = 0; getdis(x,0,leng); sort(cur+1,cur+1+tot); int l = 1,r = tot,temp = 0; while(l < r)//排序之后从两头开始找。因为对于每两个点如果其符合,那么对于l节点,从l+1~r全部是合法的点对。 { if(cur[l] + cur[r] <= k) temp += r - l,l++; else r--; } rep(i,1,tot) cur[i] = 0;//注意这里不能使用memset,否则会超时 return temp; } void solve(int x) { vis[x] = 1;ans += calc(x,0);//计算当前子的答案 for(int i = head[x];i;i = e[i].next) { int t = e[i].to; if(vis[t]) continue; ans -= calc(t,e[i].v);//减去每棵子树的答案,注意这里必须把初始的长度传进去,否则你相当于统计的路径长度少了一截 sum = size[t],maxs[root = 0] = n; getroot(t,0),solve(root);//继续找重心之后递归求解 } } int main() { n = read(); rep(i,1,n-1) x = read(),y = read(),z = read(),add(x,y,z),add(y,x,z); k = read(); sum = maxs[root] = n,getroot(1,0);//找到树的重心 solve(root);//开始递归求解 printf("%d ",ans); return 0; }