zoukankan      html  css  js  c++  java
  • 【洛谷P5050】 【模板】多项式多点求值

    code: 

    #include <bits/stdc++.h>     
    #define ll long long 
    #define ull unsigned long long 
    #define setIO(s) freopen(s".in","r",stdin) // , freopen(s".out","w",stdout)        
    using namespace std;  
    char buf[100000],*p1,*p2;
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
    int rd() 
    {
        int x=0; char s=nc();
        while(s<'0') s=nc();
        while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc();
        return x;
    }         
    void print(int x) {if(x>=10) print(x/10);putchar(x%10+'0');}
    const int G=3;  
    const int N=2000005;   
    const int mod=998244353;                   
    int A[N],B[N],w[2][N],mem[N*100],*ptr=mem;      
    inline int qpow(int x,int y) 
    {
        int tmp=1;     
        for(;y;y>>=1,x=(ll)x*x%mod)     if(y&1) tmp=(ll)tmp*x%mod;  
        return tmp;    
    }      
    inline int INV(int a) { return qpow(a,mod-2); }        
    inline void ntt_init(int len) 
    {
        int i,j,k,mid,x,y;      
        w[1][0]=w[0][0]=1,x=qpow(3,(mod-1)/len),y=qpow(x,mod-2);
        for (i=1;i<len;++i) w[0][i]=(ll)w[0][i-1]*x%mod,w[1][i]=(ll)w[1][i-1]*y%mod;         
    }
    void NTT(int *a,int len,int flag) 
    {
        int i,j,k,mid,x,y;                
        for(i=k=0;i<len;++i) 
        {
            if(i>k)    swap(a[i],a[k]);  
            for(j=len>>1;(k^=j)<j;j>>=1);  
        }   
        for(mid=1;mid<len;mid<<=1)            
            for(i=0;i<len;i+=mid<<1) 
                for(j=0;j<mid;++j)          
                {
                    x=a[i+j], y=(ll)w[flag==-1][len/(mid<<1)*j]*a[i+j+mid]%mod;  
                    a[i+j]=(x+y)%mod;  
                    a[i+j+mid]=(x-y+mod)%mod;   
                }   
        if(flag==-1)  
        {
            int rev=INV(len);   
            for(i=0;i<len;++i)    a[i]=(ll)a[i]*rev%mod;   
        }
    }              
    inline void getinv(int *a,int *b,int len,int la) 
    {
        if(len==1) { b[0]=INV(a[0]);   return; }
        getinv(a,b,len>>1,la);    
        int l=len<<1,i;   
        memset(A,0,l*sizeof(A[0]));              
        memset(B,0,l*sizeof(A[0]));    
        memcpy(A,a,min(la,len)*sizeof(a[0]));                                                      
        memcpy(B,b,len*sizeof(b[0]));             
        ntt_init(l);   
        NTT(A,l,1),NTT(B,l,1);      
        for(i=0;i<l;++i)  A[i]=((ll)2-(ll)A[i]*B[i]%mod+mod)*B[i]%mod;
        NTT(A,l,-1);                                 
        memcpy(b,A,len<<2);          
    }  
    struct poly 
    {
        int len,*a;    
        poly(){}       
        poly(int l) {len=l,a=ptr,ptr+=l; }            
        inline void rev() { reverse(a,a+len); }       
        inline void fix(int l) {len=l,a=ptr,ptr+=l;}   
        inline void get_mod(int l) { for(int i=l;i<len;++i) a[i]=0;  len=l;  }
        inline poly dao() 
        {        
            poly re(len-1);   
            for(int i=1;i<len;++i)  re.a[i-1]=(ll)i*a[i]%mod;         
            return re;    
        }    
        inline poly Inv(int l) 
        {  
            poly b(l);              
            getinv(a,b.a,l,len);                                  
            return b;                        
        }                                                                    
        inline poly operator * (const poly &b) const 
        {
            poly c(len+b.len-1);   
            if(c.len<=500) 
            {         
                for(int i=0;i<len;++i)   
                    if(a[i])   for(int j=0;j<b.len;++j)  c.a[i+j]=(c.a[i+j]+(ll)(a[i])*b.a[j])%mod;      
                return c; 
            }
            int n=1;    
            while(n<(len+b.len)) n<<=1; 
            memset(A,0,n<<2);  
            memset(B,0,n<<2);   
            memcpy(A,a,len<<2);                             
            memcpy(B,b.a,b.len<<2);            
            ntt_init(n);        
            NTT(A,n,1), NTT(B,n,1);     
            for(int i=0;i<n;++i) A[i]=(ll)A[i]*B[i]%mod;   
            NTT(A,n,-1);   
            memcpy(c.a,A,c.len<<2);  
            return c;       
        }    
        poly operator + (const poly &b) const 
        {
            poly c(max(len,b.len));    
            for(int i=0;i<c.len;++i)  c.a[i]=((i<len?a[i]:0)+(i<b.len?b.a[i]:0))%mod;   
            return c;    
        }
        poly operator - (const poly &b) const 
        {    
            poly c(len);       
            for(int i=0;i<len;++i)   
            {
                if(i>=b.len)   c.a[i]=a[i];  
                else c.a[i]=(a[i]-b.a[i]+mod)%mod;    
            } 
            return c;  
        }
        poly operator / (poly u) 
        {  
            int n=len,m=u.len,l=1;  
            while(l<(n-m+1)) l<<=1;                           
            rev(),u.rev();            
            poly v=u.Inv(l);    
            v.get_mod(n-m+1);        
            poly re=(*this)*v;   
            rev(),u.rev();    
            re.get_mod(n-m+1);         
            re.rev();  
            return re;   
        }      
        poly operator % (poly u) 
        {      
            poly re=(*this)-u*(*this/u);        
            re.get_mod(u.len-1);       
            return re;    
        }                     
    }p[N<<2],pr;    
    int xx[N],yy[N];              
    #define lson now<<1  
    #define rson now<<1|1           
    inline void pushup(int l,int r,int now)
    {
        int mid=(l+r)>>1;      
        if(r>mid)   p[now]=p[lson]*p[rson]; 
        else p[now]=p[lson];   
    }
    void build(int l,int r,int now,int *pp) 
    {
        if(l==r) 
        {     
            p[now].fix(2);  
            p[now].a[0]=mod-pp[l];  
            p[now].a[1]=1;   
            return; 
        }  
        int mid=(l+r)>>1;   
        if(l<=mid)  build(l,mid,lson,pp);     
        if(r>mid)   build(mid+1,r,rson,pp);          
        p[now]=p[lson]*p[rson];   
    }    
    void get_val(int l,int r,int now,poly b,int *pp,int *t) 
    {
        if(b.len<=500)     
        {   
            for(int i=l;i<=r;++i) 
            {
                ull s=0;             
                for(int j=b.len-1;j>=0;--j)     
                {
                    s=((ull)s*pp[i]+b.a[j])%mod;  
                    if(!(j&7))   s%=mod;       
                }
                t[i]=s%mod;   
            }
            return;  
        } 
        int mid=(l+r)>>1;     
        if(l<=mid)   get_val(l,mid,lson,b%p[lson],pp,t);  
        if(r>mid)    get_val(mid+1,r,rson,b%p[rson],pp,t);     
    }   
    poly solve_polate(int l,int r,int now,int *t) 
    {
        if(l==r) 
        {
            poly re(1);   
            re.a[0]=t[l];   
            return re;   
        } 
        int mid=(l+r)>>1;    
        poly L,R;  
        L=solve_polate(l,mid,lson,t);   
        R=solve_polate(mid+1,r,rson,t);   
        return L*p[rson]+R*p[lson];           
    }       
    int main() 
    {   
        int i,j,n,m,l; 
        n=rd(),m=rd();              
        pr.fix(n+1);         
        static int pp[N];    
        for(i=0;i<=n;++i)   pr.a[i]=rd();   
        for(i=1;i<=m;++i)   pp[i]=rd();   
        build(1,m,1,pp);   
        get_val(1,m,1,pr,pp,pp);                                         
        for(i=1;i<=m;++i)   printf("%d
    ",pp[i]);                
        return 0;       
    }    
    

      

  • 相关阅读:
    BZOJ 1191 HNOI2006 超级英雄hero
    BZOJ 2442 Usaco2011 Open 修建草坪
    BZOJ 1812 IOI 2005 riv
    OJ 1159 holiday
    BZOJ 1491 NOI 2007 社交网络
    NOIP2014 D1 T3
    BZOJ 2423 HAOI 2010 最长公共子序列
    LCA模板
    NOIP 2015 D1T2信息传递
    数据结构
  • 原文地址:https://www.cnblogs.com/guangheli/p/11928630.html
Copyright © 2011-2022 走看看