题目大意:
一颗树 有一个点的集合
对于每个集合的答案为 从集合内一个点遍历集合内所有点再返回的距离最小值
每次可以选择一个点 若在集合外便加入集合 若在集合内就删除
求每次操作后这个集合的答案
思路:
对于每个集合
它的答案一定为从dfs序最小的开始依次遍历再回来
当加入一个点x的时候 可以找到它dfs序的前驱与后继 画图可得 ans+=dis(pre,x)+dis(x,sub)-dis(pre,sub) 删除的时候为ans-=
特别地 当x没有前驱或后继时 前驱为最大值 后继为最小值(当做一个环
因此我们维护一颗平衡树搞一下即可
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<cstdlib> 5 #include<cstring> 6 #include<algorithm> 7 #include<vector> 8 #include<queue> 9 #include<set> 10 #define inf 2139062143 11 #define ll long long 12 #define MAXN 100100 13 using namespace std; 14 inline int read() 15 { 16 int x=0,f=1;char ch=getchar(); 17 while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();} 18 while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();} 19 return x*f; 20 } 21 int ch[MAXN][2],fa[MAXN],sz,cnt[MAXN],val[MAXN],size[MAXN],rt; 22 int which(int x) {return x==ch[fa[x]][1];} 23 int find_pre() 24 { 25 int pos=ch[rt][0]; 26 while(ch[pos][1]) pos=ch[pos][1]; 27 return pos; 28 } 29 int find_max() 30 { 31 int pos=ch[rt][1]; 32 while(ch[pos][1]) pos=ch[pos][1]; 33 return pos; 34 } 35 int find_min() 36 { 37 int pos=ch[rt][0]; 38 while(ch[pos][0]) pos=ch[pos][0]; 39 return pos; 40 } 41 int find_sub() 42 { 43 int pos=ch[rt][1]; 44 while(ch[pos][0]) pos=ch[pos][0]; 45 return pos; 46 } 47 void upd(int x) 48 { 49 if(!x) return ; 50 size[x]=cnt[x]+size[ch[x][1]]+size[ch[x][0]]; 51 } 52 void rotate(int pos) 53 { 54 int f=fa[pos],ff=fa[f],k=which(pos); 55 ch[f][k]=ch[pos][k^1],fa[ch[f][k]]=f,fa[f]=pos,ch[pos][k^1]=f,fa[pos]=ff; 56 if(ff) ch[ff][ch[ff][1]==f]=pos; 57 upd(f);upd(pos); 58 } 59 void splay(int x) 60 { 61 for(int f;f=fa[x];rotate(x)) 62 if(fa[f]) rotate((which(x)==which(f)?f:x)); 63 rt=x; 64 } 65 void Insert(int x) 66 { 67 int pos=rt,f=0; 68 while(1) 69 { 70 if(val[pos]==x) {cnt[pos]++,upd(pos);upd(f);splay(pos);return ;} 71 f=pos,pos=ch[pos][x>val[pos]]; 72 if(!pos) 73 { 74 ch[++sz][0]=ch[sz][1]=0,fa[sz]=f,val[sz]=x,cnt[sz]=size[sz]=1,ch[f][x>val[f]]=sz; 75 upd(f);splay(sz);return ; 76 } 77 } 78 } 79 void insert(int x) 80 { 81 if(!rt) {val[++sz]=x,ch[sz][0]=ch[sz][1]=fa[sz]=0,cnt[sz]=size[sz]=1,rt=sz;return;} 82 Insert(x); 83 } 84 int find_rank(int x) 85 { 86 int res=0,pos=rt; 87 while(1) 88 { 89 if(x<val[pos]) pos=ch[pos][0]; 90 else 91 { 92 res+=size[ch[pos][0]]; 93 if(val[pos]==x) {splay(pos);return res+1;} 94 res+=cnt[pos],pos=ch[pos][1]; 95 } 96 } 97 } 98 void dlt(int x) 99 { 100 if(cnt[rt]>1) {cnt[rt]--;return ;} 101 if(!ch[rt][0]&&!ch[rt][1]) {rt=0;return ;} 102 if(!ch[rt][0]||!ch[rt][1]) 103 { 104 int k=!ch[rt][1]?0:1; 105 rt=ch[rt][k],fa[rt]=0; 106 return ; 107 } 108 int k=find_pre(),tmp=rt; 109 splay(k);fa[ch[tmp][1]]=rt; 110 ch[rt][1]=ch[tmp][1],rt=k; 111 } 112 int n,m,nxt[MAXN<<1],fst[MAXN],to[MAXN<<1],Val[MAXN<<1],Cnt; 113 int f[MAXN][20],dep[MAXN],s[MAXN],k[MAXN],tot,hsh[MAXN],HSH[MAXN],vis[MAXN]; 114 ll ans,dis[MAXN]; 115 void add(int u,int v,int w) {nxt[++Cnt]=fst[u],fst[u]=Cnt,to[Cnt]=v,Val[Cnt]=w;} 116 void dfs(int x) 117 { 118 for(int i=1;(1<<i)<=dep[x];i++) f[x][i]=f[f[x][i-1]][i-1]; 119 hsh[x]=++tot,HSH[tot]=x; 120 for(int i=fst[x];i;i=nxt[i]) 121 if(to[i]!=f[x][0]) 122 { 123 dis[to[i]]=dis[x]+Val[i],dep[to[i]]=dep[x]+1; 124 f[to[i]][0]=x;dfs(to[i]); 125 } 126 } 127 int lca(int u,int v) 128 { 129 int t; 130 if(dep[u]<dep[v]) swap(u,v); 131 t=dep[u]-dep[v]; 132 for(int i=0;i<20;i++) 133 if((1<<i)&t) u=f[u][i]; 134 if(u==v) return u; 135 for(int i=19;i>=0;i--) 136 if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i]; 137 return f[u][0]; 138 } 139 inline ll calc(int u,int v) {return dis[HSH[u]]+dis[HSH[v]]-(dis[lca(HSH[u],HSH[v])]<<1);} 140 int main() 141 { 142 n=read(),m=read();int a,b,c; 143 for(int i=1;i<n;i++) {a=read(),b=read(),c=read();add(a,b,c);add(b,a,c);} 144 dfs(1); 145 while(m--) 146 { 147 c=hsh[read()],vis[c]^=1; 148 if(vis[c]) 149 { 150 if(!rt) {puts("0");insert(c);continue;}insert(c); 151 a=val[find_pre()],b=val[find_sub()]; 152 if(!a) a=val[find_max()];if(!b) b=val[find_min()]; 153 if(!a) a=val[rt];if(!b) b=val[rt]; 154 ans+=calc(a,c)+calc(b,c)-calc(a,b); 155 } 156 else 157 { 158 a=find_rank(c); 159 a=val[find_pre()],b=val[find_sub()]; 160 if(!a) a=val[find_max()];if(!b) b=val[find_min()]; 161 if(!a) a=val[rt];if(!b) b=val[rt]; 162 ans-=calc(a,c)+calc(b,c)-calc(a,b); 163 dlt(c); 164 } 165 printf("%lld ",ans); 166 } 167 }