题目描述
H国有 n个城市,这 n个城市用n-1条双向道路相互连通构成一棵树,1号城市是首都,也是树中的根节点。
H国的首都爆发了一种危害性极高的传染病。当局为了控制疫情,不让疫情扩散到边境城市(叶子节点所表示的城市),决定动用军队在一些城市建立检查点,使得从首都到边境城市的每一条路径上都至少有一个检查点,边境城市也可以建立检查点。但特别要注意的是,首都是不能建立检查点的。
现在,在 H国的一些城市中已经驻扎有军队,且一个城市可以驻扎多个军队。一支军队可以在有道路连接的城市间移动,并在除首都以外的任意一个城市建立检查点,且只能在一个城市建立检查点。一支军队经过一条道路从一个城市移动到另一个城市所需要的时间等于道路的长度(单位:小时)。
请问最少需要多少个小时才能控制疫情。注意:不同的军队可以同时移动。
输入格式
第一行一个整数n,表示城市个数。
接下来的 n-1行,每行3个整数,u,v,w,每两个整数之间用一个空格隔开,表示从城市 u到城市v有一条长为 w的道路。数据保证输入的是一棵树,且根节点编号为 1。
接下来一行一个整数 mm,表示军队个数。
接下来一行 m个整数,每两个整数之间用一个空格隔开,分别表示这 m个军队所驻扎的城市的编号。
输出格式
一个整数,表示控制疫情所需要的最少时间。如果无法控制疫情则输出-1。
保证军队不会驻扎在首都。
对于 20%的数据,2≤ n≤ 10;
对于 40%的数据,2 ≤n≤50,0<w <10^5;
对于 60%的数据,2 ≤ n≤1000,0<w <10^6;
对于 80%的数据,2 ≤ n≤10,000;
对于 100%的数据,2≤m≤n≤50,000,0<w <10^9。
首先明确一点,军队显然只会往上走而不会往下走,因为往上走能控制更多的城市,而往下走相反。所以我们有了一个贪心策略:让所有军队尽量地往上走,但不要走到首都的位置。
由于最后一支军队驻扎的时间就是答案,显然我们需要让所有军队驻扎的时间的最大值最小化。这里我们可以用到二分答案。首先二分出一个mid作为判定合法的limitation。然后我们根据上面的贪心策略,让所有军队在lim的时间内尽量地往上走。然后我们会发现会有军队走到了首都的儿子节点处,但是时间还没到lim,此时让它越过首都到达首都的其它未被驻扎的儿子节点可能会更优。
所以我们的下一步就是——处理出第一步所有军队尽量往上走之后还未被控制的首都的子节点。注意根据题意,这里的'控制'一词含义为——这个点有军队驻扎或者它的所有儿子都被控制。根据这个定义,我们只需要dfs就可以求出来首都的哪些子节点未被控制。
这时我们明确思路,我们设rest[x]表示第x支军队到达首都之后(注意是首都)与lim相差多少时间,也就是剩下的时间。然后我们显然让剩余时间最多的军队去驻扎离首都最远的子节点是最优的。所以我们把首都的未被控制的子节点按照与首都的距离从大到小排序,再把所有时间没到lim的军队按照rest从大到小排序,然后判断能否在lim内把所有地方控制完即可。
然而我们会发现一个问题,我们让有剩余时间的军队越过首都去驻扎其它城市可能不是最优的。如果这个城市本来就走上来了一个时间没到lim的军队,我们显然让这支军队驻扎在这座城市是最优的。但如果这个城市有多支时间没到lim的军队呢?我们显然是让rest最小的一支留下来。所以我们还需要记下首都的子节点的rest最小的军队。
然而如果直接这样做,我们的时间复杂度为——
[O((MN+N+MlogM+Degree(1)*log_{2}Degree(1)+Degree(1))*log_{2}Ans)
]
degree表示这个点的度数。解释一下复杂度是怎么来的:MN是每支军队往上走的复杂度,N是dfs寻找没被占领的城市的复杂度,MlogM是对军队排序的复杂度,Degree(1) * logDegree(1)是对首都的子节点排序的复杂度,最后的degree(1)表示判合法的复杂度
其中MN这一项我们是吃不消的,可以考虑优化。我们的做法是一个一个地往上爬,太慢了,这里我们可以用倍增来加速。那么每次二分的复杂度就从N^2级别降为了NlogN级别,总复杂度就是Nloglog,50000的数据是可以过的。
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define maxn 50001
using namespace std;
struct edge{
int to,dis,next;
edge(){}
edge(const int &_to,const int &_dis,const int &_next){ to=_to,dis=_dis,next=_next; }
}e[maxn<<1];
int head[maxn],k;
struct node{
int id; long long rest;
bool operator<(const node &x)const{ return rest>x.rest; }
}a[maxn],b[maxn];
int da,db,army[maxn];
int minrest[maxn],min_id[maxn];
bool vis[maxn],use[maxn];
int fa[maxn][20],dep[maxn],maxdep;
long long dis[maxn][20];
int n,m;
inline int read(){
register int x(0),f(1); register char c(getchar());
while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); }
while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
return x*f;
}
inline void add(const int &u,const int &v,const int &w){ e[k]=edge(v,w,head[u]),head[u]=k++; }
void dfs(int u){
for(register int i=head[u];~i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0]) continue;
fa[v][0]=u,dis[v][0]=e[i].dis,dep[v]=dep[u]+1;//这里的dep似乎是没有用到的变量
for(register int i=1;i<=maxdep;i++) fa[v][i]=fa[fa[v][i-1]][i-1],dis[v][i]=dis[v][i-1]+dis[fa[v][i-1]][i-1];
dfs(v);
}
}
bool dfs_check(int u){
if(vis[u]) return true;
register bool flag1=false,flag2=true;
for(register int i=head[u];~i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0]) continue;
flag1=true;
if(!dfs_check(v)){
flag2=false;
if(u==1) b[++db].rest=e[i].dis,b[db].id=v;
else return false;
}
}
return flag1&&flag2;
}
inline bool check(const long long &lim){
da=db=0,memset(min_id,false,sizeof min_id),memset(use,false,sizeof use),memset(vis,false,sizeof vis);
for(register int i=1;i<=m;i++){
int u=army[i]; long long len=0;
for(register int i=maxdep;i>=0;i--) if(fa[u][i]>1&&len+dis[u][i]<=lim) len+=dis[u][i],u=fa[u][i];
if(fa[u][0]==1&&len+dis[u][0]<=lim){
a[++da].rest=lim-len-dis[u][0],a[da].id=i;
if(!min_id[u]||a[da].rest<minrest[u]) min_id[u]=i,minrest[u]=a[da].rest;
}else vis[u]=true;
}
if(dfs_check(1)) return true;
sort(a+1,a+1+da),sort(b+1,b+1+db);
use[0]=true;
for(register int i=1,j=1;i<=db;i++){
if(!use[min_id[b[i].id]]){ use[min_id[b[i].id]]=true; continue; }
while(j<=da&&(use[a[j].id]||a[j].rest<b[i].rest)) j++;
if(j>da) return false; use[a[j].id]=true;
}
return true;
}
int main(){
memset(head,-1,sizeof head);
n=read(),maxdep=(int)log(n)/log(2)+1;
for(register int i=1;i<n;i++){
int u=read(),v=read(),w=read();
add(u,v,w),add(v,u,w);
}
dfs(1);//处理树上倍增的信息
m=read();
for(register int i=1;i<=m;i++) army[i]=read();
long long l=0,r=1e18,mid,ans=-1;
while(l<=r){
mid=l+r>>1;
if(check(mid)) ans=mid,r=mid-1;
else l=mid+1;
}
printf("%lld
",ans);
return 0;
}