#define LL long long
#define mid(a,b) ((a+b)>>1)
#define eps 1e-8
#define maxn 55000
#define mod 1000000007
#define inf 0x3f3f3f3f
#define IN freopen("in.txt","r",stdin);
using namespace std;
int n;
LL num[maxn], tmpa[maxn];
LL c[maxn];
LL left_large[maxn], left_small[maxn];
LL right_small[maxn], right_large[maxn];
LL tol_cnt[maxn], tol_small[maxn], tol_large[maxn];
LL lowbit(LL x) {
return x & (-x);
}
void update(LL x, LL d) {
while(x <=n) {
c[x] += d;
x += lowbit(x);
}
}
LL get_sum(LL x) {
LL ret = 0;
while(x > 0) {
ret += c[x];
x -= lowbit(x);
}
return ret;
}
int main(int argc, char const *argv[])
{
//IN;
while(scanf("%d", &n) != EOF)
{
for(int i=1; i<=n; i++)
scanf("%d", &num[i]), tmpa[i] = num[i];
sort(tmpa+1, tmpa+1+n);
map<LL, LL> mymap;
LL sz = 0; mymap.clear();
for(int i=1; i<=n; i++) {
if(!mymap.count(tmpa[i])) {
sz++;
mymap.insert(make_pair(tmpa[i], sz));
}
}
for(int i=1; i<=n; i++) {
num[i] = mymap[num[i]];
}
memset(left_large, 0, sizeof(left_large));
memset(right_large, 0, sizeof(right_large));
memset(right_small, 0, sizeof(right_small));
memset(left_small, 0, sizeof(left_small));
memset(tol_small, 0, sizeof(tol_small));
memset(tol_large, 0, sizeof(tol_large));
memset(tol_cnt, 0, sizeof(tol_cnt));
memset(c, 0, sizeof(c));
for(int i=1; i<=n; i++) {
tol_cnt[num[i]]++;
}
for(int i=1; i<=sz; i++) {
tol_small[i] = tol_small[i-1] + tol_cnt[i-1];
}
for(int i=sz; i>=1; i--) {
tol_large[i] = tol_large[i+1] + tol_cnt[i+1];
}
LL all_less = 0;
for(int i=1; i<=n; i++) {
left_large[i] = get_sum(sz) - get_sum(num[i]);
right_large[i] = tol_large[num[i]] - left_large[i];
left_small[i] = get_sum(num[i]-1);
right_small[i] = tol_small[num[i]] - left_small[i];
all_less += right_small[i];
update(num[i], 1);
}
LL ans = 0;
for(int i=1; i<=n; i++) {
ans += left_small[i] * (all_less - left_large[i] - right_small[i]);
ans -= right_large[i] * (left_large[i] + right_small[i]);
}
printf("%I64d
", ans);
}
return 0;
}
####TLE代码:(三次线段树操作)
``` cpp
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <map>
#include <set>
#include <vector>
#define LL long long
#define mid(a,b) ((a+b)>>1)
#define eps 1e-8
#define maxn 55000
#define mod 1000000007
#define inf 0x3f3f3f3f
#define IN freopen("in.txt","r",stdin);
using namespace std;
int n;
LL num[maxn], tmpa[maxn];
LL pre[maxn];
LL last[maxn];
struct Tree
{
int left,right;
LL cur;
LL sum;
}tree[maxn<<2];
void build(int i,int left,int right)
{
tree[i].left=left;
tree[i].right=right;
if(left==right){
tree[i].sum = 0;
tree[i].cur = 0;
return ;
}
int mid=mid(left,right);
build(i<<1,left,mid);
build(i<<1|1,mid+1,right);
tree[i].sum=tree[i<<1].sum+tree[i<<1|1].sum;
tree[i].cur=tree[i<<1].cur+tree[i<<1|1].cur;
}
void update(int i,int x,LL d)
{
if(tree[i].left==tree[i].right){
tree[i].sum+=1;
tree[i].cur+=d;
return;
}
int mid=mid(tree[i].left,tree[i].right);
if(x<=mid) update(i<<1,x,d);
else update(i<<1|1,x,d);
tree[i].sum=tree[i<<1].sum+tree[i<<1|1].sum;
tree[i].cur=tree[i<<1].cur+tree[i<<1|1].cur;
}
LL query(int i,int left,int right)
{
if(tree[i].left==left&&tree[i].right==right)
return tree[i].sum;
int mid=mid(tree[i].left,tree[i].right);
if(right<=mid) return query(i<<1,left,right);
else if(left>mid) return query(i<<1|1,left,right);
else return query(i<<1,left,mid)+query(i<<1|1,mid+1,right);
}
LL query2(int i,int left,int right)
{
if(tree[i].left==left&&tree[i].right==right)
return tree[i].cur;
int mid=mid(tree[i].left,tree[i].right);
if(right<=mid) return query2(i<<1,left,right);
else if(left>mid) return query2(i<<1|1,left,right);
else return query2(i<<1,left,mid)+query2(i<<1|1,mid+1,right);
}
map<LL, LL> mymap;
int main(int argc, char const *argv[])
{
//IN;
while(scanf("%d", &n) != EOF)
{
for(int i=1; i<=n; i++)
scanf("%d", &num[i]), tmpa[i] = num[i];
sort(tmpa+1, tmpa+1+n);
LL sz = 0; mymap.clear();
for(int i=1; i<=n; i++) {
if(!mymap.count(tmpa[i])) {
sz++;
mymap.insert(make_pair(tmpa[i], sz));
}
}
for(int i=1; i<=n; i++) {
num[i] = mymap[num[i]];
}
fill(pre, pre+n+1, 0);
fill(last, last+n+1, 0);
build(1, 1, sz);
for(int i=1; i<=n; i++) {
update(1, num[i], 1);
if(num[i] == sz) continue;
pre[i] = query(1, num[i]+1, sz);
}
build(1, 1, sz);
for(int i=n; i>=1; i--) {
update(1, num[i], 1);
if(num[i] == 1) continue;
last[i] = query(1, 1, num[i]-1);
}
LL all_less = 0;
for(int i=1; i<=n; i++) {
all_less += last[i];
}
build(1, 1, sz);
LL ans = 0;
for(int i=1; i<=n; i++) {
update(1, num[i], pre[i]+last[i]);
if(num[i] == 1) continue;
LL pre_num = query(1, 1, num[i]-1);
LL pre_sum = query2(1, 1, num[i]-1);
ans += pre_num * (all_less - pre[i] - last[i]) - pre_sum;
}
printf("%I64d
", ans);
}
return 0;
}