又是历史遗留题,收藏了好久才做的(
https://www.luogu.com.cn/problem/P5559
考虑如何计算一个点 (u) 到链 ((x,y)) 的距离,设 (operatorname{LCA}(x,y)=lca),则距离即 (dis(u,lca)) 减去到 (lca) 的路径上与这条链重合的部分的长度
进一步的,就是 (dis(u,root)+dis(lca,root)-2dis(operatorname{LCA}(u,lca),root)) 减去 (u) 到 (lca) 的路径上与这条链重合的部分的长度
那么这个重合长度如何求?其实就是 (u) 到 (root) 的路径,与这个链重合的长度,为了方便的维护这个值,可以使用树剖,就是从 (u) 到 (root) 每个边(其实是点,边信息下放到点)都加一,但这个加一并不是值加一,而是每个点对应的边的长度乘的系数加一,也就是没个点分别加上它所对应的边的长度,用线段树维护
设这个长度为 (len),每一个点 (u) 的答案贡献就是 (dis(u,root)+dis(lca,root)-2dis(operatorname{LCA}(u,lca),root)-len),这个 (dis(u,root)) 很好求,就用一个全局变量维护就行,(dis(lca,root)) 自然也不用说
对于 (dis(operatorname{LCA}(u,lca),root)),可以通过上面说的维护重合长度的方式求出,就是用线段树访问 (root) 到 (lca) 的距离和即可求出 (sum dis(operatorname{LCA}(u,lca),root))
可以画图理解一下,会发现从不同位置到它和 (lca) 的 (operatorname{LCA}) 的点,所走过的路径,总会“拼成” (root) 到 (lca) 的路径
还有就是要开 long long
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<iomanip>
#include<cstring>
#define reg register
#define EN puts("")
inline int read(){
register int x=0;register int y=1;
register char c=std::getchar();
while(c<'0'||c>'9'){if(c=='-') y=0;c=std::getchar();}
while(c>='0'&&c<='9'){x=x*10+(c^48);c=std::getchar();}
return y?x:-x;
}
#define N 200006
#define M 400006
int n;
struct graph{
int fir[N],nex[M],to[M],w[M],tot;
inline void add(int u,int v,int W){
to[++tot]=v;w[tot]=W;
nex[tot]=fir[u];fir[u]=tot;
}
}G;
int size[N],son[N],val[N],deep[N];
long long sum[N];
int fa[22][N],top[N];
int dfn[N],rank[N],dfscnt;
int yes[N];
void dfs(int u){
size[u]=1;
for(reg int v,i=G.fir[u];i;i=G.nex[i]){
v=G.to[i];
if(v==fa[0][u]) continue;
fa[0][v]=u;deep[v]=deep[u]+1;
sum[v]=sum[u]+G.w[i];
val[v]=G.w[i];
for(reg int j=1;j<20;j++) fa[j][v]=fa[j-1][fa[j-1][v]];
dfs(v);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs2(int u,int topnow){
top[u]=topnow;
dfn[u]=++dfscnt;rank[dfscnt]=u;
if(!son[u]) return;
dfs2(son[u],topnow);
for(reg int i=G.fir[u];i;i=G.nex[i])if(!dfn[G.to[i]]) dfs2(G.to[i],G.to[i]);
}
inline int getlca(int u,int v){
if(deep[u]<deep[v]) u^=v,v^=u,u^=v;
for(reg int i=19;~i;i--)if(deep[fa[i][u]]>=deep[v]) u=fa[i][u];
if(u==v) return u;
for(reg int i=19;~i;i--)if(fa[i][u]^fa[i][v]) u=fa[i][u],v=fa[i][v];
return fa[0][u];
}
struct tr{
tr *ls,*rs;
int tag;
long long sum,sub;
}dizhi[N*2],*root=&dizhi[0];
int tot;
inline void pushup(tr *tree){
tree->sum=tree->ls->sum+tree->rs->sum;
tree->sub=tree->ls->sub+tree->rs->sub;
}
void build(tr *tree,int l,int r){
if(l==r) return tree->sum=val[rank[l]],void();
int mid=(l+r)>>1;
tree->ls=&dizhi[++tot];tree->rs=&dizhi[++tot];
build(tree->ls,l,mid);build(tree->rs,mid+1,r);
pushup(tree);
}
inline void pushdown(tr *tree){
if(!tree->tag) return;
reg int tag=tree->tag;tree->tag=0;
tree->ls->tag+=tag;tree->rs->tag+=tag;
tree->ls->sub+=tag*tree->ls->sum;
tree->rs->sub+=tag*tree->rs->sum;
}
long long qsub(tr *tree,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr) return tree->sub;
int mid=(l+r)>>1;
pushdown(tree);
long long ret=0;
if(ql<=mid) ret+=qsub(tree->ls,l,mid,ql,qr);
if(qr>mid) ret+=qsub(tree->rs,mid+1,r,ql,qr);
return ret;
}
void change(tr *tree,int l,int r,int ql,int qr,int k){
if(ql<=l&&r<=qr){
tree->tag+=k;
tree->sub+=k*tree->sum;
return;
}
int mid=(l+r)>>1;
pushdown(tree);
if(ql<=mid) change(tree->ls,l,mid,ql,qr,k);
if(qr>mid) change(tree->rs,mid+1,r,ql,qr,k);
pushup(tree);
}
inline long long getsub(reg int x,reg int y){
long long ret=0;
while(top[x]^top[y]){
if(deep[top[x]]<deep[top[y]]) x^=y,y^=x,x^=y;
ret+=qsub(root,1,n,dfn[top[x]],dfn[x]);
x=fa[0][top[x]];
}
if(dfn[x]>dfn[y]) x^=y,y^=x,x^=y;
if(x^y) ret+=qsub(root,1,n,dfn[x]+1,dfn[y]);
return ret;
}
inline void update(reg int x,reg int y,int k){
while(top[x]^top[y]){
if(deep[top[x]]<deep[top[y]]) x^=y,y^=x,x^=y;
change(root,1,n,dfn[top[x]],dfn[x],k);
x=fa[0][top[x]];
}
if(dfn[x]>dfn[y]) x^=y,y^=x,x^=y;
if(x^y) change(root,1,n,dfn[x]+1,dfn[y],k);
}
long long S;
int main(){
n=read();int q=read();read();
for(reg int u,v,w,i=1;i<n;i++){
u=read();v=read();w=read();
G.add(u,v,w);G.add(v,u,w);
}
deep[1]=1;
dfs(1);dfs2(1,1);
int now=0;
build(root,1,n);
for(reg int i=1;i<=n;i++)if(yes[i]=read()) S+=sum[i],now++,update(1,i,1);
reg int op,x,y;
while(q--){
op=read();x=read();
if(op==1){
yes[x]^=1;
if(yes[x]) S+=sum[x],update(1,x,1),now++;
else S-=sum[x],update(1,x,-1),now--;
}
else{
y=read();
int lca=getlca(x,y);
// printf("qsize : %d lca : %d S : %d
",qsize(root,1,n,dfn[lca]),lca,S);
printf("%lld
",S+now*sum[lca]-getsub(1,lca)*2-getsub(x,y));
}
}
return 0;
}