正题
题目链接:http://www.ybtoj.com.cn/contest/123/problem/1
题目大意
给出\(3\)个长度为\(n\)的排列\(A,B,C\)。然后一个下标集合\(S\)的三元组是
\[(max\{A_i\},max\{B_i\},max\{C_i\})(i\in S)
\]
求所有下标集合不同的三元组数量
\(1\leq n\leq 10^5\)
解题思路
所有下标集合的三元组都能用一个\(|S|\leq 3\)的集合代替,所以我们只考虑\(|S|\leq 3\)的就好了。
\(|S|=1\)的个数就是\(n\),直接累加即可。
\(|S|=2\)的话,那就代表某个下标霸占了两个最大值,而另一个一定是另一个下标的,如果是\(a,b\)最大,那么我们就要找满足\(a_i> a_j,b_i> a_j,c_i< c_j\)的方案,用三维偏序就好了。
然后\(a,c\)和\(b,c\)的情况也都要做
\(|S|=3\)的话很麻烦,考虑容斥,总方案\(\binom n 3\)减去有一个下标是至少两个的最大值。
同样和上面,先考虑\(a,b\),假设下标\(i\)满足\(a_i>a_j,b_i>b_j\)的情况有\(k\)种,那么就好有\(\binom{k}{2}\)种情况使得\(i\)占据了至少两个最大值。
同理\(a,c\)和\(b,c\)也要做,这是二维偏序,直接树状数组就好了。
但是发现对于\(i\)占据了三个最大值的情况我们统计了三次,需要加回多余的两次,那么统计\(a_i>a_j,b_i>b_j,c_i>c_j\)的个数\(k\),然后加回\(k(k-1)\)的方案就好了,这个也要三维偏序
代码里三维偏序用的是\(CDQ\)分治+树状数组
时间复杂度\(O(n\log^2 n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define lowbit(x) (x&-x)
using namespace std;
const ll N=1e5+10;
struct node{
ll a,b,c;
}w[N],a[N],b[N];
ll n,ans,sum,t[N],g[N];
void Change(ll x,ll val){
while(x<=n){
t[x]+=val;
x+=lowbit(x);
}
return;
}
ll Ask(ll x){
ll ans=0;
while(x){
ans+=t[x];
x-=lowbit(x);
}
return ans;
}
void Merge(ll l,ll mid,ll r){
ll p=l,q=mid+1;
for(ll i=1;i<=r-l+1;i++){
if(p<=mid&&w[p].b<=w[q].b||q>r)b[i]=w[p],p++;
else b[i]=w[q],q++;
}
for(ll i=1;i<=r-l+1;i++)w[l+i-1]=b[i];
return;
}
void CDQ(ll l,ll r,bool op){
if(l==r)return;
ll mid=(l+r)>>1;
CDQ(l,mid,op);CDQ(mid+1,r,op);
ll p=l,tmp;
for(ll i=mid+1;i<=r;i++){
while(p<=mid&&w[p].b<w[i].b)
Change(w[p].c,1),p++;
sum+=(tmp=Ask(w[i].c));
g[w[i].a]+=(op?tmp:0);
}
for(ll i=l;i<p;i++)Change(w[i].c,-1);
Merge(l,mid,r);return;
}
bool cmp(node x,node y)
{return x.a<y.a;}
void solve(){
sort(w+1,w+1+n,cmp);
for(ll i=1;i<=n;i++){
ll tmp=Ask(w[i].b);
ans-=tmp*(tmp-1)/2;
Change(w[i].b,1);
}
memset(t,0,sizeof(t));
return;
}
signed main()
{
freopen("subset.in","r",stdin);
freopen("subset.out","w",stdout);
scanf("%lld",&n);ans=n;
for(ll i=1;i<=n;i++)scanf("%lld",&a[i].a);
for(ll i=1;i<=n;i++)scanf("%lld",&a[i].b);
for(ll i=1;i<=n;i++)scanf("%lld",&a[i].c);
for(ll i=1;i<=n;i++)
w[i].a=a[i].a,w[i].b=a[i].b,w[i].c=n-a[i].c+1;
sort(w+1,w+1+n,cmp);CDQ(1,n,0);
for(ll i=1;i<=n;i++)
w[i].a=a[i].a,w[i].b=a[i].c,w[i].c=n-a[i].b+1;
sort(w+1,w+1+n,cmp);CDQ(1,n,0);
for(ll i=1;i<=n;i++)
w[i].a=a[i].b,w[i].b=a[i].c,w[i].c=n-a[i].a+1;
sort(w+1,w+1+n,cmp);CDQ(1,n,0);
ans+=sum;ans+=n*(n-1)*(n-2)/6;
for(ll i=1;i<=n;i++)w[i].a=a[i].a,w[i].b=a[i].b;solve();
for(ll i=1;i<=n;i++)w[i].a=a[i].b,w[i].b=a[i].c;solve();
for(ll i=1;i<=n;i++)w[i].a=a[i].a,w[i].b=a[i].c;solve();
for(ll i=1;i<=n;i++)
w[i].a=a[i].a,w[i].b=a[i].b,w[i].c=a[i].c;
sort(w+1,w+1+n,cmp);
CDQ(1,n,1);
for(ll i=1;i<=n;i++)ans+=g[i]*(g[i]-1);
printf("%lld\n",ans);
}