zoukankan      html  css  js  c++  java
  • 2017 多校3 hdu 6061 RXD and functions

    2017 多校3 hdu 6061 RXD and functions(FFT)

    题意:

    给一个函数(f(x)=sum_{i=0}^{n}c_i cdot x^{i})
    (g(x) = f(x - sum a_i))后每一项(x^{i})的系数mod998244353
    (n <= 10^{5},m <= 10^{5})
    (0 <= c_i < 998244353)
    (0 <= a_i < 998244353)

    思路:

    (d = -sum a_i),把(g(x))展开得:

    [g(x) = c_0 (x + d)^{0} + c_1 (x + d)^{1} + ... + c_n (x+d)^{n} ]

    (a_i = d^{i}),再用二项式定理化简一下可以得到

    [g(x) = sum_{i = 0}^{n}x^{i}sum_{k=i}^{n}C(k,i)c_ka_{k-i} ]

    (fft)只是入了门,想了半天,看不出来这是个卷积式子,组合数会变化啊,赛后终于开窍组合数是个阶乘啊,把(c_k 和 a^{k-i}变换一下)
    令$$b_k = c_k cdot k!, a_i = frac{a_i}{i!}$$
    (g(x))就可以写成

    [g(x) = sum_{i = 0}^{n}x^{i} cdot frac{1}{i!}sum_{k=i}^{n}b_ka_{k-i} ]

    (ans(i) = frac{1}{i!} sum_{k=i}^{n}b_ka_{k-i})
    把b数组逆序一下
    (ans(i) = frac{1}{i!}sum_{k=0}^{n-i}b_{k}a_{n-k-i})

    类比fft多项式乘法下面(c_j)的形式 (sum_{k=0}^{n-i}b_{k}a_{n-k-i})这一项其实就是(fft)之后得到的数组(c_{n-i}),最后答案(ans(i) = frac{1}{i!} c_{n-i})
    (A(x) = sum_{i=0}^{n}a_ix^{i})
    (B(x) = sum_{i=0}^{n}b_ix^{i})
    (C(x) = A(x)B(x) = sum_{i=0}^{2n}c_ix^{i})
    (c_j = sum_{i=0}^{j}a_ib_{j-i})

    然后就上板子了,由于是在模意义下的运算,要拿ntt,去找了个板子
    不太会用啊,板子上的费马素数是P=(1LL<<55) * 5+1,原根g=6的,
    开始交了几发,TLE,原来是数组开小了,改完再交RE了,也不知道改了哪里就没RE了,然后WA了,暴力对拍数据,发现是费马素数的锅,乱试了其他的一些费马素数,又想了半天觉得这样不行,本来就是在mod下取的逆元,又在P下做运算,ntt原理也不懂,一脸懵逼,最后我直接把P改成mod试了一下,居然A了,好像给的这个mod本来是就是一个费马素数(1<<23) * 119 + 1,g = 3,而且运气好前面试的费马素数原根刚好是3。
    还有疑问就是运算时费马素数应该取多大呢,==再深入学习一下

    #include<bits/stdc++.h>
    #define LL long long
    using namespace std;
    const int N = 2e5 + 1000;
    const int mod = 998244353;
    const LL P =  mod;
    const LL G = 3;
    const int NUM = 23;
    int read(){
        int x = 0;
        char c;
        while((c=getchar())<'0'||c>'9');
        while(c>='0'&&c<='9')
            x=x*10+(c-'0'), c=getchar();
        return x;
    }
    int fac[N],facinv[N];
    int n, m;
    LL mul(LL x,LL y){
     //return (x * y - (LL)(x / (long double)P * y + 1e-3) * P + P) % P;
     return x * y % P;
    }
    LL q_pow(LL a,LL b){
      LL res = 1,tmp = a;
      while(b){
        if(b &1) res = res * tmp % P;
        tmp = tmp * tmp % P;
        b >>= 1;
      }
      return res;
    }
    void init(){
        fac[0] = facinv[0] = 1;
        for(int i = 1;i < N;i++){
            fac[i] = 1LL * i * fac[i-1] % mod;
            facinv[i] = 1LL * q_pow(i, mod - 2) * facinv[i - 1] % mod;
        }
    }
    LL  wn[NUM];
    LL  a[2 * N], b[2 * N],c[N];
    void GetWn()
    {
        for(int i = 0; i< NUM; i++)
        {
            int t = 1 << i;
            wn[i] = q_pow(G, (P - 1) / t);
        }
    }
    void Rader(LL a[], int len)
    {
        int j = len >> 1;
        for(int i=1; i<len-1; i++)
        {
            if(i < j) swap(a[i], a[j]);
            int k = len >> 1;
            while(j >= k)
            {
                j -= k;
                k >>= 1;
            }
            if(j < k) j += k;
        }
    }
    void NTT(LL a[], int len, int on)
    {
        Rader(a, len);
        int id = 0;
        for(int h = 2; h <= len; h <<= 1)
        {
            id++;
            for(int j = 0; j < len; j += h)
            {
                LL w = 1;
                for(int k = j; k < j + h / 2; k++)
                {
                    LL u = a[k];
                    LL t = mul(w,a[k + h / 2]);
                    a[k] = (u + t) % P;
                    a[k + h / 2] = ((u - t) % P + P) % P;
                    w = mul(w,wn[id]);
                }
            }
        }
        if(on == -1)
        {
            for(int i = 1; i < len / 2; i++)
                swap(a[i], a[len - i]);
            LL Inv = q_pow(len, P - 2);
            for(int i = 0; i < len; i++)
                a[i] = mul(a[i],Inv);
        }
    }
    void Conv(LL a[], LL b[], int n)
    {
        NTT(a, n, 1);
        NTT(b, n, 1);
        for(int i = 0; i < n; i++) a[i] = mul(a[i],b[i]);
        NTT(a, n, -1);
    }
    int main()
    {
        GetWn();
        init();
        while(scanf("%d",&n) == 1){
            for(int i = 0;i <= n;i++) c[i] = read();
            int  sum = 0;
            m = read();
            for(int i = 1;i <= m;i++){
                int x;
                x = read();
                sum = (sum - x + mod) % mod;
            }
            int len = 1;
            while(len < 2 * (n + 1)) len <<= 1;
            int res = 1;
            for(int i = 0;i <= n;i++) {
                a[i] = 1LL * res * facinv[i] % mod, res = 1LL * res * sum % mod;
                b[i] = c[n - i] * fac[n - i] % mod;
            }
            for(int i = n + 1;i < len;i++) a[i] = b[i] = 0;
            Conv(a,b,len);
            for(int i = 0;i <= n;i++) printf("%lld ",a[n - i] * facinv[i] % mod);
            printf("
    ");
        }
        return 0;
    }
    
    
  • 相关阅读:
    MySQL数据库常见面试题
    抽象类与接口
    HashMap与Hashtable的区别
    IDEA破解
    重写equals方法
    MFC编程入门之十七(对话框:文件对话框)
    MFC编程入门之十六(对话框:消息对话框)
    MFC编程入门之十五(对话框:一般属性页对话框的创建及显示)
    MFC编程入门之十三(对话框:属性页对话框及相关类的介绍)
    MFC编程入门之十二(对话框:非模态对话框的创建及显示)
  • 原文地址:https://www.cnblogs.com/jiachinzhao/p/7273500.html
Copyright © 2011-2022 走看看