设dp[i][j]表示选到了第i张牌,牌号在j之前包括j的概率,cnt[i]表示有i张牌,inv[i]表示i在mod下的逆元,那我们可以考虑转移,dp[i][j]=dp[i-1][j-1]*cnt[j]*inv[n-i+1],这个只是表示当前成功转移到i j的状态,如果要考虑胜利的条件,显然是选在选一次j即可赢取胜率,那么对于答案ans只需要加上dp[i-1][j-1]*cnt[j]*inv[n-i+1]*(cnt[j]-1)*inv[n-i]即可,因为我们这个dp[i][j]是记录j之前所有的概率和,需要开一个sum记录之前的和再去更新当前的dp[i][j]即可,记得初始化,所有dp[0][j]都是1,没有选那么概率显然为1,复杂度O(n^2),可以不需要开二维数组。
1 // ——By DD_BOND 2 3 //#include<bits/stdc++.h> 4 #include<functional> 5 #include<algorithm> 6 #include<iostream> 7 #include<sstream> 8 #include<iomanip> 9 #include<climits> 10 #include<cstring> 11 #include<cstdlib> 12 #include<cstddef> 13 #include<cstdio> 14 #include<memory> 15 #include<vector> 16 #include<cctype> 17 #include<string> 18 #include<cmath> 19 #include<queue> 20 #include<deque> 21 #include<ctime> 22 #include<stack> 23 #include<map> 24 #include<set> 25 26 #define fi first 27 #define se second 28 #define MP make_pair 29 #define pb push_back 30 #define INF 0x3f3f3f3f 31 #define pi 3.1415926535898 32 #define lowbit(a) (a&(-a)) 33 #define lson l,(l+r)/2,rt<<1 34 #define rson (l+r)/2+1,r,rt<<1|1 35 #define Min(a,b,c) min(a,min(b,c)) 36 #define Max(a,b,c) max(a,max(b,c)) 37 #define debug(x) cerr<<#x<<"="<<x<<" "; 38 39 using namespace std; 40 41 typedef long long ll; 42 typedef pair<int,int> P; 43 typedef pair<ll,ll> Pll; 44 typedef unsigned long long ull; 45 46 const ll LLMAX=2e18; 47 const int MOD=998244353; 48 const double eps=1e-8; 49 const int MAXN=1e6+10; 50 51 inline ll sqr(ll x){ return x*x; } 52 inline int sqr(int x){ return x*x; } 53 inline double sqr(double x){ return x*x; } 54 ll __gcd(ll a,ll b){ return b==0? a: __gcd(b,a%b); } 55 ll qpow(ll a,ll n){ll sum=1;while(n){if(n&1)sum=sum*a%MOD;a=a*a%MOD;n>>=1;}return sum;} 56 inline int dcmp(double x){ if(fabs(x)<eps) return 0; return (x>0? 1: -1); } 57 58 ll dp[5010][5010],inv[5010],cnt[5010]; 59 60 int main(void) 61 { 62 ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); 63 inv[1]=dp[0][0]=1; 64 for(int i=2;i<=5000;i++) inv[i]=(MOD-MOD/i)*inv[MOD%i]%MOD; 65 ll n,ans=0; cin>>n; 66 for(int i=1;i<=n;i++){ 67 int x; cin>>x; 68 cnt[x]++; 69 dp[0][i]=1; 70 } 71 for(int i=1;i<=n;i++){ 72 ll sum=0; 73 for(int j=1;j<=n;j++){ 74 ll p=dp[i-1][j-1]*cnt[j]%MOD*inv[n-i+1]%MOD; 75 sum=(sum+p)%MOD; 76 dp[i][j]=sum; 77 if(cnt[j]>=2) ans=(ans+p*(cnt[j]-1)%MOD*inv[n-i]%MOD)%MOD; 78 } 79 } 80 cout<<ans<<endl; 81 return 0; 82 }