http://poj.org/problem?id=1741
题意:
给出一棵树,求出树上满足两点权值之和不大于k的点对数。
思路:
直接暴力就是$O(n^2)$,显然不行。这里有一篇论文可以推荐大家看一下。
因为一条路径要么过根结点,要么不过,即在一棵子树中,然后我们可以用分治法。每次确定一个根u,u有几棵子树,就将u的子树分成了几部分,每一部分就一直递归下去。
我们设dis[i]为i到根结点的距离,于是我们可以得到以下的算法:$dis[i]+dis[j]<=k的(i,j)对数-dis[i]+dis[j]<=k的(i,j)对数且(i,j)在一棵子树中$。这个部分其实不难算,我们只需要先把所有子节点的dis值求出,然后排个序,利用尺取法可以快速求出个数。
还有一点要注意的是,根结点的选择很重要,这里需要选择重心,因为删去重心后最大子树的节点数最小。
1 #include<iostream> 2 #include<algorithm> 3 #include<cstring> 4 #include<cstdio> 5 #include<vector> 6 #include<stack> 7 #include<queue> 8 #include<cmath> 9 #include<map> 10 #include<set> 11 using namespace std; 12 typedef long long ll; 13 typedef pair<int,int> pll; 14 const int INF = 0x3f3f3f3f; 15 const int maxn=10000+5; 16 17 int n,k,num,root,sz,ans,mi; 18 int vis[maxn], son[maxn], dis[maxn], mx[maxn]; 19 vector<pll> G[maxn]; 20 21 22 void getroot(int u, int fa, int n) //计算重心 23 { 24 son[u]=0; 25 int balance=0; 26 for(int i=0;i<G[u].size();i++) 27 { 28 int v=G[u][i].first; 29 if(v==fa || vis[v]) continue; 30 getroot(v,u,n); 31 son[u]+=son[v]+1; 32 balance=max(balance,son[v]+1); 33 } 34 balance=max(balance,n-son[u]-1); 35 if(balance<sz) 36 { 37 sz=balance; 38 root=u; 39 } 40 } 41 42 43 void getdis(int u, int fa, int d) //求子节点到根结点的距离 44 { 45 dis[num++]=d; 46 for(int i=0;i<G[u].size();i++) 47 { 48 int v=G[u][i].first; 49 int w=G[u][i].second; 50 if(v==fa || vis[v]) continue; 51 getdis(v,u,d+w); 52 } 53 } 54 55 int getnum() 56 { 57 int ret=0; 58 sort(dis,dis+num); 59 int l=0,r=num-1; 60 while(l<r) //尺取法快速求对数 61 { 62 while(dis[l]+dis[r]>k && l<r) r--; 63 ret+=r-l; 64 l++; 65 } 66 return ret; 67 } 68 69 void solve(int u, int n) 70 { 71 sz=INF; 72 getroot(u,-1,n); 73 num=0; 74 getdis(root,-1,0); 75 ans+=getnum(); 76 vis[root]=1; 77 int tmp=root; 78 for(int i=0;i<G[tmp].size();i++) 79 { 80 int v=G[tmp][i].first; 81 int w=G[tmp][i].second; 82 if(!vis[v]) 83 { 84 num=0; 85 getdis(v,-1,w); //减去在同一棵子树的情况 86 ans-=getnum(); 87 solve(v,son[v]+1); 88 } 89 } 90 } 91 92 int main() 93 { 94 //freopen("in.txt","r",stdin); 95 while(~scanf("%d%d",&n,&k)) 96 { 97 if(!n && !k) break; 98 memset(vis,0,sizeof(vis)); 99 for(int i=1;i<=n;i++) G[i].clear(); 100 for(int i=1;i<n;i++) 101 { 102 int u,v,w; 103 scanf("%d%d%d",&u,&v,&w); 104 G[u].push_back(make_pair(v,w)); 105 G[v].push_back(make_pair(u,w)); 106 } 107 ans=0; 108 solve(1,n); 109 printf("%d ",ans); 110 } 111 return 0; 112 }