题目
题目链接:https://www.luogu.com.cn/problem/P6477
题目描述
给定一个长度为 \(n\) 的正整数序列 \(A_1\), \(A_2\), \(\cdots\), \(A_n\)。定义一个函数 \(f(l,r)\) 表示:序列中下标在 \([l,r]\) 范围内的子区间中,不同的整数个数。换句话说,\(f(l,r)\) 就是集合 \(\{A_l,A_{l+1},\cdots,A_r\}\) 的大小,这里的集合是不可重集,即集合中的元素互不相等。
现在,请你求出 \(\sum_{l=1}^n\sum_{r=l}^n (f(l,r))^2\)。由于答案可能很大,请输出答案对 \(10^9 +7\) 取模的结果。
思路
考试时居然先写了这题再写 T1 的。。。主要是 T1 一眼没看出结论。
我们考虑枚举右端点 \(r\),对于第 \(l\) 个数,我们记录 \(f[l]\) 表示 \([l,r]\) 中不同的数字个数。
假设我们已经通过某种奇妙的方法求出了 \(r=i\) 时的 \(f\)。接下来我们求 \(r=i+1\) 时的 \(f\)。
- 如果 \(a[i+1]\) 在之前没出现过,那么显然每一个区间都多了一个不同的数。\(f[1\sim i+1]\) 全部加一。
- 如果 \(a[i+1]\) 在第 \(j\) 位出现过 \((j<i)\),那么 \([1,j]\) 出现的数不变,\([j+1,i+1]\) 出现的数加一。也就是 \(f[j+1\sim i+1]\) 全部加一。
维护平方?套路性拆开,依旧维护区间平方和、区间和即可。
线段树就可以轻松解决这些问题。
时间复杂度 \(O(n\log n)\)。
代码
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=1000010,MOD=1e9+7;
int n,a[N],b[N],last[N];
ll ans;
inline int read()
{
int d=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
return d;
}
struct SegTree
{
int l[N*4],r[N*4],len[N*4],sum[N*4],lazy[N*4];
ll mul[N*4];
void build(int x,int ql,int qr)
{
l[x]=ql; r[x]=qr; len[x]=qr-ql+1;
if (ql==qr) return;
register int mid=(ql+qr)>>1;
build(x*2,ql,mid); build(x*2+1,mid+1,qr);
}
void update(int x,int ql,int qr)
{
if (l[x]==ql && r[x]==qr)
{
mul[x]=(mul[x]+2LL*sum[x]+len[x])%MOD;
sum[x]=(sum[x]+len[x])%MOD;
lazy[x]++;
return;
}
if (lazy[x])
{
ll p=lazy[x]; register int lc=x*2,rc=x*2+1;
lazy[lc]+=p; lazy[rc]+=p;
mul[lc]=(mul[lc]+2LL*p*sum[lc]+len[lc]*p*p)%MOD;
sum[lc]=(sum[lc]+len[lc]*p)%MOD;
mul[rc]=(mul[rc]+2LL*p*sum[rc]+len[rc]*p*p)%MOD;
sum[rc]=(sum[rc]+len[rc]*p)%MOD;
lazy[x]=0;
}
register int mid=(l[x]+r[x])>>1;
if (qr<=mid) update(x*2,ql,qr);
else if (ql>mid) update(x*2+1,ql,qr);
else update(x*2,ql,mid),update(x*2+1,mid+1,qr);
sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
mul[x]=(mul[x*2]+mul[x*2+1])%MOD;
}
}seg;
int main()
{
n=read();
for (register int i=1;i<=n;i++)
a[i]=b[i]=read();
sort(b+1,b+1+n);
register int tot=unique(b+1,b+1+n)-b-1;
for (register int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+1+n,a[i])-b;
seg.build(1,1,n);
for (register int i=1;i<=n;i++)
{
if (last[a[i]])
seg.update(1,last[a[i]]+1,i);
else
seg.update(1,1,i);
last[a[i]]=i;
ans=(ans+seg.mul[1])%MOD;
}
printf("%lld",ans);
return 0;
}