离散化x然后用树状数组解决,排序y然后分治解决,z在分治的时候排序解决。
具体:先对y排序,solve(l,r)分成solve(l,mid),solve(mid+1,r), 然后因为是按照y排序,所以l,mid区间内的y值都小于mid+1,r。现在再对z排序,按照顺序以x做关键字插入到树状数组中,那么就可以一起解决l,mid对mid+1,r的影响。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int maxn=1e5+9,mod=1<<30; int trsum[maxn],trmax[maxn]; int n; struct P { int x,y,z,id; }point[maxn],now[maxn]; struct A { int max,sum; }ans[maxn],tr[maxn]; bool cmpx(const P a,const P b) { return a.x<b.x; } bool cmpy(const P a,const P b) { return a.y<b.y; } bool cmpz(const P a,const P b) { return a.z<b.z; } int lowbit(int x) { return (x&-x); } void insert(int x,A tmp) { for(int i=x;i<=n;i+=lowbit(i)) { if(tr[i].max==tmp.max) { tr[i].sum+=tmp.sum; tr[i].sum%=mod; } else if(tr[i].max<tmp.max) { tr[i].sum=tmp.sum; tr[i].max=tmp.max; } } } A getsum(int x) { A ret; ret.max=-1; for(int i=x;i>=1;i-=lowbit(i)) { if(tr[i].max>ret.max) { ret.max=tr[i].max; ret.sum=tr[i].sum; } else if(tr[i].max==ret.max) { ret.sum+=tr[i].sum; ret.sum%=mod; } } return ret; } void clear(int x) { for(int i=x;i<=n;i+=lowbit(i)) { tr[i].max=0; tr[i].sum=0; } } void solve(int l,int r) { if(l==r) return ; int mid=l+r>>1; solve(l,mid); for(int i=mid+1;i<=r;i++) now[i]=point[i]; sort(point+l,point+mid+1,cmpz); sort(point+mid+1,point+r+1,cmpz); for(int i=mid+1,top=l;i<=r;i++) { while(top<=mid&&point[top].z<=point[i].z) { insert(point[top].x,ans[point[top].id]); top++; } A ret=getsum(point[i].x); ret.max++; if(ret.max==ans[point[i].id].max) { ans[point[i].id].sum+=ret.sum; ans[point[i].id].sum%=mod; } else if(ret.max>ans[point[i].id].max) { ans[point[i].id]=ret; } } for(int i=l;i<=mid;i++) clear(point[i].x); for(int i=mid+1;i<=r;i++) point[i]=now[i]; solve(mid+1,r); } int main() { // freopen("in.txt","r",stdin); int T; scanf("%d",&T); while(T--) { scanf("%d",&n); for(int i=1;i<=n;i++) { scanf("%d %d %d",&point[i].x,&point[i].y,&point[i].z); point[i].id=i; } sort(point+1,point+1+n,cmpx); for(int i=1,xx=point[1].x-1,num=0;i<=n;i++) { if(point[i].x!=xx) num++,xx=point[i].x; point[i].x=num; } sort(point+1,point+1+n,cmpy); for(int i=1;i<=n;i++) { ans[i].max=1; ans[i].sum=1; } solve(1,n); A ret; ret.max=-1; for(int i=1;i<=n;i++) { if(ret.max==ans[i].max) { ret.sum+=ans[i].sum; ret.sum%=mod; } else if(ret.max<ans[i].max) { ret=ans[i]; } } printf("%d %d ",ret.max,ret.sum); } return 0; }