问题描述
给你一个长度为 (N) 的序列 (a_i),(1leq ileq N),和 (q) 组询问,每组询问读入 (l_1,r_1,l_2,r_2),需输出
(sumlimits_{x=0}^infty ext{get}(l_1,r_1,x) imes ext{get}(l_2,r_2,x))。
( ext{get}(l,r,x)) 表示计算区间 ([l,r]) 中,数字 (x) 出现了多少次。
输入格式
第一行,一个数字 (N),表示序列长度。
第二行,(N) 个数字,表示 (a_1sim a_N)。
第三行,一个数字 (Q),表示询问个数。
第 (4sim Q+3) 行,每行四个数字 (l_1,r_1,l_2,r_2),表示询问。
输出格式
对于每组询问,输出一行一个数字,表示答案。
样例输入
5
1 1 1 1 1
2
1 2 3 4
1 1 4 4
样例输出
4
1
说明
对于 (20\%) 的数据,(1leq N,Qleq 1000);
对于另外 (30\%) 的数据,(1leq a_ileq 50);
对于 (100\%) 的数据,(N,Qleq 50000),(1leq a_ileq N),(1leq l_1leq r_1leq N),(1leq l_2leq r_2leq N)。
数据范围与原题相同,但测试数据由 LibreOJ 自制,并非原数据。
注意:答案有可能超过 int
的最大值。
解析
区间问题往往可以想到前缀和。如果想用前缀和的形式表示这个式子,那么可以按照如下过程化简:
这样,我们就把一个询问拆成了4个可以用莫队维护的询问。对于每一个询问,维护 (num[0][x]) 表示在区间 ([1,l]) 中 (x) 出现了多少次,(num[1][x]) 表示区间 ([1,r]) 中 (x) 出现了多少次。那么答案就是 (sumlimits_{x=0}^infty num[0][x]*num[1][x]) 。至于修改操作,可以这么看:假设一个数原来是 (a*b) ,现在要把它变成 ((a+1)*b) ,其实就相当于在原来的基础上加上一个 (b) 。这里也是同理。
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#define int long long
#define N 50002
using namespace std;
struct query{
int l,r,id;
}q[4*N];
int n,m,i,a[N],b[N],gap,cnt,num[2][N],sum,ans[4*N],l,r;
int read()
{
char c=getchar();
int w=0;
while(c<'0'||c>'9') c=getchar();
while(c<='9'&&c>='0'){
w=w*10+c-'0';
c=getchar();
}
return w;
}
int my_comp(const query &x,const query &y)
{
if(b[x.l]==b[y.l]) return x.r<y.r;
return x.l<y.l;
}
void add(int op,int x)
{
num[op][a[x]]++;
sum+=num[op^1][a[x]];
}
void del(int op,int x)
{
num[op][a[x]]--;
sum-=num[op^1][a[x]];
}
signed main()
{
n=read();
gap=sqrt(1.0*n);
for(i=1;i<=n;i++) a[i]=read();
for(i=1;i<=n;i++) b[i]=(i-1)/gap+1;
m=read();
for(i=1;i<=m;i++){
int l1=read(),r1=read(),l2=read(),r2=read();
q[++cnt]=(query){min(r1,r2),max(r1,r2),cnt};
q[++cnt]=(query){min(r1,l2-1),max(r1,l2-1),cnt};
q[++cnt]=(query){min(l1-1,r2),max(l1-1,r2),cnt};
q[++cnt]=(query){min(l1-1,l2-1),max(l1-1,l2-1),cnt};
}
sort(q+1,q+cnt+1,my_comp);
for(i=1;i<=cnt;i++){
while(l<q[i].l) add(0,++l);
while(l>q[i].l) del(0,l--);
while(r<q[i].r) add(1,++r);
while(r>q[i].r) del(1,r--);
ans[q[i].id]=sum;
}
for(i=1;i<=cnt;i+=4) printf("%lld
",ans[i]-ans[i+1]-ans[i+2]+ans[i+3]);
return 0;
}