题目分析:
做三个指针然后预处理阶乘就行。
题目代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 const int maxn = 102000; 5 6 const int mod = 998244353; 7 8 int n,k; 9 struct node{ 10 int data,num; 11 }a[maxn]; 12 13 int ans[maxn]; 14 15 int fac[maxn],inv[maxn]; 16 17 int fast_pow(int now,int pw){ 18 if(pw == 1)return now; 19 int z = fast_pow(now,pw/2); 20 z = (1ll*z*z)%mod; 21 if(pw & 1) z = (1ll*z*now)%mod; 22 return z; 23 } 24 25 void init(){ 26 fac[0] = 1; 27 for(int i=1;i<=n;i++) fac[i] = (1ll*fac[i-1]*i) %mod; 28 inv[n] = fast_pow(fac[n],mod-2); 29 for(int i=n;i>=1;i--){ 30 inv[i-1] = (1ll*inv[i]*i)%mod; 31 } 32 } 33 34 int C(int alpha,int beta){ 35 if(beta > alpha) return 0; 36 return (((1ll*fac[alpha]*inv[beta])%mod)*(inv[alpha-beta]))%mod; 37 } 38 39 int cmp(node alpha,node beta){return alpha.data > beta.data;} 40 41 void read(){ 42 scanf("%d%d",&n,&k); 43 for(int i=1;i<=n;i++) { 44 scanf("%d",&a[i].data),a[i].num = i; 45 } 46 sort(a+1,a+n+1,cmp); 47 } 48 49 void work(){ 50 init(); 51 int pts = 1,ok = 1; 52 while(pts <= n && a[pts].data*2 >= a[1].data) pts++; 53 for(int i=1;i<=n;i++){ 54 int nxt = i; while(nxt+1<=n && a[nxt+1].data == a[i].data) nxt++; 55 while(pts <= n && a[pts].data*2 >= a[i].data) pts++; 56 int forw = n-(pts-nxt),res = C(forw,k); 57 for(int j=i;j<=nxt;j++) ans[a[j].num] += res; 58 while(ok <= nxt && a[ok].data >= a[nxt].data*2) ok++; 59 int newm = k-(nxt-ok+1); 60 if(newm >= 0){ 61 forw = n-(nxt-ok+1),res = C(forw,newm); 62 for(int j=i;j<=nxt;j++) ans[a[j].num] += res,ans[a[j].num] %= mod; 63 } 64 i = nxt; 65 } 66 for(int i=1;i<=n;i++) if(a[i].data == 0) ans[a[i].num] = C(n,k); 67 for(int i=1;i<=n;i++) printf("%d ",ans[i]); 68 } 69 70 int main(){ 71 read(); 72 work(); 73 return 0; 74 }