Description
Solution
用set按dfs序维护当前的宝物序列,那么答案为相邻2个点的距离加上头尾2个的距离
Code
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <set>
#define Inf 0x7fffffff
#define ll long long
#define N 100010
using namespace std;
struct info{int to,nex,w;}e[N*2];
int n,m,tot,head[N],dfn[N],dep[N],fa[N][20],_log,id[N];
ll dis[N],Ans,s_t;
bool b[N];
set<int> q;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void Link(int u,int v,int w){
e[++tot].to=v;e[tot].nex=head[u];head[u]=tot;e[tot].w=w;
}
void dfs(int u,int pre){
dfn[u]=++tot;id[tot]=u;
for(int i=1;i<=_log;++i)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u];i;i=e[i].nex){
int v=e[i].to;
if(v==pre) continue;
dep[v]=dep[u]+1;
dis[v]=dis[u]+e[i].w;
fa[v][0]=u;
dfs(v,u);
}
head[u]=0;
}
int LCA(int u,int v){
if(dep[u]>dep[v]) swap(u,v);
int d=dep[v]-dep[u];
for(int i=0;i<=_log;++i)
if(d&(1<<i)) v=fa[v][i];
if(u==v) return v;
for(int i=_log;i>=0;--i)
if(fa[u][i]!=fa[v][i]){
u=fa[u][i];
v=fa[v][i];
}
return fa[u][0];
}
ll Dis(int a,int b){
int f=LCA(a,b);
return dis[a]+dis[b]-2*dis[f];
}
int main(){
n=read(),m=read();_log=log(n)/log(2);
for(int i=1;i<n;++i){
int u=read(),v=read(),w=read();
Link(u,v,w);Link(v,u,w);
}
tot=0;dfs(1,0);
q.insert(Inf),q.insert(-Inf);
while(m--){
int x=read(),f,l,r;
if(!b[x]) q.insert(dfn[x]),f=1;
else q.erase(dfn[x]),f=-1;
b[x]^=1;
l=*--q.lower_bound(dfn[x]),r=*q.upper_bound(dfn[x]);
if(l!=-Inf) Ans+=(ll)f*Dis(id[l],x);
if(r!=Inf) Ans+=(ll)f*Dis(id[r],x);
if(l!=-Inf&&r!=Inf) Ans-=(ll)f*Dis(id[l],id[r]);
if(q.size()>3) s_t=Dis(id[*q.upper_bound(-Inf)],id[*--q.lower_bound(Inf)]);else s_t=0;
printf("%lld
",Ans+s_t);
}
return 0;
}