zoukankan      html  css  js  c++  java
  • 【学习笔记】二项式反演

    考场上的我猜测到了某个题是二项式反演 可惜我不会啊

    分治NTT给了60的暴力 然而他们的分治NTT都能跑1e6/dk

    回归正题

    二项式反演大概就是关于二项式的反演(废话)

    一般来说 初始状态都是状压的0/1状态 然后它们有很好的性质 与具体哪一位的状态无关 所以可以直接记录1的个数

    又因为反演≈容斥 所以它的两个基本形式就很好看了

    $f(n) = sum_{i=0} ^ n  (-1)^i inom{n}{i} g(i) Leftrightarrow g(n) = sum_{i=0} ^n (-1)^i inom{n}{i} f(i)$

    一般这个柿子用不到(吧 因为很少上来就是个容斥形式

    $f(n) = sum _{i=0} ^n inom{n}{i} g(i) Leftrightarrow g(n)=sum_{i=0}^n (-1)^{n-i} inom{n}{i} f(i)$

    一般是这个柿子比较有用qwq

    具体证明自行百度吧(~~其实就是我懒~~

    例题

    [HAOI2016] 染色

    其实这个题完全可以不用二项式反演直接容斥 也间接性证明了反演的本质就是容斥

    考虑恰好为S有k个的限制 按照套路转化为至少有k个为S

    设所求函数为G(x) 容斥函数为F(x)

    容斥函数可直接计算$F(x) = inom{m}{x} (m-i)^{n-ix} frac{(sx)!}{(x!)^s (n-ix)!}$

    然后直接套上上边的容斥就可以了

    最后NTT优化即可

    以前的代码(反正最后柿子推出来是一样的)

    //Love and Freedom.
    #include<cstdio>
    #include<cmath>
    #include<algorithm>
    #include<cstring>
    #define ll long long
    #define inf 20021225
    #define mdn 1004535809
    #define N 300010
    #define G 3
    using namespace std;
    
    int rev[N<<2],inv,jc[N*100],iv[N*100],n,m,s,w[N];
    int ksm(int bs,int mi)
    {
        int ans=1;
        while(mi)
        {
            if(mi&1)    ans=(ll)ans*bs%mdn;
            bs=(ll)bs*bs%mdn; mi>>=1;
        }
        return ans;
    }
    int pre(int n)
    {
        int lim=1,l=0;
        while(lim<n)    lim<<=1,l++;
        for(int i=1;i<lim;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
        inv=ksm(lim,mdn-2); return lim;
    }
    void ntt(int *a,int lim,int opt)
    {
        for(int i=1;i<lim;i++)    if(i<rev[i])
            swap(a[i],a[rev[i]]);
        for(int k=2,mid=1;k<=lim;k<<=1,mid<<=1)
        {
            int Wn=ksm(G,(mdn-1)/k);    if(opt)    Wn=ksm(Wn,mdn-2);
            for(int w=1,i=0;i<lim;w=1,i+=k)    for(int j=0;j<mid;j++,w=(ll)w*Wn%mdn)
            {
                int x=a[i+j],y=(ll)w*a[i+mid+j]%mdn;
                a[i+j]=(x+y)%mdn; a[i+mid+j]=(x-y+mdn)%mdn;
            }
        }
        if(opt)    for(int i=0;i<lim;i++)    a[i]=(ll)a[i]*inv%mdn;
    }
    int h[N],f[N],ans[N]; int top;
    int main()
    {
        scanf("%d%d%d",&n,&m,&s);
        for(int i=0;i<=m;i++)    scanf("%d",&w[i]);
        top=min(n/s,m); jc[0]=iv[0]=1; int tt=max(n,m);
        for(int i=1;i<=tt;i++)    jc[i]=(ll)jc[i-1]*i%mdn;
        iv[tt]=ksm(jc[tt],mdn-2); int qwq=1;
        for(int i=tt;i;i--)    iv[i-1]=(ll)iv[i]*i%mdn;
        for(int i=0;i<=top;i++)
            f[i]=(ll)jc[m]*iv[m-i]%mdn*jc[n]%mdn*iv[n-i*s]%mdn*qwq%mdn*ksm(m-i,n-i*s)%mdn,qwq=(ll)qwq*iv[s]%mdn;
        for(int i=0;i<=top;i++)    h[i]=(ll)(mdn+((i&1)?-1:1)*iv[i])%mdn;
        for(int i=0;i<=(top>>1);i++)    swap(h[i],h[top-i]);
        int lim=pre((top+1)<<1);// printf("%d
    ",lim);
        ntt(h,lim,0); ntt(f,lim,0);
        for(int i=0;i<lim;i++)    ans[i]=(ll)h[i]*f[i]%mdn;
        ntt(ans,lim,1); int fin=0;
        for(int i=0;i<=top;i++)
            fin=(ll)(fin+(ll)ans[top+i]*w[i]%mdn*iv[i]%mdn)%mdn;
        printf("%d
    ",fin);
        return 0;
    }
    View Code

    [CTS2019] 珍珠

    论考试前一天晚上刚好看到原题的几率 可惜没看懂

    考虑有多少颜色是奇数

    直接做的话显然是

    $G(i) = (frac{e^x + e^{-x}}{2})^i (frac{e^x - e^{-x}}{2})^{n-i}$

    而且还要求n!

    发现n1e9/px

    于是就是二项式反演了

    $F(i) = n![x^n](frac{e^x + e^{-x}}{2})^i e^{(D-i)x}$

    我们进一步拆开括号画柿子

    $F(i)= frac{1}{2^i} sum_{j=0}^i inom{i}{j} n![x^n]e^{jx-(i-j)x} e^{(D-i)x}$

    $F(i)=frac{1}{2^i}sum_{j=0}^i inom{i}{j} n![x^n]e^{(D-2i+2j)x}$

    $F(i)=frac{1}{2^i}sum_{j=0}^i inom{i}{j} (D-2i+2j)^n$

    典型的卷积形式 NTT优化掉就可以了

    反演直接带进柿子里继续NTT一下就好了

    (吐槽一句 毒瘤出题人把n开到了[0,1e18] 于是还要特判0/px)

    啊 这个题懒得再回去写了 直接丢考场的加强代码吧x

    //Love and Freedom.
    #include<algorithm>
    #include<cmath>
    #include<cstring>
    #include<cstdio>
    #define inf 20021225
    #define ll long long
    #define N 600010
    #define mdn 998244353
    #define G 3
    using namespace std;
    ll read()
    {
        ll f=1,s=0; char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return f*s;
    }
    int fac[N],inv[N];
    void upd(int &x,int y){x+=x+y>=mdn?y-mdn:y;}
    int C(int n,int m){return n<m?0:1ll*fac[n]*inv[m]%mdn*inv[n-m]%mdn;}
    int ksm(int bs,ll mi)
    {
        int ans=1;
        if(!bs)    return !mi?1:0; mi%=(mdn-1);
        while(mi)
        {
            if(mi&1)    ans=1ll*ans*bs%mdn;
            bs=1ll*bs*bs%mdn; mi>>=1;
        }
        return ans;
    }
    int n,m; ll q,k;
    int r[N];
    int init(int n)
    {
        int l=0,lim=1;
        while(lim<n)    lim<<=1,l++;
        for(int i=0;i<lim;i++)
            r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
        return lim;
    }
    void ntt(int *a,int lim,int f)
    {
        for(int i=0;i<lim;i++)    if(r[i]>i)
            swap(a[r[i]],a[i]);
        for(int k=2,mid=1;k<=lim;k<<=1,mid<<=1)
        {
            int Wn=ksm(G,(mdn-1)/k); if(f)    Wn=ksm(Wn,mdn-2);
            for(int w=1,i=0;i<lim;i+=k,w=1)    for(int j=0;j<mid;j++,w=1ll*w*Wn%mdn)
            {
                int x=a[i+j],y=1ll*w*a[i+mid+j]%mdn;
                a[i+j]=(x+y)%mdn; a[i+mid+j]=(mdn+x-y)%mdn;
            }
        }
        if(f)    for(int kinv=ksm(lim,mdn-2),i=0;i<lim;i++)
            a[i]=1ll*a[i]*kinv%mdn;
    }
    int tmp[N],g[N],f[N];
    void solve(int *a,int n)
    {
        memset(g,0,sizeof(g)); memset(f,0,sizeof(f)); 
        for(int i=0;i<=n;i++)    g[i]=inv[i],f[i]=1ll*(i&1?mdn-inv[i]:inv[i])*ksm((n-2*i+mdn)%mdn,q)%mdn;
        int lim=init(n+1<<1); ntt(g,lim,0); ntt(f,lim,0);
        for(int i=0;i<lim;i++)    g[i]=1ll*g[i]*f[i]%mdn;
        ntt(g,lim,1);
        for(int i=0;i<=n;i++)    g[i]=1ll*ksm(2,(mdn-1-i)%mdn)*fac[i]%mdn*g[i]%mdn;
        for(int i=0;i<=n;i++)    g[i]=1ll*g[i]*fac[i]%mdn*C(n,i)%mdn,f[i]=i&1?mdn-inv[i]:inv[i];
        for(int i=n+1;i<lim;i++)    f[i]=g[i]=0; reverse(f,f+n+1);
        ntt(g,lim,0); ntt(f,lim,0);
        for(int i=0;i<lim;i++)    a[i]=1ll*g[i]*f[i]%mdn;
        ntt(a,lim,1);
        for(int i=0;i<=n;i++)    a[i]=1ll*a[i+n]*inv[i]%mdn;
    }
    int pre[N],h[N],l[N];
    int qry(int l,int r){return l>r?0:l?(pre[r]-pre[l-1]+mdn)%mdn:pre[r];}
    int main()
    {
        //freopen("qwq.txt","w",stdout);
        n=read(),m=read(),q=read(),k=read(); //printf("%lld %lld
    ",q,k);
        int top=max(n,m);
        fac[0]=1; for(int i=1;i<=top;i++)    fac[i]=1ll*fac[i-1]*i%mdn;
        inv[top]=ksm(fac[top],mdn-2); for(int i=top;i;i--)    inv[i-1]=1ll*inv[i]*i%mdn;
        solve(h,n); solve(l,m); int ans=0;
        pre[0]=l[0]; for(int i=1;i<=m;i++)    pre[i]=pre[i-1],upd(pre[i],l[i]);
        for(int i=0;i<=n;i++)
        {
            int fm=n-2*i; ll fz=k-1ll*m*i;
            if(!fm)
            {
                if(1ll*m*i<=k)    upd(ans,1ll*h[i]*qry(0,m)%mdn);
            }
            else if(fm<0)
            {
                ll top=ceil((long double)fz/fm); top=max(top,0ll);
                if(top<=m)    upd(ans,1ll*h[i]*qry(top,m)%mdn);
            }
            else
            {
                ll top=floor((long double)fz/fm); top=min(top,(ll)m);
                if(top>=0)    upd(ans,1ll*h[i]*qry(0,top)%mdn);
            }
        }
        printf("%d
    ",ans);
        return 0;
    }
    View Code

    [CTS2019] 随机立方体

    8会 先鸽着(不咕不咕 今天就写x)

    写了可是题解懒得写了

    差不多看看吧 还挺好理解的(雾)

    //Love and Freedom.
    #include<algorithm>
    #include<cstring>
    #include<cmath>
    #include<cstdio>
    #define inf 20021225
    #define ll long long
    #define mdn 998244353
    #define N 5000100
    using namespace std;
    int read()
    {
        int f=1,s=0; char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
        while(ch>='0'&&ch<='9')    s=s*10+ch-'0',ch=getchar();
        return f*s;
    }
    int ksm(int bs,int mi)
    {
        int ans=1;
        while(mi)
        {
            if(mi&1)    ans=1ll*ans*bs%mdn;
            bs=1ll*bs*bs%mdn; mi>>=1;
        }
        return ans;
    }
    int f[N],fac[N],inv[N],n=N-1,m,l;
    int F(int k){return (1ll*n*m%mdn*l%mdn-1ll*(n-k)*(m-k)%mdn*(l-k)%mdn+mdn)%mdn;}
    int C(int n,int m){return n<m?0:1ll*fac[n]*inv[m]%mdn*inv[n-m]%mdn;}
    int W(int k){return 1ll*C(n,k)*C(m,k)%mdn*C(l,k)%mdn*fac[k]%mdn*fac[k]%mdn*fac[k]%mdn;}
    int min(int x,int y,int z){return min(x,min(y,z));}
    int max(int x,int y,int z){return max(x,max(y,z));}
    void calcf()
    {
        int top=min(n,m,l); f[0]=1; int all=1;
        for(int i=1;i<=top;i++)    all=1ll*F(i)*all%mdn;
        all=ksm(all,mdn-2); f[top]=all; //printf("%d
    ",all);
        for(int i=top;i;i--)
        {
            f[i-1]=1ll*F(i)*f[i]%mdn;//puts("GG");
        }
        for(int i=0;i<=top;i++)    f[i]=1ll*f[i]*W(i)%mdn;
    }
    void init()
    {
        int top=max(n,m,l);    fac[0]=1;
        for(int i=1;i<=top;i++)    fac[i]=1ll*fac[i-1]*i%mdn;
        inv[top]=ksm(fac[top],mdn-2);
        for(int i=top;i;i--)    inv[i-1]=1ll*inv[i]*i%mdn;
    }
    void upd(int &x,int y){x+=x+y>=mdn?y-mdn:y;}
    int main()
    {
        int T=read(); init();
        while(T--)
        {
            n=read(),m=read(),l=read(); int ans=0,k=read();
            calcf(); int top=min(n,m,l);
            for(int i=k;i<=top;i++)
                upd(ans,1ll*((i-k)&1?mdn-1:1)*C(i,k)%mdn*f[i]%mdn);
            printf("%d
    ",ans);
        }
        return 0;
    }
    View Code

    发现有两道sb题没放 刚好再推下柿子

    集合计数

    定义$f(k)$为交集元素恰好$k$个,$g(k)$为交集元素至少$k$个

    显然有$g(n)=sum_{k=n}^N inom{k}{n}f(k)$

    二项式反演得$f(n) = sum_{i=n}^N (-1)^{k-n}g(k)inom{k}{n}$

    考虑求$g(n)=2^{2^{N-n}}$也就是钦定$n$个位置,剩下的随便选。

    于是就做完了

    //Love and Freedom.
    #include<algorithm>
    #include<cmath>
    #include<cstring>
    #include<cstdio>
    #define inf 20021225
    #define ll long long
    #define N 2000100
    #define mdn 1000000007
    using namespace std;
    int read()
    {
        int f=1,s=0; char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return f*s;
    }
    int fac[N],inv[N];
    int ksm(int bs,int mi,int md)
    {
        int ans=1;
        while(mi)
        {
            if(mi&1)    ans=1ll*ans*bs%md;
            bs=1ll*bs*bs%md; mi>>=1;
        }
        return ans;
    }
    int C(int n,int m){return 1ll*fac[n]*inv[m]%mdn*inv[n-m]%mdn;}
    int main()
    {
        int n=read(),k=read(); fac[0]=1;
        for(int i=1;i<=n;i++)    fac[i]=1ll*fac[i-1]*i%mdn;
        inv[n]=ksm(fac[n],mdn-2,mdn);
        for(int i=n;i;i--)    inv[i-1]=1ll*inv[i]*i%mdn;
        int ans=0;
        for(int i=k;i<=n;i++)
            ans=(ans+1ll*((i-k)&1?mdn-1:1)*C(n,i)%mdn*C(i,k)%mdn*(ksm(2,ksm(2,n-i,mdn-1),mdn)-1+mdn)%mdn)%mdn;
        printf("%d
    ",ans);
        return 0;
    }
    View Code

    已经没有什么好害怕的了

    还是类似定义$f(k)$为恰好有$k$个糖果比药片大$g(k)$为至少有$k$个糖果比药片大

    然后我们把药片和糖果一起排序 然后直接dp就好了

    //Love and Freedom.
    #include<algorithm>
    #include<cmath>
    #include<cstring>
    #include<cstdio>
    #define inf 20021225
    #define ll long long
    #define N 2010
    #define mdn 1000000009
    using namespace std;
    int read()
    {
        int f=1,s=0; char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return f*s;
    }
    int f[N][N],ida[N],idb[N],a[N],b[N],val[N<<1];
    int fac[N],inv[N],cnt[N],pre[N<<1];
    void upd(int &x,int y){x+=x+y>=mdn?y-mdn:y;}
    int ksm(int bs,int mi)
    {
        int ans=1;
        while(mi)
        {
            if(mi&1)    ans=1ll*ans*bs%mdn;
            bs=1ll*bs*bs%mdn; mi>>=1;
        }
        return ans;
    }
    int C(int n,int m){return n<m?0:1ll*fac[n]*inv[m]%mdn*inv[n-m]%mdn;}
    int main()
    {
        int n=read(),k=read(); fac[0]=1;
        if((n+k)&1)    return puts("0"),0; k=(n+k)>>1;
        for(int i=1;i<=n;i++)    a[i]=read(),val[i]=a[i];
        for(int i=1;i<=n;i++)    fac[i]=1ll*fac[i-1]*i%mdn;
        inv[n]=ksm(fac[n],mdn-2);
        for(int i=n;i;i--)    inv[i-1]=1ll*inv[i]*i%mdn;
        for(int i=1;i<=n;i++)    b[i]=read(),val[n+i]=b[i];
        //puts("GG");
        sort(val+1,val+n+n+1); int top=2*n;//unique(val+1,val+n+n+1)-val-1;
        for(int i=1;i<=n;i++)    ida[i]=lower_bound(val+1,val+top+1,a[i])-val,idb[i]=lower_bound(val+1,val+top+1,b[i])-val;
        for(int i=1;i<=n;i++)    pre[idb[i]]++;
        for(int i=1;i<=top;i++)    pre[i]+=pre[i-1];
        sort(ida+1,ida+n+1);
        for(int i=0;i<=n;i++)    f[i][0]=1; 
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
                upd(f[i][j],1ll*f[i-1][j-1]*max(pre[ida[i]]+1-j,0)%mdn),upd(f[i][j],f[i-1][j]);
        int ans=0;// puts("GG");
        for(int i=k;i<=n;i++)
            upd(ans,1ll*((i-k)&1?mdn-1:1)*f[n][i]%mdn*fac[n-i]%mdn*C(i,k)%mdn);
        printf("%d
    ",ans);
        return 0;
    }
    View Code
  • 相关阅读:
    mysql数据库安装与配置
    redis主从配置+sentinel哨兵模式
    Oracle 本地验证和密码文件
    Oracle 12c hub和leaf的转换
    oracle 12c CPU资源隔离
    oracle12 listagg 与 wm_concat行列转换
    Oracle 12c rac搭建
    ClassLoader.loadClass()与Class.forName()的区别《 转》
    docker 安装mysql8.0
    spring boot @EnableWebMvc禁用springMvc自动配置原理。
  • 原文地址:https://www.cnblogs.com/hanyuweining/p/11950267.html
Copyright © 2011-2022 走看看