zoukankan      html  css  js  c++  java
  • 付公主的背包——生成函数

    题面

        洛谷P4389

    解析

      每一个物品可以看作一个多项式$f(x)=sum_{i=0}^{infty}[i\%V==0]x^i=frac{1}{1-x^V}$,暴力把$n$个多项式乘起来是$O(NMlog M)$的复杂度,显然不能接受。

      于是看了下题解,学了一手奇技淫巧。

      设$h(x)=ln(x)$, $g(x)=ln(f(x))=h(f(x))$,有:$$egin{align*}g'(x)&=h'(f(x))f'(x)\ &=frac{f'(x)}{f(x)}\&=frac{sum_{i=1}^{infty}[i\%V==0]i cdot x^{i-1}}{frac{1}{1-x^V}}\ &=(1-x^V)sum_{i=1}^{infty}Vi cdot x^{Vi-1}\ &= sum_{i=1}^{infty}Vi cdot x^{Vi-1}-sum_{i=1}^{infty}Vi cdot x^{V(i+1)-1}\ &= sum_{i=1}^{infty}Vi cdot x^{Vi-1}-sum_{i=2}^{infty}V(i-1) cdot x^{Vi-1}\ &= sum_{i=1}^{infty}Vi cdot x^{Vi-1}-sum_{i=2}^{infty}Vi cdot x^{Vi-1} + sum_{i=2}^{infty}Vcdot x^{Vi-1}\ &= sum_{i=1}^{infty}Vcdot x^{Vi-1}end{align*}$$

      则:$$g(x)=sum_{i=1}^{infty}frac{1}{i}x^{Vi}$$

      而$f(x)=e^{g(x)}$,最终答案为:$$ans=prod_{i=1}^n f_i(x)\ ans= prod_{i=1}^n e^{g_i(x)}\ ans = e^{sum_{i=1}^n g_i(x)}$$

      对每个不同体积求$g(x)$之和的时间是$O(M ln M)$,最后$exp$的时间是$O(M log M)$,因此总时间为$O(Mln M+Mlog M)$

     代码:

    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    #include<cstring>
    using namespace std;
    typedef long long ll;
    const int maxn = 200005, mod = 998244353, g = 3;
    
    inline int read()
    {
        int ret, f=1;
        char c;
        while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1;
        ret=c-'0';
        while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0';
        return ret*f;
    }
    
    int add(int x, int y)
    {
        return x + y < mod? x + y: x + y - mod;
    }
    
    int rdc(int x, int y)
    {
        return x - y < 0? x - y + mod: x - y;
    }
    
    ll qpow(ll x, int y)
    {
        ll ret = 1;
        while(y)
        {
            if(y&1)
                ret = ret * x % mod;
            x = x * x % mod;
            y >>= 1;
        }
        return ret;
    }
    
    int n, m, num[maxn], lim, bit, rev[maxn<<1];
    ll ginv, ln[maxn<<1], iv[maxn<<1], inv[maxn], F[maxn<<1], G[maxn<<1], c[maxn<<1];
    
    void init()
    {
        ginv = qpow(g, mod - 2);
        inv[0] = inv[1] = 1;
        for(int i = 2; i <= m; ++i)
            inv[i] = (mod - mod / i) * inv[mod%i] % mod;
    }
    
    void NTT_init(int x)
    {
        lim = 1;
        bit = 0;
        while(lim <= x)
        {
            lim <<= 1;
            ++ bit;
        }
        for(int i = 1; i < lim; ++i)
            rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1));
    }
    
    void NTT(ll *x, int y)
    {
        for(int i = 1; i < lim; ++i)
            if(i < rev[i])
                swap(x[i], x[rev[i]]);
        ll wn, w, u, v;
        for(int i = 1; i < lim; i <<= 1)
        {
            wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1));
            for(int j = 0; j < lim; j += (i << 1))
            {
                w = 1;
                for(int k = 0; k < i; ++k)
                {
                    u = x[j+k];
                    v = x[j+k+i] * w % mod;
                    x[j+k] = add(u, v);
                    x[j+k+i] = rdc(u, v);
                    w = w * wn % mod;
                }
            }
        }
        if(y == -1)
        {
            ll linv = qpow(lim, mod - 2);
            for(int i = 0; i < lim; ++i)
                x[i] = x[i] * linv % mod;
        }
    }
    
    void get_inv(ll *x, ll *y, int len)
    {
        if(len == 1)
        {
            x[0] = qpow(y[0], mod - 2);
            return ;
        }
        get_inv(x, y, (len + 1) >> 1);
        for(int i = 0; i < len; ++i)
            c[i] = y[i];
        NTT_init(len << 1);
        NTT(x, 1);
        NTT(c, 1);
        for(int i = 0; i < lim; ++i)
        {
            x[i] = rdc(add(x[i], x[i]), (c[i] * x[i] % mod) * x[i] % mod);
            c[i] = 0;
        }
        NTT(x, -1);
        for(int i = len; i < lim; ++i)
            x[i] = 0;
    }
    
    void get_ln(ll *x, ll *y, int len)
    {
        for(int i = 0; i < len - 1; ++i)
            x[i] = y[i+1] * (i + 1) % mod;
        get_inv(iv, y, len);
        NTT_init(len << 1);
        NTT(x, 1);
        NTT(iv, 1);
        for(int i = 0; i < lim; ++i)
        {
            x[i] = x[i] * iv[i] % mod;
            iv[i] = 0;
        }
        NTT(x, -1);
        for(int i = len - 1; i >= 1; --i)
            x[i] = x[i-1] * qpow(i, mod - 2) % mod;
        x[0] = 0;
        for(int i = len; i < lim; ++i)
            x[i] = 0;
    }
    
    void get_exp(ll *x, ll *y, int len)
    {
        if(len == 1)
        {
            x[0] = 1;
            return ;
        }
        get_exp(x, y, (len + 1) >> 1);
        get_ln(ln, x, len);
        for(int i = 0; i < len; ++i)
        {
            c[i] = rdc(add(i == 0, y[i]), ln[i]);
            ln[i] = 0;
        }
        NTT_init(len << 1);
        NTT(x, 1);
        NTT(c, 1);
        for(int i = 0; i < lim; ++i)
        {
            x[i] = x[i] * c[i] % mod;
            c[i] = 0;
        }
        NTT(x, -1);
        for(int i = len; i < lim; ++i)
            x[i] = 0;
    }
    
    int main()
    {
        n = read(); m = read();
        init();
        int x;
        for(int i = 1; i <= n; ++i)
        {
            x = read();
            ++ num[x];
        }
        for(int i = 1; i <= m; ++i)
            if(num[i])
            {
                for(int j = 1; i * j <= m; ++j)
                    G[i*j] = add(G[i*j], num[i] * inv[j] % mod);
            }
        get_exp(F, G, m + 1);
        for(int i = 1; i <= m; ++i)
            printf("%lld
    ", F[i]);
        return 0;
    }
    View Code
  • 相关阅读:
    Python基础
    熟悉常见的Linux操作
    大数据概述
    实验报告(3)-语法分析
    LL(1)文法
    简化版C语言文法
    实验报告(1)-词法分析
    中文词频统计
    综合练习:英文词频统计
    字符串练习
  • 原文地址:https://www.cnblogs.com/Joker-Yza/p/12623480.html
Copyright © 2011-2022 走看看