zoukankan      html  css  js  c++  java
  • 【题解】 「NOI2017」整数 线段树+二分+压位 LOJ2302

    Legend

    Link to LOJ

    请维护一个高精二进制数 \(s\),支持操作 \(n\ (0 \le n \le 10^6)\) 次:

    • 加或减 \(a\times 2^{b}\)\((|a|\le 10^9,0\le b\le 30n)\)
    • 查询 \(s \operatorname{and} 2^b\) 的结果转化为 \(\textrm{bool}\) 后是否为真。

    时空 \(\textrm{2s/512MB}\)

    Editorial

    作为 \(\textrm{NOI2017}\) 的第一题,必定是一道良心送温暖题,让我们一起为出题人松松松鼓掌。

    brute

    容易看到底下有一些部分分,映入眼帘的便是 \(|a|=1\),于是我们就想到把每一个加减操作看成 \(O(\log a)\) 次加减单个二进制位。怎么样?是不是看起来简单一点了?

    考虑直接模拟。假设现在做加法的是位置 \(l\),如果这一位是 \(0\) 就直接改成 \(1\),否则即找到之后第一个为 \(0\) 的位置 \(p\ (l < p)\),把 \([l,p-1]\) 都改成 \(0\),并把位置 \(p\) 改成 \(1\)

    减法同理。如果这一位是 \(1\) 就直接改成 \(0\),否则即找到之后第一个为 \(1\) 的位置 \(p\ (l < p)\),把 \([l,p-1]\) 都改成 \(1\),并把位置 \(p\) 改成 \(0\)

    以上两个操作可以直接在线段树上二分找到,打区间覆盖标记。查询则可以直接使用线段树单点查询。

    于是就得到了一个复杂度为 \(O(n \log n\log a)\) 的做法。

    optimization

    上述做法的瓶颈在于:

    • 数组长度是 \(30n\),凭空多出来一个常数。
    • 要进行拆位,\(1\) 个操作变成了 \(\log a\) 个。

    不妨往反方向考虑,把数组压位,连续 \(32\) 个数字用一个 \(\textrm{unsigned int}\) 存储。

    这样子对于一个修改操作我们最多只要拆成两个。而查询连续 \(1\) 段和连续 \(0\) 段依然可以用线段树实现,代码相差无几。

    但这样就可以把复杂度优化到 \(O\left(\dfrac{n \log n \log a}{\omega}\right)\),其中 \(\omega\) 为压位大小。

    Code

    写的时候有点犯迷糊,最开始用 \(\textrm{unsigned int}\) 存了读入的 \(a\),后来又没写线段树的 \(\textrm{pushup pushdown}\),最后发现线段树二分写错了……白白浪费了一个下午+晚上。

    就这样修修补补写出了下面这些东西,有点繁琐了,但还可以看。

    LOJ 上这破烂可以在 \(\textrm{800ms}\) 内跑过。

    // Author : Imakf
    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define LL long long
    #define debug(...) fprintf(stderr ,__VA_ARGS__)
    #define __FILE(x)\
    	freopen(#x".in" ,"r" ,stdin);\
    	freopen(#x".out" ,"w" ,stdout)
    
    LL read(){
    	char k = getchar(); LL x = 0 ,flg = 1;
    	while(k < '0' || k > '9')
    		flg *= k == '-' ? -1 : 1 ,k = getchar();
    	while(k >= '0' && k <= '9')
    		x = x * 10 + k - '0' ,k = getchar();
    	return x * flg;
    }
    
    
    const int MX = 1e6 + 233;
    
    struct node{
    	int l ,r ,c;
    	unsigned int num;
    	bool zero ,all ,cov;
    	node *lch ,*rch;
    }*root;
    
    void pushup(node *x){
    	x->zero = x->lch->zero & x->rch->zero;
    	x->all  = x->lch->all  & x->rch->all;
    }
    
    node *build(int l ,int r){
    	node *x = new node;
    	x->l       = l;
    	x->r       = r;
    	x->zero    = true;
    	x->all     = false;
    	x->cov     = false;
    	x->c       = 0;
    	x->num     = 0;
    	if(l == r){
    		x->lch = nullptr;
    		x->rch = nullptr;
    	}
    	else{
    		int mid = (l + r) >> 1;
    		x->lch = build(l ,mid);
    		x->rch = build(mid + 1 ,r);
    		pushup(x);
    	}return x;
    }
    void docov(node *x ,bool v){
    	x->cov  = true;
    	x->c    = v;
    	x->zero = !v;
    	x->all  = v;
    	x->num  = v ? UINT_MAX : 0;
    }
    void pushdown(node *x){
    	if(x->cov){
    		x->cov = false;
    		docov(x->lch ,x->c);
    		docov(x->rch ,x->c);
    	}
    }
    
    void cov(node *x ,int l ,int r ,bool val){
    	if(l <= x->l && x->r <= r) return docov(x ,val);
    	pushdown(x);
    	if(l <= x->lch->r) cov(x->lch ,l ,r ,val);
    	if(r > x->lch->r) cov(x->rch ,l ,r ,val);
    	return pushup(x);
    }
    
    void add(node *x ,LL v){
    	x->num += v;
    	x->all  = x->num == UINT_MAX;
    	x->zero = x->num == 0;
    }
    
    int __add(node *x ,int l ,int r){ // 找到最小的不是全 1 的 pos
    	if(x->r < l || x->l > r) return 0;
    	if(x->all) return 0;
    	if(x->l == x->r){
    		return add(x ,1) ,x->l;
    	}
    	pushdown(x);
    	int ret = 0;
    	if(x->lch->all || !(ret = __add(x->lch ,l ,r))){
    		ret = __add(x->rch ,l ,r);
    	}
    	pushup(x);
    	return ret;
    }
    
    void add(node *x ,int p ,LL val){
    	if(x->l == x->r){
    		if(x->num + val > UINT_MAX){
    			x->num = (x->num + val) & UINT_MAX;
    			add(x ,0);
    			int pos = __add(root ,p + 1 ,MX);
    			if(pos - 1 >= p + 1) cov(root ,p + 1 ,pos - 1 ,0);
    		}
    		else add(x ,val);
    		return ;
    	}
    	pushdown(x);
    	if(p <= x->lch->r) add(x->lch ,p ,val);
    	else add(x->rch ,p ,val);
    	return pushup(x);
    }
    
    void add(LL a ,LL b){
    	// add a*(2^b)
    	int bit32 = b / 32 ,bit = b % 32;
    	LL f = a << bit;
    	if(f > UINT_MAX){
    		add((f & UINT_MAX) >> bit ,b);
    		add(f >> 32 ,(bit32 + 1) * 32);
    		return ;
    	}
    	// debug("%lld %lld\n" ,a ,b);
    	add(root ,bit32 ,f);
    }
    
    int __del(node *x ,int l ,int r){ // 找到最小的不是全 0 的 pos
    	// debug("Find [%d ,%d] ,allzero = %d\n" ,x->l ,x->r ,x->zero);
    	if(x->r < l || x->l > r) return 0;
    	if(x->zero) return 0;
    	if(x->l == x->r){
    		return add(x ,-1) ,x->l;
    	}
    	pushdown(x);
    	int ret = 0;
    	if(x->lch->zero || !(ret = __del(x->lch ,l ,r))){
    		ret = __del(x->rch ,l ,r);
    	}
    	pushup(x);
    	return ret;
    }
    
    void del(node *x ,int p ,LL val){
    	if(x->l == x->r){
    		if(x->num - val < 0){
    			x->num = x->num - val + UINT_MAX + 1;
    			add(x ,0);
    			int pos = __del(root ,p + 1 ,MX);
    			if(pos - 1 >= p + 1) cov(root ,p + 1 ,pos - 1 ,1);
    		}
    		else add(x ,-val);
    		return ;
    	}
    	pushdown(x);
    	if(p <= x->lch->r) del(x->lch ,p ,val);
    	else del(x->rch ,p ,val);
    	return pushup(x);
    }
    
    void sub(LL a ,LL b){
    	int bit32 = b / 32 ,bit = b % 32;
    	LL f = a << bit;
    	if(f > UINT_MAX){
    		sub((f & UINT_MAX) >> bit ,b);
    		sub(f >> 32 ,(bit32 + 1) * 32);
    		return ;
    	}
    	del(root ,bit32 ,f);
    }
    
    LL query(node *x ,int p){
    	if(x->l == x->r) return x->num;
    	pushdown(x);
    	if(p <= x->lch->r) return query(x->lch ,p);
    	return query(x->rch ,p);
    }
    
    int query(int pos){
    	int bit32 = pos / 32 ,bit = pos % 32;
    	return (query(root ,bit32) >> bit) & 1;
    }
    
    void output(node *x){
    	if(x->l == x->r){
    		for(int i = 0 ; i < 32 ; ++i){
    			debug("%u" ,(x->num >> i) & 1);
    		}
    		return;
    	}
    	pushdown(x);
    	output(x->lch) ,output(x->rch);
    }
    
    int main(){
    	__FILE([NOI2017]整数);
    	
    	int n = read(); read() ,read() ,read();
    	root = build(0 ,MX);
    	for(LL i = 1 ,op ,a ,b ; i <= n ; ++i){
    		// debug("%d\n" ,i);
    		op = read();
    		if(op == 1){
    			a = read() ,b = read();
    			// assert(a >= 0);
    			if(a > 0) add(a ,b);
    			else sub(-a ,b);
    		}
    		else{
    			a = read();
    			printf("%d\n" ,query(a));;
    		}
    		// output(root);
    		// debug("\n");
    	}
    }
    
  • 相关阅读:
    Spring框架学习09——基于AspectJ的AOP开发
    Spring框架学习08——自动代理方式实现AOP
    Spring框架学习07——基于传统代理类的AOP实现
    Spring框架学习06——AOP底层实现原理
    Spring框架学习05——AOP相关术语详解
    SpringMVC框架09——@ResponseBody的用法详解
    Spring框架学习04——复杂类型的属性注入
    Spring框架学习03——Spring Bean 的详解
    Spring框架学习01——使用IDEA开发Spring程序
    sqlserver 迁移
  • 原文地址:https://www.cnblogs.com/imakf/p/13623627.html
Copyright © 2011-2022 走看看