题目大意
给出一个长度为(n)的序列A(A1,A2...AN)
如果序列(A)不是非降的,你必须从中删去一个数(随便删都行)
反复执行这种操作,直到(A)非降为止
求有多少种不同的操作方案,答案模(10^9+7)
(剩下的数相同,操作顺序不同算不同方案)
思路1
自己想了下,如果最后剩下一个长度为(i)的合法状态,那么它上一步只能是
长度为(i)的非降序列加了一个 降点
用(f[len][0][mx],f[len][1][mx])表示当前长度,是否存在降点,当前最大值mx来dp
但发现有些dp转移难以优化
思路2
先不考虑中途非法的限制,再进行容斥
做法
记(g[i])为长度为(i)的非降序列有多少个
这个可以通过(f[len][mx])+树状数组优化简单求出
然后(g[i]*(n-i)!)得到最终串为长度(i)的全部方案数
那么我们减掉最终串为长度(i)的非法方案数即可
根据思路1,到达长度(i)的所有非法方案,上一步都是来自长度(i+1)的非降串
(否则有一个降点保底,前面顺序不会导致非法)
而分析一下可以发现
每个长度(i+1)的非降串,都会有(i+1)条边,且全都是指到长度(i)的非降串的
所以非法路径数为([n-(i+1)]!*g[i+1]*(i+1))
solution
#include <cstdio>
#include <cstdlib>
#include <cctype>
#include <cmath>
#include <algorithm>
#include <cstring>
using namespace std;
const int M=2e3+7;
const int Q=1e9+7;
typedef long long LL;
const LL INF=9223372036854775807;
inline int pls(int x,int y){return ((LL)x+y)%Q;}
inline int mul(int x,int y){return 1LL*x*y%Q;}
inline int mns(int x,int y){return pls(x,Q-y);}
inline int ri(){
int x=0;bool f=1;char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
for(;isdigit(c);c=getchar()) x=x*10+c-48;
return f?x:-x;
}
inline LL rl(){
LL x=0;bool f=1;char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
for(;isdigit(c);c=getchar()) x=x*10+c-48;
return f?x:-x;
}
int n,m;
int a[M];
LL val[M],b[M];
struct Bitarr{
int c[M];
Bitarr(){memset(c,0,sizeof c);}
inline int lb(int x){return x&-x;}
inline int add(int x,int d){for(;x<=m;x+=lb(x)) c[x]=pls(c[x],d);}
inline int sum(int x){
int res=0;
for(;x>0;x-=lb(x)) res=pls(res,c[x]);
return res;
}
inline int sum(int x,int y){return mns(sum(y),sum(x-1));}
}f[M];
int g[M],fac[M];
int main(){
int i,j,tp;
n=ri();
for(i=1;i<=n;i++) val[i]=b[i]=ri();
b[n+1]=-INF;
sort(b+1,b+n+2); m=unique(b+1,b+n+2)-(b+1);
for(i=1;i<=n;i++) a[i]=lower_bound(b+1,b+m+1,val[i])-b;
f[0].add(1,1);
for(i=1;i<=n;i++){
for(j=i-1;j>=0;j--){
tp=f[j].sum(1,a[i]);
f[j+1].add(a[i],tp);
}
}
for(i=1;i<=n;i++) g[i]=f[i].sum(1,m);
for(i=1,fac[0]=1;i<=n;i++) fac[i]=mul(fac[i-1],i);
int ans=0;
for(i=1;i<n;i++) ans=pls(ans, mns(mul(g[i],fac[n-i]),mul(mul(g[i+1],fac[n-i-1]),i+1)) );
ans=pls(ans,g[n]);
printf("%d
",ans);
return 0;
}