题目链接:https://codeforces.com/contest/1354
想法:
很明显的权值线段树(值域线段树)板子题。
#include <algorithm> #include <string> #include <cstring> #include <vector> #include <map> #include <stack> #include <set> #include <queue> #include <cmath> #include <cstdio> #include <iomanip> #include <ctime> #include <bitset> #include <cmath> #include <sstream> #include <iostream> #include <unordered_map> #define ll long long #define ull unsigned long long #define ls nod<<1 #define rs (nod<<1)+1 #define pii pair<int,int> #define mp make_pair #define pb push_back #define INF 0x3f3f3f3f3f3f3f3f #define max(a, b) (a>b?a:b) #define min(a, b) (a<b?a:b) const double eps = 1e-10; const int maxn = 1e6 + 10; const ll MOD = 99999999999999; int sgn(double a) { return a < -eps ? -1 : a < eps ? 0 : 1; } using namespace std; int a[maxn]; struct val_segment_tree{ int val; }tree[maxn<<2]; void build(int nod,int l,int r) { if (l == r) { tree[nod].val = a[l]; // 值为 l 的个数 return; } int mid = (l + r) >> 1; build(nod<<1,l,mid); build((nod<<1)+1,mid+1,r); tree[nod].val = tree[ls].val + tree[rs].val; } void add(int nod,int l,int r,int t,int v) { if (l == r) { tree[nod].val += v; return ; } int mid = (l + r) >> 1; if (t <= mid) { add(nod<<1,l,mid,t,v); } else add((nod<<1)+1,mid+1,r,t,v); tree[nod].val = tree[nod<<1].val + tree[(nod<<1)+1].val; } // 删除区间内第num大 void del(int nod,int l,int r,int num,int v) { if (l == r) { if (tree[nod].val) tree[nod].val -= v; return ; } int mid = (l + r) >> 1; if (num > tree[ls].val) del(rs,mid+1,r,num-tree[ls].val,v); else del(ls,l,mid,num,v); tree[nod].val = tree[ls].val + tree[rs].val; } // 查找区间第num大 int query(int nod,int l,int r,int num) { if (l == r) { if (tree[nod].val) return l; return -1; // 不存在 } int mid = (l + r) >> 1; if (num <= tree[ls].val) return query(ls,l,mid,num); else return query(rs,mid+1,r,num-tree[ls].val); } int find(int nod,int l,int r) { if (l == r) { if (tree[nod].val) return l; return 0; } int mid = (l + r) >> 1; if (tree[ls].val) return find(ls,l,mid); else return find(rs,mid+1,r); } int main() { ios::sync_with_stdio(false); int n,q; cin >> n >> q; for (int i = 1;i <= n;i++) { int v; cin >> v; a[v]++; } build(1,1,n); while (q--) { int k; cin >> k; if (k < 0) { k = -k; del(1,1,n,k,1); } else { add(1,1,n,k,1); } } cout << find(1,1,n) << endl; return 0; }