题目链接:http://61.187.179.132/JudgeOnline/problem.php?id=2141
题意:给出一个数列A,每次交换两个数的位置。输出交换后逆序对的个数。
思路:首先,对于交换位置x和y,对于区间[x+1,y-1]的数字,小于A[x]的要减去,大于A[x]的要加上,大于A[y]的要减去,小于A[y] 的要加上。然后A[x]<A[y]则答案加1,A[x]<A[y]答案要减去1。查找修改用树状数组套一个treap即可。
struct node
{
int val,size,pri,L,R,cnt;
};
node a[N*300];
int e;
int newNode(int val)
{
int x=++e;;
a[x].val=val;
a[x].size=1;
a[x].cnt=1;
a[x].L=a[x].R=0;
a[x].pri=rand();
return x;
}
void pushUp(int x)
{
if(x==0) return;
a[x].size=a[x].cnt+a[a[x].L].size+a[a[x].R].size;
}
void rotL(int &x)
{
int y=a[x].R;
a[x].R=a[y].L;
a[y].L=x;
pushUp(x);
pushUp(y);
x=y;
}
void rotR(int &x)
{
int y=a[x].L;
a[x].L=a[y].R;
a[y].R=x;
pushUp(x);
pushUp(y);
x=y;
}
void insert(int &k,int val)
{
if(k==0) k=newNode(val);
else if(val<a[k].val)
{
insert(a[k].L,val);
if(a[a[k].L].pri>a[k].pri) rotR(k);
}
else if(val>a[k].val)
{
insert(a[k].R,val);
if(a[a[k].R].pri>a[k].pri) rotL(k);
}
else a[k].cnt++;
pushUp(k);
}
void del(int val,int &k)
{
if(k==0) return;
else if(val<a[k].val) del(val,a[k].L);
else if(val>a[k].val) del(val,a[k].R);
else
{
a[k].cnt--;
if(a[k].cnt<=0)
{
if(a[k].L==0&&a[k].R==0) k=0;
else if(a[k].L==0) k=a[k].R;
else if(a[k].R==0) k=a[k].L;
else
{
if(a[a[k].L].pri<a[a[k].R].pri) rotL(k),del(val,a[k].L);
else rotR(k),del(val,a[k].R);
}
}
}
pushUp(k);
}
int d[N];
int getCnt(int t,int x)
{
if(t==0) return 0;
if(a[t].val==x)
{
return a[a[t].L].size+a[t].cnt;
}
if(a[t].val>x) return getCnt(a[t].L,x);
return a[t].cnt+a[a[t].L].size+getCnt(a[t].R,x);
}
int n,m;
int A[N];
void Set(int x,int k)
{
while(x<=n)
{
insert(A[x],k);
x+=x&-x;
}
}
void erase(int x,int k)
{
while(x<=n)
{
del(k,A[x]);
x+=x&-x;
}
}
int get(int x,int k)
{
int ans=0;
while(x)
{
ans+=getCnt(A[x],k);
x-=x&-x;
}
return ans;
}
pair<int,int> cal(int L,int R,int x)
{
i64 a=get(R,x)-get(L-1,x);
i64 b=get(R,x-1)-get(L-1,x-1);
return MP(a-b,b);
}
int p[N];
int find(int low,int high,int x)
{
int M;
while(low<=high)
{
M=(low+high)>>1;
if(p[M]==x) return M;
if(p[M]>x) high=M-1;
else low=M+1;
}
}
void getInt(int &x)
{
char c=getchar();
while(!isdigit(c)) c=getchar();
x=0;
while(isdigit(c)) x=x*10+c-'0',c=getchar();
}
int main()
{
getInt(n);
int i;
FOR1(i,n) getInt(d[i]),p[i]=d[i];
sort(p+1,p+n+1);
int M=unique(p+1,p+n+1)-(p+1);
FOR1(i,n) d[i]=find(1,M,d[i]);
i64 sum=0;
pair<int,int> temp;
FOR1(i,n)
{
Set(i,d[i]);
temp=cal(1,i-1,d[i]);
sum+=i-1-temp.first-temp.second;
}
PR(sum);
RD(m);
int x,y;
while(m--)
{
getInt(x);
getInt(y);
if(x>y) swap(x,y);
if(d[x]==d[y])
{
PR(sum);
continue;
}
if(d[x]>d[y]) sum--;
else sum++;
if(y-x>1)
{
temp=cal(x+1,y-1,d[x]);
sum-=temp.second;
sum+=y-x-1-temp.first-temp.second;
temp=cal(x+1,y-1,d[y]);
sum+=temp.second;
sum-=y-x-1-temp.first-temp.second;
}
erase(x,d[x]);
erase(y,d[y]);
swap(d[x],d[y]);
Set(x,d[x]);
Set(y,d[y]);
PR(sum);
}
}