[计蒜之道2019 复赛 A]外教 Michale 变身大熊猫
Online Judge:2019计蒜之道 复赛 A
Label:LIS+线段树、树状数组+快速幂(模逆元)
题目描述
题解:
pre.关于本题中模逆元的提示:
对于一个质数mod,q的模逆元是(q^{mod-2})。也就是说对于本题,平时一般用形如(G=frac{p}{q})的分式表示概率,现在,我们利用模逆元来表示这个概率,即(G=p*q^{mod-2})。
那这道题里(mod=998244353)是个质数,那就可以利用上面的提示(G=p*q^{mod-2}),很明显要用到快速幂,而这个(q)就是LIS的个数,那我们只要求每个位置对应的p即可。
①.70%数据的(O(N^2))做法:
想到LIS的O(NlogN)做法——用树状数组或线段树维护前缀最小值优化一下dp。然后可以求出最大上升子序列的长度mal ,接着分别从左到右,再从右到左求一下每个点会在几条LIS上,然后左右的方案数乘一下就是p了。预处理很明显要先离散化一下。离线赛时的70分代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
const int N=1e5+10;
const int mod=998244353;
int r[N],l[N];
int qsm(int x,int b){
int res=1;
while(b){
if(b&1)res=1LL*res*x%mod;
x=1LL*x*x%mod;
b>>=1;
}
return res;
}
int B[N];
int n,num,a[N],b[N],dp[N];
int lowbit(int x){return x&-x;}
void update(int x,int d){
while(x<=num)B[x]=max(B[x],d),x+=lowbit(x);
}
int query(int x){
int res=0;
while(x)res=max(res,B[x]),x-=lowbit(x);
return res;
}
int cnt[N],all;
signed main(){
// freopen("lis.in","r",stdin),freopen("lis.out","w",stdout);
scanf("%lld",&n);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]),b[i]=a[i];
sort(b+1,b+n+1);
num=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+num+1,a[i])-b;
int mal=0;
for(int i=1;i<=n;i++){
int now=query(a[i]-1);
dp[i]=now+1;
if(dp[i]>mal)mal=dp[i];
update(a[i],dp[i]);
}
a[0]=0,dp[0]=0,l[0]=1;
a[n+1]=num+1,dp[n+1]=mal+1,r[n+1]=1;n++;
for(int i=0;i<=n;i++){
for(int j=i+1;j<=n;j++){
if((dp[j]!=dp[i]+1)||a[j]<=a[i])continue;
l[j]=(l[j]+l[i])%mod;
}
}
for(int i=n;i>=0;i--){
for(int j=i-1;j>=0;j--){
if((dp[j]!=dp[i]-1)||a[j]>=a[i])continue;
r[j]=(r[j]+r[i])%mod;
}
}
int all=l[0]*r[0]%mod,q=qsm(all,mod-2);
for(int i=1;i<n;i++)printf("%lld ",r[i]*l[i]%mod*q%mod);
}
②.(O(NlogN))正解
上面的代码也可以优化,但是我们可以直接在求LIS时候维护方案数。
也是一样的思路,从左到右扫一遍,再从右到左扫一遍,然后乘一下得到方案数。然后下面代码利用树状数组维护到目前为止,以离散化后的数字i结尾的LIS最长为(c[i]),以及,以其结尾的LIS这么长有(ti[i])条。
由于一次从左到右,一次从右到左,所以打两组树状数组。大致思路同上面的代码:
//快速幂->模逆元,树状数组、线段树维护
#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
const int N=1e5+10;
const int mod=998244353;
int num,a[N],b[N];
int dp[N],li[N],ans[N];
int presum,c[N],ti[N];
int add1(int x,int mal,int d){
while(x<=num){
if(mal>c[x])c[x]=mal,ti[x]=d;
else if(mal==c[x])ti[x]=(ti[x]+d)%mod;
x+=x&(-x);
}
}
int sum1(int x){
int res=0;
while(x){
if(res<c[x])res=c[x],presum=ti[x];
else if(res==c[x])presum+=ti[x];
x-=x&(-x);
}
return res;
}
void add2(int x,int mal,int d){
while(x){
if(mal>c[x])c[x]=mal,ti[x]=d;
else if(mal==c[x])ti[x]=(ti[x]+d)%mod;
x-=x&(-x);
}
}
int sum2(int x){
int res=0;
while(x<=num){
if(res<c[x])res=c[x],presum=ti[x];
else if(res==c[x])presum+=ti[x];
x+=x&(-x);
}
return res;
}
int ksm(int x,int b){
int res=1;
while(b){
if(b&1)res=res*x%mod;
x=x*x%mod;
b>>=1;
}
return res;
}
signed main(){
int n;scanf("%lld",&n);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]),b[i]=a[i];
sort(b+1,b+n+1);
num=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+num+1,a[i])-b;
int lis=0;
for(int i=1;i<=n;i++){
presum=0;
int now=sum1(a[i]-1);presum%=mod;
dp[i]=now+1;
lis=max(lis,dp[i]);
if(dp[i]==1)li[i]=1;
else li[i]=presum;
add1(a[i],dp[i],li[i]);
}
memset(c,0,sizeof(c));
memset(ti,0,sizeof(ti));
int tot=0;
for(int i=n;i>=1;i--){
presum=0;
int now=sum2(a[i]+1);presum%=mod;
int ri=(now==0?1:presum);
if((dp[i]+now)==lis)ans[i]=li[i]*ri%mod;
if(dp[i]==lis)tot=(tot+ans[i])%mod;
add2(a[i],now+1,ri);
}
int q=ksm(tot,mod-2);
for(int i=1;i<=n;i++)printf("%lld ",ans[i]*q%mod);
return 0;
}