zoukankan      html  css  js  c++  java
  • HDU 6061 RXD and functions NTT

    RXD and functions

    Problem Description
    RXD has a polynomial function f(x)f(x)=ni=0cixi
    RXD has a transformation of function Tr(f,a), it returns another function g, which has a property that g(x)=f(xa).
    Given a1,a2,a3,,am, RXD generates a polynomial function sequence gi, in which g0=f and gi=Tr(gi1,ai)
    RXD wants you to find gm, in the form of mi=0bixi
    You need to output bi module 998244353.
    n105
     
    Input
    There are several test cases, please keep reading until EOF.
    For each test case, the first line consists of 1 integer n, which means degF.
    The next line consists of n+1 intergers ci,0ci<998244353, which means the coefficient of the polynomial.
    The next line contains an integer m, which means the length of a.
    The next line contains m integers, the i - th integer is ai.
    There are 11 test cases.
    0<=ai<998244353
    m105
     
    Output
    For each test case, output an polynomial with degree n, which means the answer.
     
    Sample Input
    2 0 0 1 1 1
     
    Sample Output
    1 998244351 1
    Hint
    $(x - 1) ^ 2 = x^2 - 2x + 1$
     

    题解:

      

    代码:

      

    #include<bits/stdc++.h>
    using namespace std;
    #pragma comment(linker, "/STACK:102400000,102400000")
    #define ls i<<1
    #define rs ls | 1
    #define mid ((ll+rr)>>1)
    #define pii pair<int,int>
    #define MP make_pair
    typedef long long LL;
    typedef unsigned long long ULL;
    const long long INF = 1e18+1LL;
    const double pi = acos(-1.0);
    const int N = 5e5+10, M = 1e3+20,inf = 2e9;
    
    const long long P=998244353LL,mod = 998244353LL;
    const LL G=3LL;
    
    LL mul(LL x,LL y){
        return (x*y-(LL)(x/(long double)P*y+1e-3)*P+P)%P;
    }
    LL qpow(LL x,LL k){
        LL ret=1;
        while(k){
            if(k&1) ret=mul(ret,x);
            k>>=1;
            x=mul(x,x);
        }
        return ret;
    }
    LL wn[50];
    void getwn(){
        for(int i=1; i<=40; ++i){
            int t=1<<i;
            wn[i]=qpow(G,(P-1)/t);
        }
    }
    
    int len;
    void NTT(LL y[],int op){
        for(int i=1,j=len>>1,k; i<len-1; ++i){
            if(i<j) swap(y[i],y[j]);
            k=len>>1;
            while(j>=k){
                j-=k;
                k>>=1;
            }
            if(j<k) j+=k;
        }
        int id=0;
        for(int h=2; h<=len; h<<=1) {
            ++id;
            for(int i=0; i<len; i+=h){
                LL w=1;
                for(int j=i; j<i+(h>>1); ++j){
                    LL u=y[j],t=mul(y[j+h/2],w);
                    y[j]=u+t;
                    if(y[j]>=P) y[j]-=P;
                    y[j+h/2]=u-t+P;
                    if(y[j+h/2]>=P) y[j+h/2]-=P;
                    w=mul(w,wn[id]);
                }
            }
        }
        if(op==-1){
            for(int i=1; i<len/2; ++i) swap(y[i],y[len-i]);
            LL inv=qpow(len,P-2);
            for(int i=0; i<len; ++i) y[i]=mul(y[i],inv);
        }
    }
    LL c[N],fac[N],ans[N],inv[N],id[N],s[N],t[N];
    int n;
    void solve(LL mo) {
        if(mo == 0) {
            for(int i = 0; i <= n; ++i)
                ans[i] = c[i];
            return ;
        }
        mo = (mod - mo) % mod;
        len = 1;
        while(len <= 2*n+5) len<<=1;
        id[0] = 1;
        for(int i = 1; i <= n; ++i)
            id[i] = id[i-1] * mo % mod;
        for(int i = 0; i < len ; ++i) s[i] = 0,t[i] = 0;
        for(int i = 0; i <= n; ++i)
            s[i] = c[i]*fac[i]%mod,
            t[n - i] = id[i] * inv[i] % mod;
        NTT(s,1),NTT(t,1);
        for(int i = 0; i < len; ++i) s[i] = s[i]*t[i] % mod;
        NTT(s,-1);
        for(int i = 0; i <= n; ++i) {
            ans[i] = s[n+i]*inv[i] % mod;
        }
    }
    int m;
    int main() {
        getwn();
        while(scanf("%d",&n)!=EOF) {
            for(int i = 0; i <= n; ++i) {
                scanf("%lld",&c[i]);
            } 
            fac[0] = 1;
            for(int i = 1; i <= n; ++i) {
                fac[i] = fac[i-1]*1LL*i%mod;
            }
            inv[n]=qpow(fac[n],mod-2);
            for(int i = n-1; i >= 0; --i)
                inv[i]=inv[i+1]*1ll*(i+1)%mod;
            scanf("%d",&m);
            int sum = 0;
            for(int i = 1; i <= m; ++i) {
                int x;
                scanf("%d",&x);
                sum += x;
                sum %= mod;
            }
            solve(sum);
            for(int i = 0; i < n; ++i)
                printf("%lld ",ans[i]);
            printf("%lld 
    ",ans[n]);
        }
        return 0;
    }
  • 相关阅读:
    Python中replace 不起作用的问题
    java 获取视频时长、大小
    MySQL 自定义排序
    加 synchronized 关键字进行同步
    SQL 查询当前周的开始、结束日期
    Java 按照一定的规则生成递增的编号
    Java中BigDecimal的8种舍入模式
    Lamada 表达式之 sort 排序
    搭建Java环境
    初识JAVA(学习记录)
  • 原文地址:https://www.cnblogs.com/zxhl/p/7273363.html
Copyright © 2011-2022 走看看