枚举两个人都喜欢的个数,就能得到单个喜欢的个数,然后用平衡树维护前k大的和。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define PII pair<int, int> #define PLI pair<LL, int> #define ull unsigned long long using namespace std; const int N = 2e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 998244353; const int Mod = 1e9 + 7; int n, m, k, num, cnt, state[N], d[N], a[N], b[N], c[N]; LL suma[N], sumb[N], sumc[N]; struct node { node* ch[2]; int key, fix, sz, cnt; LL sum; void update() { sz = ch[0]->sz + ch[1]->sz + cnt; sum = ch[0]->sum + ch[1]->sum + 1ll*cnt*key; } }; typedef node* P_node; struct Treap { node base[N], nil; P_node root, null, len; Treap() { root = null = &nil; null->key = null->fix = 1e9; null->sz = null->cnt = 0; null->ch[0] = null->ch[1] = null; len = base; } P_node newnode(int tkey) { len->key = tkey; len->fix = rand(); len->ch[0] = len->ch[1] = null; len->sz = len->cnt = 1; len->sum = tkey; return len++; } void rot(P_node &p, int d) { P_node k = p->ch[d ^ 1]; p->ch[d ^ 1] = k->ch[d]; k->ch[d] = p; p->update(); k->update(); p = k; } void _Insert(P_node &p, int tkey) { if(p == null) { p = newnode(tkey); } else if(p->key == tkey) { p->cnt++; } else { int d = tkey > p->key; _Insert(p->ch[d], tkey); if(p->ch[d]->fix > p->fix) { rot(p, d ^ 1); } } p->update(); } void _Delete(P_node &p, int tkey) { if(p == null) return; if(p->key == tkey) { if(p->cnt > 1) p->cnt--; else if(p->ch[0] == null) p = p->ch[1]; else if(p->ch[1] == null) p = p->ch[0]; else { int d = p->ch[0]->fix > p->ch[1]->fix; rot(p, d); _Delete(p->ch[d], tkey); } } else { _Delete(p->ch[tkey > p->key], tkey); } p->update(); } int _Kth(P_node p, int k) { if(p == null || k < 1 || k > p->sz) return 0; if(k < p->ch[0]->sz + 1) return _Kth(p->ch[0], k); if(k > p->ch[0]->sz + p->cnt) return _Kth(p->ch[1], k - p->ch[0]->sz - p->cnt); return p->key; } int _Rank(P_node p, int tkey, int res) { if(p == null) return -1; if(p->key == tkey) return p->ch[0]->sz + res + 1; if(tkey < p->key) return _Rank(p->ch[0], tkey, res); return _Rank(p->ch[1], tkey, res + p->ch[0]->sz + p->cnt); } int _Pred(P_node p, int tkey){ if(p == null) return -1e9; if(tkey <= p->key) return _Pred(p->ch[0], tkey); return max(p->key, _Pred(p->ch[1], tkey)); } int _Succ(P_node p, int tkey){ if(p == null) return 1e9; if(tkey >= p->key) return _Succ(p->ch[1], tkey); return min(p->key, _Succ(p->ch[0], tkey)); } LL _Query(P_node p, int res) { if(!res) return 0; if(p->ch[0]->sz >= res) return _Query(p->ch[0], res); else if(p->ch[0]->sz + p->cnt < res) { return p->ch[0]->sum + 1ll*p->key*p->cnt + _Query(p->ch[1], res - p->ch[0]->sz - p->cnt); } else { return p->ch[0]->sum + 1ll*p->key*(res - p->ch[0]->sz); } } void Insert(int tkey){ _Insert(root,tkey); } void Delete(int tkey){ _Delete(root,tkey); } int Kth(int k){ return _Kth(root,k); } int Rank(int tkey){ return _Rank(root,tkey,0); } int Pred(int tkey){ return _Pred(root,tkey); } int Succ(int tkey){ return _Succ(root,tkey); } LL Query(int res){ return _Query(root,res); } }tp; int main() { scanf("%d%d%d", &n, &m, &k); for(int i = 1; i <= n; i++) scanf("%d", &d[i]); scanf("%d", &num); for(int i = 1; i <= num; i++) { int x; scanf("%d", &x); state[x] |= 1; } scanf("%d", &num); for(int i = 1; i <= num; i++) { int x; scanf("%d", &x); state[x] |= 2; } for(int i = 1; i <= n; i++) { if(state[i] == 0) tp.Insert(d[i]), cnt++; else if(state[i] == 1) a[++a[0]] = d[i]; else if(state[i] == 2) b[++b[0]] = d[i]; else c[++c[0]] = d[i]; } sort(a + 1, a + 1 + a[0]); sort(b + 1, b + 1 + b[0]); sort(c + 1, c + 1 + c[0]); for(int i = 1; i <= a[0]; i++) suma[i] = suma[i-1] + a[i]; for(int i = 1; i <= b[0]; i++) sumb[i] = sumb[i-1] + b[i]; for(int i = 1; i <= c[0]; i++) sumc[i] = sumc[i-1] + c[i]; LL ans = INF; for(int i = 1; i <= c[0]; i++) tp.Insert(c[i]), cnt++; for(int i = 0; i <= c[0]; i++) { if(i) tp.Delete(c[i]), cnt--; int res1 = max(0, k - i); if(a[0] < res1 || b[0] < res1) continue; while(a[0] > res1) { tp.Insert(a[a[0]]); a[0]--; cnt++; } while(b[0] > res1) { tp.Insert(b[b[0]]); b[0]--; cnt++; } if(i + 2 * res1 > m) continue; int res2 = m - i - 2 * res1; ans = min(ans, sumc[i] + suma[a[0]] + sumb[b[0]] + tp.Query(res2)); } printf("%lld ", ans == INF ? -1 : ans); return 0; } /* */