题意:
有一个长度为n的序列a,a[i]在[li,ri]中独立均匀随机生成。求期望的逆序对个数。
题解:
显然由于独立生成,所以可以每对逆序对单独考虑。
我们将每一块[i,i+1](i∈Z)的区间称之为“第i块”。那么假设a[i]有pi的概率选到第x块,a[j]有pj的概率选到第y块(i<j,x>y),那么贡献就是pi*pj。如果是x<y,贡献显然为0,如果x=y,那么贡献为1/2*pi*pj。这样我们可以得到O(n^2)的做法。
当然这样不够快,我们可以分开考虑x>y和x=y的贡献。
由于是均匀随机的,所以每一个块的概率是一定的。那么我们用一个线段树来表示分到第i块的概率总和,那么就可以解决x=y的贡献,就是1/2*pi*sum,sum表示从第l[i]块到第r[i]-1块的概率总和(线段树维护)。
x>y不太好弄,如果我们对于i,找所有的块来贡献答案,并不能做,因为系数不一样。那我们倒过来考虑,对于i,它能对前面哪些块造成多大的贡献。计算一下,第l[i]块贡献为0,第l[i]+1块贡献为pi,第l[i]+2块贡献为2*pi,这恰好是一个等差数列!我们可以使用线段树维护等差数列,然后即可解决这道题目。
如果不知道怎么线段树维护等差数列的,可以看一看下面这一段文字(会就不用看了)
如果我们要插入一段l,r的区间,公差为a的等差数列,那么第i个位置的值就为a*(i-l)(本题首项为0),a*(i-l)=a*i-a*l,所以可以维护两颗线段树,一颗维护a*i的值(即维护系数),一颗区间加上(-a*l),查询的时候只需要用第一颗线段树加上第二颗线段树的值即可。
#include<cstdio> #include<algorithm> #include<cstdlib> using namespace std; const int mod=998244353,INF=110000000; int n,l[100002],r[100002],cnt=1,ct=1,ct2=1; long long ans; typedef struct{ long long sum,f; int ls,rs; }P; P p[10000002]; typedef struct{ long long sum,f; int ls,rs; }PP; PP q[10000002]; typedef struct{ long long sum,f; int ls,rs; }PPP; PPP t[10000002]; long long ccj(long long x,long long y){ long long ans=1; while(y) { if (y&1)ans=ans*x%mod; x=x*x%mod;y>>=1; } return ans; } void pushdown(int root,int begin,int mid,int end){ if (p[root].f) { if (!p[root].ls)p[root].ls=++cnt; if (!p[root].rs)p[root].rs=++cnt; p[p[root].ls].sum=(p[p[root].ls].sum+p[root].f*(mid-begin+1)%mod)%mod; p[p[root].rs].sum=(p[p[root].rs].sum+p[root].f*(end-mid)%mod)%mod; p[p[root].ls].f=(p[p[root].ls].f+p[root].f)%mod; p[p[root].rs].f=(p[p[root].rs].f+p[root].f)%mod; p[root].f=0; } } void gengxin(int root,int begin,int end,int begin2,int end2,long long z){ if (begin2>end2)return; if (begin>=begin2 && end<=end2) { p[root].sum=(p[root].sum+z*(end-begin+1))%mod;p[root].f=(p[root].f+z)%mod; return; } int mid=(begin+end)/2;pushdown(root,begin,mid,end); if (!(begin>end2 || mid<begin2)) { if (!p[root].ls)p[root].ls=++cnt; gengxin(p[root].ls,begin,mid,begin2,end2,z); } if (!(mid+1>end2 || end<begin2)) { if (!p[root].rs)p[root].rs=++cnt; gengxin(p[root].rs,mid+1,end,begin2,end2,z); } p[root].sum=(p[p[root].ls].sum+p[p[root].rs].sum)%mod; } long long chaxun(int root,int begin,int end,int begin2,int end2){ if (begin2>end2 || begin>end2 || end<begin2 || !root)return 0; if (begin>=begin2 && end<=end2)return p[root].sum; int mid=(begin+end)/2;pushdown(root,begin,mid,end); return (chaxun(p[root].ls,begin,mid,begin2,end2)+chaxun(p[root].rs,mid+1,end,begin2,end2))%mod; } long long js(long long a,long long b){ return ((a+b)*(b-a+1)/2)%mod; } void pd(int root,int begin,int mid,int end){ if (q[root].f) { if (!q[root].ls)q[root].ls=++ct; if (!q[root].rs)q[root].rs=++ct; q[q[root].ls].sum=(q[q[root].ls].sum+q[root].f*js(begin,mid)%mod)%mod; q[q[root].rs].sum=(q[q[root].rs].sum+q[root].f*js(mid+1,end)%mod)%mod; q[q[root].ls].f=(q[q[root].ls].f+q[root].f)%mod; q[q[root].rs].f=(q[q[root].rs].f+q[root].f)%mod; q[root].f=0; } } void gx1(int root,int begin,int end,int begin2,int end2,long long z){ if (begin2>end2)return; if (begin>=begin2 && end<=end2) { q[root].sum=(q[root].sum+z*js(begin,end)%mod)%mod;q[root].f=(q[root].f+z)%mod; return; } int mid=(begin+end)/2;pd(root,begin,mid,end); if (!(begin>end2 || mid<begin2)) { if (!q[root].ls)q[root].ls=++ct; gx1(q[root].ls,begin,mid,begin2,end2,z); } if (!(mid+1>end2 || end<begin2)) { if (!q[root].rs)q[root].rs=++ct; gx1(q[root].rs,mid+1,end,begin2,end2,z); } q[root].sum=(q[q[root].ls].sum+q[q[root].rs].sum)%mod; } void pd2(int root,int begin,int mid,int end){ if (t[root].f) { if (!t[root].ls)t[root].ls=++ct2; if (!t[root].rs)t[root].rs=++ct2; t[t[root].ls].sum=(t[t[root].ls].sum+t[root].f*(mid-begin+1)%mod)%mod; t[t[root].rs].sum=(t[t[root].rs].sum+t[root].f*(end-mid)%mod)%mod; t[t[root].ls].f=(t[t[root].ls].f+t[root].f)%mod; t[t[root].rs].f=(t[t[root].rs].f+t[root].f)%mod; t[root].f=0; } } void gx2(int root,int begin,int end,int begin2,int end2,long long z){ if (begin2>end2)return; if (begin>=begin2 && end<=end2) { t[root].sum=(t[root].sum+z*(end-begin+1)%mod)%mod;t[root].f=(t[root].f+z)%mod; return; } int mid=(begin+end)/2;pd2(root,begin,mid,end); if (!(begin>end2 || mid<begin2)) { if (!t[root].ls)t[root].ls=++ct2; gx2(t[root].ls,begin,mid,begin2,end2,z); } if (!(mid+1>end2 || end<begin2)) { if (!t[root].rs)t[root].rs=++ct2; gx2(t[root].rs,mid+1,end,begin2,end2,z); } t[root].sum=(t[t[root].ls].sum+t[t[root].rs].sum)%mod; } long long cx(int r1,int r2,int begin,int end,int begin2,int end2){ if (begin2>end2 || begin>end2 || end<begin2)return 0; if (begin>=begin2 && end<=end2)return (q[r1].sum+t[r2].sum)%mod; int mid=(begin+end)/2;pd(r1,begin,mid,end);pd2(r2,begin,mid,end); return (cx(q[r1].ls,t[r2].ls,begin,mid,begin2,end2)+cx(q[r1].rs,t[r2].rs,mid+1,end,begin2,end2))%mod; } int main() { scanf("%d",&n); for (int i=1;i<=n;i++)scanf("%d%d",&l[i],&r[i]); for (int i=n;i>=1;i--) { int len=r[i]-l[i];long long ny=ccj(len,mod-2); ans=(ans+ny*cx(1,1,0,INF,l[i],r[i]-1)%mod+ccj(2,mod-2)*ny%mod*chaxun(1,0,INF,l[i],r[i]-1)%mod)%mod; gengxin(1,0,INF,l[i],r[i]-1,ny); gx1(1,0,INF,l[i],r[i],ny);gx2(1,0,INF,l[i],r[i],((-(ny*l[i]%mod))+mod)%mod); gx2(1,0,INF,r[i]+1,INF,1); } printf("%lld ",ans); return 0; }