点分治好题
统计距离正常点分治统计即可,我们只需考虑何时达到最优
有两种情况:
第一:代价最大的询问两个端点在不同的两个子树中
因为这种情况下,无论根向那个子树移动都会等价地增加到达另一个端点的代价,因此此时总代价已经达到最小
第二:代价最大的询问有多组,且这些点不在同一棵子树中
同情况一,如果我们把根偏向其中某一棵子树移动的话,那么一定会等价地增大另一端的代价,因此总代价已经达到最小
这样的话直接统计就好了
贴代码:
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> using namespace std; struct Edge { int nxt; int to; int val; }edge[200005]; int q[100005][2]; int siz[100005]; int rt,s; int head[100005]; int maxp[100005]; int ans=0x3f3f3f3f; bool vis[100005]; int dis[100005]; int bel[100005]; int ccl[100005]; int col=0,typ=0; int cnt=1,maxx=0; int n,m; void add(int l,int r,int w) { edge[cnt].nxt=head[l]; edge[cnt].to=r; edge[cnt].val=w; head[l]=cnt++; } void get_rt(int x,int fx) { siz[x]=1,maxp[x]=0; for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(to==fx||vis[to])continue; get_rt(to,x); siz[x]+=siz[to]; maxp[x]=max(maxp[x],siz[to]); } maxp[x]=max(maxp[x],s-siz[x]); if(maxp[x]<maxp[rt])rt=x; } void dfs(int x,int fx,int dep) { dis[x]=dep; if(!fx)bel[x]=0; for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(to==fx)continue; if(!fx)col++; bel[to]=col; dfs(to,x,dep+edge[i].val); } } int calc(int x) { dis[x]=maxx=0,dfs(x,0,0); typ++; for(int i=1;i<=m;i++) { ccl[i]=0; if(dis[q[i][0]]+dis[q[i][1]]>maxx)maxx=dis[q[i][0]]+dis[q[i][1]],ccl[i]=++typ; else if(dis[q[i][0]]+dis[q[i][1]]==maxx)ccl[i]=typ; } return typ; } void solve(int x) { vis[x]=1; int k=calc(x); int ori=0; for(int i=1;i<=m;i++) { if((bel[q[i][0]]!=bel[q[i][1]])&&ccl[i]==k){ori=-1;break;} else if(ccl[i]==k&&!ori)ori=bel[q[i][0]]; else if(ccl[i]==k)if(bel[q[i][0]]!=ori){ori=-1;break;} } ans=min(ans,maxx); for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(vis[to])continue; if(bel[to]==ori)rt=0,s=siz[to],get_rt(to,x),solve(rt); } } inline int read() { int f=1,x=0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int main() { n=read(),m=read(); for(int i=1;i<n;i++) { int x=read(),y=read(),z=read(); add(x,y,z),add(y,x,z); } maxp[0]=0x3f3f3f3f; for(int i=1;i<=m;i++)q[i][0]=read(),q[i][1]=read(); rt=0,s=n,get_rt(1,0),solve(rt); printf("%d ",ans); return 0; }