给n个操作, 第一种是在x, y, z这个点+1. 第二种询问(x1, y1, z1). (x2, y2, z2)之间的总值。
用一次cdq分治可以将三维变两维, 两次的话就变成一维了, 然后最后一维用树状数组维护。 对于每个询问, 相当于将它拆成8个点。
注意第二次cdq分治的时候l可能小于r。 所以这里的return条件是l <= r而不是l == r。 找了好久...
#include <bits/stdc++.h> using namespace std; #define pb(x) push_back(x) #define mem(a) memset(a, 0, sizeof(a)) vector <int> b; const int maxn = 5e4+5; struct node { int x1, y1, z1, x2, y2, z2; int sign, id; node(){} node(int x, int y, int z, int sign, int id):x1(x), y1(y), z1(z), sign(sign), id(id){} }a[maxn], s[maxn*8], ss[maxn*8], c[maxn*8]; bool cmpx(node a, node b) { if(a.x1 == b.x1) return a.id < b.id; return a.x1 < b.x1; } bool cmpy(node a, node b) { if(a.y1 == b.y1) return a.id < b.id; return a.y1 < b.y1; } int len, sum[maxn*8], ans[maxn]; int lowbit(int x) {return x&(-x);} void update(int x, int val) { while(x <= len) { sum[x] += val; x += lowbit(x); } } int query(int x, int ret = 0) { while(x) { ret += sum[x]; x -= lowbit(x); } return ret; } void cdqy(int l, int r) { if(l >= r) return ; int m = l+r>>1, cnt = 0; cdqy(l, m); cdqy(m+1, r); for(int i = l; i <= m; i++) { if(s[i].sign == 1) { ss[cnt++] = s[i]; } } for(int i = m+1; i <= r; i++) { if(s[i].sign != 1) { ss[cnt++] = s[i]; } } sort(ss, ss+cnt, cmpy); for(int i = 0; i < cnt; i++) { if(ss[i].sign == 1) { update(ss[i].z1, 1); } else if(ss[i].sign == 2) { ans[ss[i].id] += query(ss[i].z1); } else { ans[ss[i].id] -= query(ss[i].z1); } } for(int i = 0; i < cnt; i++) { if(ss[i].sign == 1) { update(ss[i].z1, -1); } } } void cdqx(int l, int r) { if(l == r) return ; int m = l+r>>1, top = 0; cdqx(l, m); cdqx(m+1, r); for(int i = l; i <= m; i++) { if(c[i].sign == 1) { s[++top] = c[i]; } } for(int i = m+1; i <= r; i++) { if(c[i].sign != 1) { s[++top] = c[i]; } } sort(s+1, s+1+top, cmpx); cdqy(1, top); } void solve(int n, int num = 0) { sort(b.begin(), b.end()); b.erase(unique(b.begin(), b.end()), b.end()); len = b.size(); for(int i = 1; i <= n; i++) { if(a[i].sign == 1) { c[++num] = a[i]; } else { c[++num] = node(a[i].x2, a[i].y2, a[i].z2, 2, a[i].id); c[++num] = node(a[i].x1-1, a[i].y1-1, a[i].z1-1, 3, a[i].id); c[++num] = node(a[i].x2, a[i].y2, a[i].z1-1, 3, a[i].id); c[++num] = node(a[i].x2, a[i].y1-1, a[i].z2, 3, a[i].id); c[++num] = node(a[i].x1-1, a[i].y2, a[i].z2, 3, a[i].id); c[++num] = node(a[i].x1-1, a[i].y1-1, a[i].z2, 2, a[i].id); c[++num] = node(a[i].x1-1, a[i].y2, a[i].z1-1, 2, a[i].id); c[++num] = node(a[i].x2, a[i].y1-1, a[i].z1-1, 2, a[i].id); } } for(int i = 1; i <= num; i++) { c[i].z1 = lower_bound(b.begin(), b.end(), c[i].z1)-b.begin()+1; } cdqx(1, num); } int main() { int t, n; cin>>t; while(t--) { cin>>n; mem(sum); mem(ans); for(int i = 1; i <= n; i++) { scanf("%d%d%d%d", &a[i].sign, &a[i].x1, &a[i].y1, &a[i].z1); if(a[i].sign == 2) { scanf("%d%d%d", &a[i].x2, &a[i].y2, &a[i].z2); b.pb(a[i].z2); b.pb(a[i].z1-1); } else { b.pb(a[i].z1); } a[i].id = i; } solve(n); for(int i = 1; i <= n; i++) { if(a[i].sign == 2) { printf("%d ", ans[i]); } } } return 0; }