[题目链接]
http://codeforces.com/contest/992/problem/E
[算法]
线段树 + 二分
时间复杂度 : O(NlogN^2)
[代码]
#include<bits/stdc++.h> using namespace std; const int MAXN = 2e5 + 10; typedef long long ll; struct Node { int l,r; ll mx,sum; } Tree[MAXN << 2]; int i,n,q,x,y,cur,tmp,ans; ll value[MAXN]; ll pre; template <typename T> inline void read(T &x) { int f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) { if (c == '-') f = -f; } for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } inline void update(int index) { Tree[index].mx = max(Tree[index << 1].mx,Tree[index << 1 | 1].mx); Tree[index].sum = Tree[index << 1].sum + Tree[index << 1 | 1].sum; } inline void build(int index,int l,int r) { int mid; Tree[index].l = l; Tree[index].r = r; if (l == r) { Tree[index].mx = value[l]; Tree[index].sum = value[l]; return; } mid = (l + r) >> 1; build(index << 1,l,mid); build(index << 1 | 1,mid + 1,r); update(index); } inline void modify(int index,int pos,int val) { int mid; if (Tree[index].l == Tree[index].r) { Tree[index].mx = Tree[index].sum = val; return; } mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= pos) modify(index << 1,pos,val); else modify(index << 1 | 1,pos,val); update(index); } inline int query(int index,int l,int r,ll val) { int mid,tmp; if (Tree[index].l == l && Tree[index].r == r) { if (Tree[index].mx < val) return -1; if (l == r) return l; mid = (Tree[index].l + Tree[index].r) >> 1; if (Tree[index << 1].mx >= val) return query(index << 1,l,mid,val); else return query(index << 1 | 1,mid + 1,r,val); } mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= r) tmp = query(index << 1,l,r,val); else if (mid + 1 <= l) tmp = query(index << 1 | 1,l,r,val); else { tmp = query(index << 1,l,mid,val); if (tmp != -1) return tmp; return query(index << 1 | 1,mid + 1,r,val); } return tmp; } inline ll query_sum(int index,int l,int r) { int mid; if (Tree[index].l == l && Tree[index].r == r) return Tree[index].sum; mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= r) return query_sum(index << 1,l,r); else if (mid + 1 <= l) return query_sum(index << 1 | 1,l,r); else return query_sum(index << 1,l,mid) + query_sum(index << 1 | 1,mid + 1,r); } int main() { read(n); read(q); for (i = 1; i <= n; i++) read(value[i]); build(1,1,n); while (q--) { read(x); read(y); value[x] = y; modify(1,x,y); if (value[1] == 0) { printf("1 "); continue; } cur = pre = tmp = 0; ans = -1; while (cur < n) { tmp = query(1,cur + 1,n,pre); if (tmp == -1) break; cur = tmp; pre = query_sum(1,1,tmp); if (pre - value[cur] == value[cur]) { ans = tmp; break; } } printf("%d ",ans); } return 0; }