zoukankan      html  css  js  c++  java
  • [CF960G] Bandit Blues

    题意

    给你三个正整数 (n,a,b),定义 (A) 为一个排列中是前缀最大值的数的个数,定义 (B) 为一个排列中是后缀最大值的数的个数,求长度为 (n) 的排列中满足 (A = a)(B = b) 的排列个数。(n le 10^5),答案对 (998244353) 取模。

    Sol

    首先可以设一个 (DP) 状态 (f(i,j)) 表示,长度为 (i) 的排列,有 (j) 个前缀最大值的方案数。

    那么转移就是枚举新放一个最小值,只有放在序列开头才有 (1) 的贡献:

    [f(i,j)=f(i-1,j-1)+(i-1) imes f(i-1,j) ]

    最后的答案就是枚举最大值 (n) 放在位置 (i),然后左边长度为 (i-1) 且有 (a-1) 个前缀最大值,右边长度为 (n-1-i) 且有 (b-1) 个后缀最大值,可以发现这个后缀最大值和前缀最大值的方案是相等的,那么最终的答案就是:

    [ans=sum_{i=1}^n C(n-1,i-1)cdot f(i-1,a-1)cdot f(n-i-1,b-1) ]

    稍微熟练一点就可以看出,这个 (f(i,j)) 本质上就是第一类斯特林数,即 (i) 个数放 (j) 个圆排列的方案数。

    所以这个式子就可以化简了,从组合意义上理解就是,从 (n-1) 个数拿出来形成 (a+b-2) 个圆排列,其中把 (a-1) 个放在 (n) 前面的方案数。最后答案就变成了:

    [ans=s(n-1,a+b-2) imes C(a+b-2,a-1) ]

    只需要预处理第一类斯特林数一行就行。分治(mathrm{NTT})复杂度(O(nlog^2 n)),倍增复杂度(O(nlog n))。具体见这里

    Code

    两种都写了下

    // 分治NTT
    #pragma GCC optimize(2)
    #include<bits/stdc++.h>
    using std::min;
    using std::max;
    using std::swap;
    using std::vector;
    typedef double db;
    typedef long long ll;
    #define pb(A) push_back(A)
    #define vec std::vector<int>
    #define pii std::pair<int,int>
    #define all(A) A.begin(),A.end()
    #define mp(A,B) std::make_pair(A,B)
    const int N=4e5+5;
    const int mod=998244353;
    
    int n,A,B,lim;
    int a[N],b[N],rev[N];
    
    int ksm(int a,int b=mod-2,int ans=1){
        while(b){
            if(b&1) ans=1ll*ans*a%mod;
            a=1ll*a*a%mod;b>>=1;
        } return ans;
    }
    
    void ntt(int *f,int g){
        for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            int tmp=ksm(g,(mod-1)/(mid<<1));
            for(int R=mid<<1,j=0;j<lim;j+=R){
                for(int w=1,k=0;k<mid;k++,w=1ll*w*tmp%mod){
                    int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                    f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
                }
            }
        } if(g>3)
            for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
    }
    
    vec mul(vec A,vec B,int n){
        lim=1;while(lim<=n) lim<<=1;
        for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
        for(int i=0;i<lim;i++) a[i]=(i<A.size()?A[i]:0);
        for(int i=0;i<lim;i++) b[i]=(i<B.size()?B[i]:0);
        ntt(a,3),ntt(b,3);
        for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
        ntt(a,(mod+1)/3); vec now;
        for(int i=0;i<=n;i++) now.pb(a[i]);
        return now;
    }
    
    vec solve(int l,int r){
        if(l==r){vec now;now.pb(l),now.pb(1);return now;}
        int mid=l+r>>1; return mul(solve(l,mid),solve(mid+1,r),r-l+1);
    }
    
    int S(int n,int m){
        if(!n) return 1;
        if(n<m) return 0;
        vec now=solve(0,n-1);
        return now[m];
    }
    
    int getint(){
        int X=0,w=0;char ch=getchar();
        while(!isdigit(ch))w|=ch=='-',ch=getchar();
        while( isdigit(ch))X=X*10+ch-48,ch=getchar();
        if(w) return -X;return X;
    }
    
    int C(int n,int m){
        if(n<m or n<0 or m<0) return 0;
        int now=1;
        for(int i=1;i<=n;i++) now=1ll*now*i%mod;
        for(int i=1;i<=m;i++) now=1ll*now*ksm(i)%mod;
        for(int i=1;i<=n-m;i++) now=1ll*now*ksm(i)%mod;
        return now;
    }
    
    signed main(){
        n=getint(),A=getint(),B=getint();
        printf("%lld
    ",1ll*S(n-1,A+B-2)*C(A+B-2,A-1)%mod);
        return 0;
    }
    
    
    // 倍增
    #pragma GCC optimize(2)
    #include<bits/stdc++.h>
    using std::min;
    using std::max;
    using std::swap;
    using std::vector;
    typedef double db;
    typedef long long ll;
    #define pb(A) push_back(A)
    #define pii std::pair<int,int>
    #define all(A) A.begin(),A.end()
    #define mp(A,B) std::make_pair(A,B)
    const int N=4e5+5;
    const int mod=998244353;
    
    int lim,rev[N],pw[N];
    int a[N],b[N],c[N],d[N];
    int n,A,B,fac[N],ifac[N];
    
    int getint(){
        int X=0,w=0;char ch=getchar();
        while(!isdigit(ch))w|=ch=='-',ch=getchar();
        while( isdigit(ch))X=X*10+ch-48,ch=getchar();
        if(w) return -X;return X;
    }
    
    int ksm(int a,int b=mod-2,int ans=1){
        while(b){
            if(b&1) ans=1ll*ans*a%mod;
            a=1ll*a*a%mod;b>>=1;
        } return ans;
    }
    
    void ntt(int *f,int g){
        for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            int tmp=ksm(g,(mod-1)/(mid<<1));
            for(int R=mid<<1,j=0;j<lim;j+=R){
                for(int w=1,k=0;k<mid;k++,w=1ll*w*tmp%mod){
                    int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                    f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
                }
            }
        } if(g>3)
            for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
    }
    
    void solve(int *a,int len){
        if(len==1) return a[1]=1,void();
        if(len&1){
            solve(a,len-1);
            for(int i=len;i;i--) a[i]=(1ll*a[i]*(len-1)%mod+a[i-1])%mod;
        } else{
            solve(a,len>>1); int mid=len>>1;
            pw[0]=1;for(int i=1;i<=mid;i++) pw[i]=1ll*pw[i-1]*mid%mod;
            lim=1;while(lim<=len) lim<<=1;
            for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
            for(int i=0;i<=mid;i++) c[i]=1ll*a[i]*fac[i]%mod,d[mid-i]=1ll*pw[i]*ifac[i]%mod;
            ntt(c,3),ntt(d,3);
            for(int i=0;i<lim;i++) c[i]=1ll*c[i]*d[i]%mod;
            ntt(c,(mod+1)/3);
            for(int i=0;i<=mid;i++) c[i]=1ll*c[mid+i]*ifac[i]%mod;
            for(int i=mid+1;i<lim;i++) c[i]=0;
            ntt(c,3),ntt(a,3);
            for(int i=0;i<lim;i++) a[i]=1ll*a[i]*c[i]%mod;
            ntt(a,(mod+1)/3);
            for(int i=len+1;i<lim;i++) a[i]=0;
            for(int i=0;i<lim;i++) d[i]=c[i]=0;
        }
    }
    
    int C(int n,int m){
        return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
    }
    
    void init(int n){
        fac[0]=ifac[0]=1;
        for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
        ifac[n]=ksm(fac[n]);
        for(int i=n-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
    }
    
    signed main(){
        init(N-5); n=getint(),A=getint(),B=getint();
        if(!A or !B or n-1<A+B-2 or A+B-2<A-1) return puts("0"),0;
        if(n==1) return printf("1"),0;
        solve(a,n-1);
        printf("%lld
    ",1ll*a[A+B-2]*C(A+B-2,A-1)%mod);
        return 0;
    }
    
    
  • 相关阅读:
    手写简易SpringMVC框架,包含@PathVariable
    高并发下,如何保证接口的幂等性?
    JAVA判断奇偶数
    多线程ForkJoin-分治思想
    websocket简单使用
    Git使用教程:最详细、最傻瓜、最浅显、真正手把手教!(转载学习)
    linux配置java环境变量(详细)
    java缓存技术的介绍(转载)
    java 多态性详解及常见面试题
    oracle数据库基础知识总结(一)
  • 原文地址:https://www.cnblogs.com/YoungNeal/p/10366970.html
Copyright © 2011-2022 走看看