题解
LCT的模板题之一。
像《线段树2》一样,先打乘法标记再打加法标记。
注意 (51061^2 = 2607225721),会爆 int 。可以使用 uint,不过还是推荐 long long。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 100000 + 5;
const int mod = 51061;
namespace LCT
{
// PartI. Splay
int fa[maxn], ch[2][maxn], size[maxn], st[maxn], sz; ll val[maxn], sum[maxn], tagsum[maxn], tagmul[maxn]; bool flip[maxn];
bool inline notroot(int o) {
return (ch[0][fa[o]] == o) || (ch[1][fa[o]] == o);
}
void inline settag(int o, ll add, ll mul) {
val[o] = (val[o] * mul + add) % mod;
sum[o] = (sum[o] * mul + size[o] * add) % mod;
tagsum[o] = (tagsum[o] * mul + add) % mod;
tagmul[o] = (tagmul[o] * mul) % mod;
}
void inline pushdown(int o) {
if(flip[o]) {
if(ch[0][o]) flip[ch[0][o]] ^= 1;
if(ch[1][o]) flip[ch[1][o]] ^= 1;
swap(ch[0][o], ch[1][o]);
flip[o] = false;
}
if(ch[0][o]) settag(ch[0][o], tagsum[o], tagmul[o]);
if(ch[1][o]) settag(ch[1][o], tagsum[o], tagmul[o]);
tagmul[o] = 1, tagsum[o] = 0;
}
void inline pushup(int o) {
sum[o] = (sum[ch[0][o]] + sum[ch[1][o]] + val[o]) % mod;
size[o] = size[ch[0][o]] + size[ch[1][o]] + 1;
}
void inline rotate(int x) {
int y = fa[x], z = fa[y], d = ch[1][y] == x;
if(notroot(y)) ch[y == ch[1][z]][z] = x; fa[x] = z;
ch[d][y] = ch[d^1][x]; if(ch[d][y]) fa[ch[d][y]] = y; ch[d^1][x] = y; fa[y] = x;
pushup(y); pushup(x);
}
void inline splay(int x) {
int o = x;
st[sz = 1] = o;
while(notroot(o)) st[++sz] = o = fa[o];
while(sz) pushdown(st[sz--]);
while(notroot(x)) {
int y = fa[x], z = fa[y];
if(notroot(y)) rotate(((ch[0][z] == y) ^ (ch[0][y] == x)) ? x : y);
rotate(x);
}
pushup(x);
}
// PartII. LCT
void inline access(int x) {
for(int y = 0; x; y = x, x = fa[x]) {
splay(x);
ch[1][x] = y;
pushup(x);
}
}
void inline makeroot(int x) {
access(x);
splay(x);
flip[x] ^= 1;
}
void inline split(int x, int y) {
makeroot(x);
access(y); splay(y);
}
void inline link(int x, int y) {
makeroot(x);
fa[x] = y;
}
void inline cut(int x, int y) {
split(x, y);
fa[x] = ch[0][y] = 0;
pushup(y);
}
}
int n, m;
void inline Init()
{
scanf("%d %d", &n, &m);
for(int i = 1; i <= n; ++i) {
LCT::val[i] = 1;
LCT::tagmul[i] = 1;
LCT::sum[i] = 1;
LCT::size[i] = 1;
}
int u, v;
for(int i = 1; i < n; ++i) {
scanf("%d %d", &u, &v);
LCT::link(u, v);
}
}
char opt[5]; int u, v, c, u2, v2;
using namespace LCT;
void inline Solve()
{
while(m--) {
scanf("%s %d %d", opt, &u, &v);
if(opt[0] == '+') {
scanf("%d", &c);
split(u, v);
settag(v, c, 1);
} else if(opt[0] == '*') {
scanf("%d", &c);
split(u, v);
settag(v, 0, c);
} else if(opt[0] == '-') {
scanf("%d %d", &u2, &v2);
cut(u, v);
link(u2, v2);
} else {
split(u, v);
printf("%lld
", sum[v]);
}
}
}
int main()
{
Init();
Solve();
return 0;
}