题目
题目链接:https://www.luogu.com.cn/problem/U138580
帮助统治者解决问题之后,统治者准备奖励你两把剑,让你去打怪。
具体的来说,两把剑分别代表了两个长度为 (n) 的序列 (a,b)。
你什么方面都强,所以你可以分别重新锻造这两把剑,锻造就相当于重新排列这两个序列。
合并这两把剑,让它变成一把新剑(对应序列 (c)),合并相当于把对应位置上的数加起来 (c[i]=a[i]+b[i])。
最后你准备拿着这把新剑去找大 Boss,造成的伤害是众数出现的次数。
问怎么排列才能使得伤害最大化,输出最大伤害。
思路
设 (a[i],b[i]) 分别表示序列一 / 二中,数字 (i) 出现的次数。(ans[i]) 表示中位数为 (i) 的答案。
那么显然有
[ans[i]=sum^{i}_{j=1}min(a[j],b[i-j])
]
我们发现这个式子很像卷积,但是并不可以直接把它们卷起来。
发现 (min(a,b)=sum_{k}[ageq k][bgeq k]),所以我们可以得出
[ans[i]=sum^{i}_{j=1}sum_{k} [a[j]geq k][b[i-j]geq k]
]
直接实现是 (O(nmlog n)) 的,其中 (m) 是值域,这样还没有暴力优秀。
但是可以发现值超过 (T) 的 (a[i],b[i]) 数量不超过 (lfloorfrac{n}{T}
floor) 个。所以我们可以考虑分段乱搞。
- 当 (ileq T) 的时候,我们就跑 FFT,时间复杂度 (O(Tnlog n))。
- 当 (i>T) 的时候,我们直接跑暴力,时间复杂度 (O(frac{n^2}{T^2}))。
所以总时间复杂度 (O(Tnlog n+frac{n^2}{T^2}))。考虑到 FFT 的常数巨大,这里取 (T=5)。
代码
#include <bits/stdc++.h>
#define cp complex<double>
using namespace std;
typedef long long ll;
const int N=400010,T=5;
const double pi=acos(-1);
int n,m1,m2,lim,a[N],b[N],c[N],d[N],rev[N];
ll maxn,ans[N];
cp f[N],g[N];
int read()
{
int d=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
return d;
}
void FFT(cp *f,int tag)
{
for (int i=0;i<lim;i++)
if (i<rev[i]) swap(f[i],f[rev[i]]);
for (int mid=1;mid<lim;mid<<=1)
{
cp temp(cos(pi/mid),tag*sin(pi/mid));
for (int i=0;i<lim;i+=(mid<<1))
{
cp w(1,0);
for (int j=0;j<mid;j++,w*=temp)
{
cp x=f[i+j],y=w*f[i+j+mid];
f[i+j]=x+y; f[i+j+mid]=x-y;
}
}
}
}
int main()
{
n=read();
for (int i=1;i<=n;i++) a[read()]++;
for (int i=1;i<=n;i++) b[read()]++;
lim=1;
while (lim<=2e5) lim<<=1;
for (int i=0;i<lim;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0);
for (int j=1;j<=T;j++)
{
for (int i=0;i<lim;i++)
f[i]=cp(a[i]>=j,0),g[i]=cp(b[i]>=j,0);
FFT(f,1); FFT(g,1);
for (int i=0;i<lim;i++) f[i]*=g[i];
FFT(f,-1);
for (int i=0;i<lim;i++)
ans[i]+=(ll)(f[i].real()/lim+0.499999);
}
for (int i=0;i<lim;i++)
{
if (a[i]>T) c[++m1]=i;
if (b[i]>T) d[++m2]=i;
}
for (int i=1;i<=m1;i++)
for (int j=1;j<=m2;j++)
ans[c[i]+d[j]]+=min(a[c[i]],b[d[j]])-T;
for (int i=0;i<=lim;i++) maxn=max(maxn,ans[i]);
printf("%lld",maxn);
return 0;
}