传送门
这道题难点不在 lct,而是传递 lazy 标记,和之前做的那道维护数列一样,核心都是传递平衡树懒标记。
之前那道做对了的,这道没注意乘 0 的情况,调了很久。Orz。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=1e5+10;
const int mod=51061;
int n,m;
struct LCT{
LL val[N],sum[N],add[N],mul[N];
int rev[N],ch[N][2],fa[N],siz[N];
#define ls(x) ch[x][0]
#define rs(x) ch[x][1]
inline int ident(int p,int f){return ch[f][1]==p;}
inline void connect(int p,int f,int k){ch[f][k]=p;fa[p]=f;}
inline int ntroot(int p){return ls(fa[p])==p||rs(fa[p])==p;}
void pushup(int p){
sum[p]=(sum[ls(p)]+sum[rs(p)]+val[p])%mod;
siz[p]=siz[ls(p)]+siz[rs(p)]+1;
}
void rotate(int p){
int f=fa[p],ff=fa[f],k=ident(p,f);
connect(ch[p][k^1],f,k);
fa[p]=ff;
if(ntroot(f)) ch[ff][ident(f,ff)]=p;
connect(f,p,k^1);
pushup(f),pushup(p);
}
void Rev(int p){swap(ls(p),rs(p));rev[p]^=1;}
void Mul(int p,LL x){(sum[p]*=x)%=mod;(add[p]*=x)%=mod;(val[p]*=x)%=mod;(mul[p]*=x)%=mod;}
void Add(int p,LL x){(sum[p]+=siz[p]*x)%=mod;(val[p]+=x)%=mod;(add[p]+=x)%=mod;}
void pushdw(int p){
if(rev[p]) Rev(ls(p)),Rev(rs(p));
if(mul[p]!=1) Mul(ls(p),mul[p]),Mul(rs(p),mul[p]);
if(add[p]) Add(ls(p),add[p]),Add(rs(p),add[p]);
rev[p]=add[p]=0;mul[p]=1;
}
void pushall(int p){if(ntroot(p)) pushall(fa[p]);pushdw(p);}
void splay(int p){
pushall(p);
while(ntroot(p)){
int f=fa[p],ff=fa[f];
if(ntroot(f)) ident(p,f)^ident(f,ff)?rotate(p):rotate(f);
rotate(p);
}
pushup(p);
}
void access(int p){for(int t=0;p;t=p,p=fa[p]) splay(p),rs(p)=t,pushup(p);}
void makert(int p){access(p);splay(p);Rev(p);}
int findrt(int p){access(p);splay(p);while(ls(p)) p=ls(p),pushdw(p);splay(p);return p;}
void link(int p,int q){makert(p);fa[p]=q;}
void split(int p,int q){makert(p);access(q);splay(q);}
void cut(int p,int q){split(p,q);if(!rs(p)&&ls(q)==p) fa[p]=ls(q)=0;}
}lct;
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) lct.mul[i]=lct.val[i]=lct.siz[i]=1;
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
lct.link(u,v);
}
for(int i=1,x,y,z,k;i<=m;i++){
char opt[10];scanf("%s",opt);
if(opt[0]=='+'){
scanf("%d%d%d",&x,&y,&z);
lct.split(x,y);lct.Add(y,z);
}
else if(opt[0]=='-'){
scanf("%d%d%d%d",&x,&y,&z,&k);
if(lct.findrt(x)==lct.findrt(y)) lct.cut(x,y);
if(lct.findrt(z)!=lct.findrt(k)) lct.link(z,k);
}
else if(opt[0]=='*'){
scanf("%d%d%d",&x,&y,&z);
lct.split(x,y);lct.Mul(y,z);
}
else if(opt[0]=='/'){
scanf("%d%d",&x,&y);
lct.split(x,y);
printf("%lld
",lct.sum[y]);
}
}
return 0;
}