zoukankan      html  css  js  c++  java
  • JZOJ 7036. 2021.03.30【2021省赛模拟】凌乱平衡树(平衡树单旋+权值线段树)

    JZOJ 7036. 2021.03.30【2021省赛模拟】凌乱平衡树

    题目大意

    • 给出两棵Treap,大小分别为 n , m n,m n,m,每个点的 p r i o r i t y priority priority值为子树大小(因此满足大根堆性质), Q Q Q次修改(修改是永久的),每次单旋一个节点,求修改前和每次修改后后两树合并之后的所有节点深度之和。合并按照Treap的合并方式,左树根为 x x x,右树根为 y y y时,当 s i z e x ≥ s i z e y size_xge size_y sizexsizey时以 x x x为根,否则反之。
    • 1 ≤ n , m , Q ≤ 2 ∗ 1 0 5 1le n,m,Qle2*10^5 1n,m,Q2105

    题解

    • 考虑合并的过程,记录当前深度 d p dp dp,左树根每次向右走,就加上左儿子 F + G ∗ d p F+G*dp F+Gdp,含义是所有点到左子树根的深度加上到实际的根深度差值。右边同理。这样需要在每次单旋后重新计算每个子树的大小 G G G及以该子树根为根的深度和 F F F G G G可以在常数复杂度内维护,但 F F F不行。
    • 换一种思路,记录总的深度和 s u m sum sum,每次求出合并后增加的差值。这样合并的过程中,左树根每次向右走,则加上右树根的 G G G,含义是它子树内所有点的深度都会被增加 1 1 1。右边同理。
    • 而合并时左树根始终向右,右树根始终向左,其它的节点是不会经过的,且与它相关的值也不会调用到,所以可以把左根向右和右根向左两条链(以下称为链)单独看,设链上 G G G序列左边依次为 A A A,右边为 B B B A i A_i Ai对答案的贡献次数为 ( A i , A i − 1 ] (A_i,A_{i-1}] (Ai,Ai1] B B B的个数, B i B_i Bi对答案的贡献次数为 [ B i , B i − 1 ) [B_i,B_{i-1}) [Bi,Bi1) A A A的个数,注意这里区间的开闭情况。
    • 那么可以用权值线段树维护,把初始的 A A A B B B都存进同一棵权值线段树中,在单旋时进行修改。
    • 只有两种情况需要修改:
    • 1、单旋的节点 x x x x x x的父亲都在链中;
    • 2、单旋的节点 x x x不在链中, x x x的父亲在链中。
    • 修改时因为 G G G值会改变,所以需要先删除该点及其贡献,修改完 G G G后再加入回来。修改的贡献不仅有它自己的贡献,还有 A A A B B B中它们前驱的贡献。

    代码

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    #define N 200010
    #define ll long long
    struct {
    	int p[2];
    }f[N * 4];
    int ns;
    ll ans;
    void is(int v, int l, int r, int x, int o, int c) {
    	if(l == r) {
    		f[v].p[o] += c;
    	}
    	else {
    		int mid = (l + r) / 2;
    		if(x <= mid) is(v * 2, l, mid, x, o, c); else is(v * 2 + 1, mid + 1, r, x, o, c);
    		f[v].p[o] = f[v * 2].p[o] + f[v * 2 + 1].p[o];
    	}
    }
    int get(int v, int l, int r, int x, int y, int o) {
    	if(x > y) return 0;
    	if(l == x && r == y) return f[v].p[o];
    	int mid = (l + r) / 2;
    	if(y <= mid) return get(v * 2, l, mid, x, y, o);
    	if(x > mid) return get(v * 2 + 1, mid + 1, r, x, y, o);
    	return get(v * 2, l, mid, x, mid, o) + get(v * 2 + 1, mid + 1, r, mid + 1, y, o);
    }
    int find(int v, int l, int r, int x, int y, int k, int o) {
    	if(f[v].p[o] < k || x > y) return -1;
    	if(l == r) return l;
    	int mid = (l + r) / 2;
    	if(y <= mid) return find(v * 2, l, mid, x, y, k, o);
    	if(x > mid) return find(v * 2 + 1, mid + 1, r, x, y, k, o);
    	int s = get(v * 2, l, mid, x, mid, o);
    	if(s >= k) return find(v * 2, l, mid, x, mid, k, o);
    	return find(v * 2 + 1, mid + 1, r, mid + 1, y, k - s, o);
    }
    int find0(int v, int l, int r, int x, int y, int k, int o) {
    	if(f[v].p[o] < k || x > y) return -1;
    	if(l == r) return l;
    	int mid = (l + r) / 2;
    	if(y <= mid) return find0(v * 2, l, mid, x, y, k, o);
    	if(x > mid) return find0(v * 2 + 1, mid + 1, r, x, y, k, o);
    	int s = get(v * 2 + 1, mid + 1, r, mid + 1, y, o);
    	if(s >= k) return find0(v * 2 + 1, mid + 1, r, mid + 1, y, k, o);
    	return find0(v * 2, l, mid, x, mid, k - s, o);
    }
    ll count(int x, int o) {
    	if(!o) {
    		int t = find(1, 1, ns, x, ns, 2, 0);
    		if(t == -1) t = ns;
    		return (ll)get(1, 1, ns, x + 1, t, 1) * x;
    	}
    	else {
    		int t = find(1, 1, ns, x, ns, 2, 1);
    		if(t == -1) t = ns + 1;
    		return (ll)get(1, 1, ns, x, t - 1, 0) * x;
    	}
    }
    int fr(int x, int o) {
    	if(!o) {
    		int t = find0(1, 1, ns, 1, x, 1, 1);
    		return t == -1 ? 0 : t;
    	}
    	else {
    		int t = find0(1, 1, ns, 1, x - 1, 1, 0);
    		return t == -1 ? 0 : t;
    	}
    }
    struct {
    	int s, rt, p[N];
    	ll F[N], si[N], sum;
    	struct {
    		int s[2], fa, p;	
    	}f[N];
    	void ins(int r, int l, int i) {
    		f[i].s[0] = l, f[i].s[1] = r;
    		f[l].fa = f[r].fa = i;
    		f[l].p = 0, f[r].p = 1;
    	}
    	void ro(int x, int o) {
    		int y = f[x].fa, z = f[y].fa, py = f[x].p, pz = f[y].p;
    		f[z].s[pz] = x, f[x].fa = z, f[x].p = pz;
    		f[y].s[py] = f[x].s[py ^ 1], f[f[x].s[py ^ 1]].fa = y, f[f[x].s[py ^ 1]].p = py;
    		f[x].s[py ^ 1] = y, f[y].fa = x, f[y].p = py ^ 1;
    		if(rt == y) rt = x;
    		int tp;
    		if(p[y] && p[x]) {
    			ans -= count(si[x], o) + count(si[y], o);
    			ans -= fr(si[y], o) + fr(si[x], o);
    			tp = find0(1, 1, ns, 1, si[x], 2, o);
    			if(tp > 0) ans -= count(tp, o);
    			is(1, 1, ns, si[y], o, -1);
    			is(1, 1, ns, si[x], o, -1);
    		}
    		else if(p[y] && !p[x]) {
    			ans -= count(si[y], o);
    			ans -= fr(si[y], o);
    			tp = find0(1, 1, ns, 1, si[y], 2, o);
    			if(tp > 0) ans -= count(tp, o);
    			is(1, 1, ns, si[y], o, -1);
    		}
    		
    		si[y] = si[f[y].s[0]] + si[f[y].s[1]] + 1;
    		si[x] = si[f[x].s[0]] + si[f[x].s[1]] + 1;
    		sum += si[f[y].s[py ^ 1]] - si[f[x].s[py]];
    		
    		if(p[y] && p[x]) {
    			is(1, 1, ns, si[x], o, 1);
    			ans += fr(si[x], o);
    			ans += count(si[x], o);
    			if(tp > 0) ans += count(tp, o);
    			p[y] = 0;
    		}
    		else if(p[y] && !p[x]) {
    			is(1, 1, ns, si[x], o, 1);
    			is(1, 1, ns, si[y], o, 1);
    			ans += fr(si[y], o) + fr(si[x], o);
    			ans += count(si[y], o) + count(si[x], o);
    			if(tp > 0) ans += count(tp, o);
    			p[x] = 1;
    		}
    	}
    	int find() {
    		for(int i = 1; i <= s; i++) if(f[i].fa == 0) return i;
    	}
    	void dfs(int k) {
    		F[k] = 1, si[k] = 1;
    		if(f[k].s[0]) dfs(f[k].s[0]), si[k] += si[f[k].s[0]], F[k] += F[f[k].s[0]] + si[f[k].s[0]];
    		if(f[k].s[1]) dfs(f[k].s[1]), si[k] += si[f[k].s[1]], F[k] += F[f[k].s[1]] + si[f[k].s[1]];
    	}
    }a, b;
    void solve() {
    	int x = a.rt;
    	while(x) is(1, 1, ns, a.si[x], 0, 1), a.p[x] = 1, x = a.f[x].s[1];
    	x = b.rt;
    	while(x) is(1, 1, ns, b.si[x], 1, 1), b.p[x] = 1, x = b.f[x].s[0];
    	ans = 0;
    	x = a.rt;
    	while(x) ans += count(a.si[x], 0) ,x = a.f[x].s[1];
    	x = b.rt;
    	while(x) ans += count(b.si[x], 1), x = b.f[x].s[0];
    	printf("%lld
    ", ans + a.sum + b.sum);
    }
    int read() {
    	int s = 0;
    	char x = getchar();
    	while(x < '0' || x > '9') x = getchar();
    	while(x >= '0' && x <= '9') s = s * 10 + x - 48, x = getchar();
    	return s;
    }
    int main() {
    	int Q, i;
    	scanf("%d%d", &a.s, &b.s);
    	for(i = 1; i <= a.s; i++) {
    		a.ins(read(), read(), i);
    	}
    	for(i = 1; i <= b.s; i++) {
    		b.ins(read(), read(), i);
    	}
    	ns = max(a.s, b.s) + 1;
    	a.rt = a.find(), b.rt = b.find(); 
    	a.dfs(a.rt), b.dfs(b.rt);
    	a.sum = a.F[a.rt], b.sum = b.F[b.rt];
    	scanf("%d", &Q);
    	solve();
    	while(Q--) {
    		if(read() == 1) a.ro(read(), 0); else b.ro(read(), 1);
    		printf("%lld
    ", ans + a.sum + b.sum);
    	}
    	return 0;
    }
    

    自我小结

    • 细节比较多,各条语句中的顺序很重要,需要理清楚。
    哈哈哈哈哈哈哈哈哈哈
  • 相关阅读:
    JavaScript——类型检测
    JavaScript——语法与数据类型
    .NET下使用 Seq结构化日志系统
    Vs Code搭建 TypeScript 开发环境
    Entity Framework Core一键生成实体命令
    使用TestServer测试ASP.NET Core API
    Entity Framework Core导航属性加载问题
    Autofac创建实例的方法总结
    .NET Exceptionless 日志收集框架本地环境搭建
    依赖注入和控制反转
  • 原文地址:https://www.cnblogs.com/LZA119/p/14608423.html
Copyright © 2011-2022 走看看