题面
小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。
小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物
分析
本质上是在一棵树上取出若干节点,询问把这几个节点访问一遍的距离
可以发现如果我们按照dfs序将节点排序,然后将排序后的相邻节点距离相加,最后再加上序列首尾距离,就能求出答案
如序列为{1,3,4,5},则答案为dist(1,3)+dist(3,4)+dist(4,5)+dist(5,1)
因为我们访问节点一定是像dfs一样访问才能得到最短路径,所以正确性显然
我们用一个set维护这个排序后的节点序列
可以发现,每次新加入一个节点之后只会改变节点的前驱和后继相关的距离,维护一下即可
代码
//https://www.luogu.org/problemnew/show/P3320
#include<iostream>
#include<cstdio>
#include<set>
#include<cmath>
#define INF 0x3f3f3f3f
#define maxn 100005
#define maxlogn 20
using namespace std;
int n,m;
struct edge{
int from;
int to;
int next;
int len;
}E[maxn<<1];
int head[maxn];
int sz=1;
void add_edge(int u,int v,int w){
sz++;
E[sz].from=u;
E[sz].to=v;
E[sz].next=head[u];
E[sz].len=w;
head[u]=sz;
}
int logn;
int cnt=0;
int dfn[maxn];
int hash_dfn[maxn];//存储dfs序为i的节点编号
int anc[maxn][maxlogn];
int deep[maxn];
long long dist[maxn];
void dfs(int x,int fa){
dfn[x]=++cnt;
hash_dfn[dfn[x]]=x;
deep[x]=deep[fa]+1;
anc[x][0]=fa;
for(int i=1;i<=logn;i++) anc[x][i]=anc[anc[x][i-1]][i-1];
for(int i=head[x];i;i=E[i].next){
int y=E[i].to;
if(y!=fa){
dist[y]=dist[x]+E[i].len;
dfs(y,x);
}
}
}
int lca(int x,int y){
if(deep[x]<deep[y]) swap(x,y);
for(int i=logn;i>=0;i--){
if(deep[anc[x][i]]>=deep[y]){
x=anc[x][i];
}
}
if(x==y) return x;
for(int i=logn;i>=0;i--){
if(anc[x][i]!=anc[y][i]){
x=anc[x][i];
y=anc[y][i];
}
}
return anc[x][0];
}
long long get_dist(int x,int y){
return dist[x]+dist[y]-2*dist[lca(x,y)];
}
set<int>seq; //set里实际上存的是节点的dfs序值而不是编号
long long sum=0,ans;
int main(){
int u,v,w,x;
scanf("%d %d",&n,&m);
logn=log2(n);
for(int i=1;i<n;i++){
scanf("%d %d %d",&u,&v,&w);
add_edge(u,v,w);
add_edge(v,u,w);
}
dfs(1,0);
seq.insert(INF);//防止越界
seq.insert(-INF);
for(int i=1;i<=m;i++){
scanf("%d",&x);
if(!seq.count(dfn[x])){
int l=*(--seq.lower_bound(dfn[x]));
int r=*seq.upper_bound(dfn[x]);
//注意边界判断
if(l!=-INF) sum+=get_dist(hash_dfn[l],x);
if(r!=INF) sum+=get_dist(hash_dfn[r],x);
if(l!=-INF&&r!=INF) sum-=get_dist(hash_dfn[l],hash_dfn[r]);
seq.insert(dfn[x]);
}else{
int l=*(--seq.lower_bound(dfn[x]));
int r=*seq.upper_bound(dfn[x]);
if(l!=-INF) sum-=get_dist(hash_dfn[l],x);
if(r!=INF) sum-=get_dist(hash_dfn[r],x);
if(l!=-INF&&r!=INF) sum+=get_dist(hash_dfn[l],hash_dfn[r]);
seq.erase(seq.lower_bound(dfn[x]));
}
ans=sum;
if(seq.size()-2>=2) ans+=get_dist(hash_dfn[*(++seq.lower_bound(-INF))],hash_dfn[*(--seq.lower_bound(INF))]);
//记得加上首尾距离,注意我们一开始插入了-INF和+INF,所以set的大小>=4时才会有序列首尾
printf("%lld
",ans);
}
}