zoukankan      html  css  js  c++  java
  • 【洛谷P5050】 【模板】多项式多点求值

    code: 

    #include <bits/stdc++.h>     
    #define ll long long 
    #define ull unsigned long long 
    #define setIO(s) freopen(s".in","r",stdin) // , freopen(s".out","w",stdout)        
    using namespace std;  
    char buf[100000],*p1,*p2;
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
    int rd() 
    {
        int x=0; char s=nc();
        while(s<'0') s=nc();
        while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc();
        return x;
    }         
    void print(int x) {if(x>=10) print(x/10);putchar(x%10+'0');}
    const int G=3;  
    const int N=2000005;   
    const int mod=998244353;                   
    int A[N],B[N],w[2][N],mem[N*100],*ptr=mem;      
    inline int qpow(int x,int y) 
    {
        int tmp=1;     
        for(;y;y>>=1,x=(ll)x*x%mod)     if(y&1) tmp=(ll)tmp*x%mod;  
        return tmp;    
    }      
    inline int INV(int a) { return qpow(a,mod-2); }        
    inline void ntt_init(int len) 
    {
        int i,j,k,mid,x,y;      
        w[1][0]=w[0][0]=1,x=qpow(3,(mod-1)/len),y=qpow(x,mod-2);
        for (i=1;i<len;++i) w[0][i]=(ll)w[0][i-1]*x%mod,w[1][i]=(ll)w[1][i-1]*y%mod;         
    }
    void NTT(int *a,int len,int flag) 
    {
        int i,j,k,mid,x,y;                
        for(i=k=0;i<len;++i) 
        {
            if(i>k)    swap(a[i],a[k]);  
            for(j=len>>1;(k^=j)<j;j>>=1);  
        }   
        for(mid=1;mid<len;mid<<=1)            
            for(i=0;i<len;i+=mid<<1) 
                for(j=0;j<mid;++j)          
                {
                    x=a[i+j], y=(ll)w[flag==-1][len/(mid<<1)*j]*a[i+j+mid]%mod;  
                    a[i+j]=(x+y)%mod;  
                    a[i+j+mid]=(x-y+mod)%mod;   
                }   
        if(flag==-1)  
        {
            int rev=INV(len);   
            for(i=0;i<len;++i)    a[i]=(ll)a[i]*rev%mod;   
        }
    }              
    inline void getinv(int *a,int *b,int len,int la) 
    {
        if(len==1) { b[0]=INV(a[0]);   return; }
        getinv(a,b,len>>1,la);    
        int l=len<<1,i;   
        memset(A,0,l*sizeof(A[0]));              
        memset(B,0,l*sizeof(A[0]));    
        memcpy(A,a,min(la,len)*sizeof(a[0]));                                                      
        memcpy(B,b,len*sizeof(b[0]));             
        ntt_init(l);   
        NTT(A,l,1),NTT(B,l,1);      
        for(i=0;i<l;++i)  A[i]=((ll)2-(ll)A[i]*B[i]%mod+mod)*B[i]%mod;
        NTT(A,l,-1);                                 
        memcpy(b,A,len<<2);          
    }  
    struct poly 
    {
        int len,*a;    
        poly(){}       
        poly(int l) {len=l,a=ptr,ptr+=l; }            
        inline void rev() { reverse(a,a+len); }       
        inline void fix(int l) {len=l,a=ptr,ptr+=l;}   
        inline void get_mod(int l) { for(int i=l;i<len;++i) a[i]=0;  len=l;  }
        inline poly dao() 
        {        
            poly re(len-1);   
            for(int i=1;i<len;++i)  re.a[i-1]=(ll)i*a[i]%mod;         
            return re;    
        }    
        inline poly Inv(int l) 
        {  
            poly b(l);              
            getinv(a,b.a,l,len);                                  
            return b;                        
        }                                                                    
        inline poly operator * (const poly &b) const 
        {
            poly c(len+b.len-1);   
            if(c.len<=500) 
            {         
                for(int i=0;i<len;++i)   
                    if(a[i])   for(int j=0;j<b.len;++j)  c.a[i+j]=(c.a[i+j]+(ll)(a[i])*b.a[j])%mod;      
                return c; 
            }
            int n=1;    
            while(n<(len+b.len)) n<<=1; 
            memset(A,0,n<<2);  
            memset(B,0,n<<2);   
            memcpy(A,a,len<<2);                             
            memcpy(B,b.a,b.len<<2);            
            ntt_init(n);        
            NTT(A,n,1), NTT(B,n,1);     
            for(int i=0;i<n;++i) A[i]=(ll)A[i]*B[i]%mod;   
            NTT(A,n,-1);   
            memcpy(c.a,A,c.len<<2);  
            return c;       
        }    
        poly operator + (const poly &b) const 
        {
            poly c(max(len,b.len));    
            for(int i=0;i<c.len;++i)  c.a[i]=((i<len?a[i]:0)+(i<b.len?b.a[i]:0))%mod;   
            return c;    
        }
        poly operator - (const poly &b) const 
        {    
            poly c(len);       
            for(int i=0;i<len;++i)   
            {
                if(i>=b.len)   c.a[i]=a[i];  
                else c.a[i]=(a[i]-b.a[i]+mod)%mod;    
            } 
            return c;  
        }
        poly operator / (poly u) 
        {  
            int n=len,m=u.len,l=1;  
            while(l<(n-m+1)) l<<=1;                           
            rev(),u.rev();            
            poly v=u.Inv(l);    
            v.get_mod(n-m+1);        
            poly re=(*this)*v;   
            rev(),u.rev();    
            re.get_mod(n-m+1);         
            re.rev();  
            return re;   
        }      
        poly operator % (poly u) 
        {      
            poly re=(*this)-u*(*this/u);        
            re.get_mod(u.len-1);       
            return re;    
        }                     
    }p[N<<2],pr;    
    int xx[N],yy[N];              
    #define lson now<<1  
    #define rson now<<1|1           
    inline void pushup(int l,int r,int now)
    {
        int mid=(l+r)>>1;      
        if(r>mid)   p[now]=p[lson]*p[rson]; 
        else p[now]=p[lson];   
    }
    void build(int l,int r,int now,int *pp) 
    {
        if(l==r) 
        {     
            p[now].fix(2);  
            p[now].a[0]=mod-pp[l];  
            p[now].a[1]=1;   
            return; 
        }  
        int mid=(l+r)>>1;   
        if(l<=mid)  build(l,mid,lson,pp);     
        if(r>mid)   build(mid+1,r,rson,pp);          
        p[now]=p[lson]*p[rson];   
    }    
    void get_val(int l,int r,int now,poly b,int *pp,int *t) 
    {
        if(b.len<=500)     
        {   
            for(int i=l;i<=r;++i) 
            {
                ull s=0;             
                for(int j=b.len-1;j>=0;--j)     
                {
                    s=((ull)s*pp[i]+b.a[j])%mod;  
                    if(!(j&7))   s%=mod;       
                }
                t[i]=s%mod;   
            }
            return;  
        } 
        int mid=(l+r)>>1;     
        if(l<=mid)   get_val(l,mid,lson,b%p[lson],pp,t);  
        if(r>mid)    get_val(mid+1,r,rson,b%p[rson],pp,t);     
    }   
    poly solve_polate(int l,int r,int now,int *t) 
    {
        if(l==r) 
        {
            poly re(1);   
            re.a[0]=t[l];   
            return re;   
        } 
        int mid=(l+r)>>1;    
        poly L,R;  
        L=solve_polate(l,mid,lson,t);   
        R=solve_polate(mid+1,r,rson,t);   
        return L*p[rson]+R*p[lson];           
    }       
    int main() 
    {   
        int i,j,n,m,l; 
        n=rd(),m=rd();              
        pr.fix(n+1);         
        static int pp[N];    
        for(i=0;i<=n;++i)   pr.a[i]=rd();   
        for(i=1;i<=m;++i)   pp[i]=rd();   
        build(1,m,1,pp);   
        get_val(1,m,1,pr,pp,pp);                                         
        for(i=1;i<=m;++i)   printf("%d
    ",pp[i]);                
        return 0;       
    }    
    

      

  • 相关阅读:
    Java线程专题 3:java内存模型
    Java线程专题 2:synchronized理解
    Java线程专题 1:线程创建
    设计模式七大原则
    JVM 运行时数据区
    css_selector定位,比xpath速度快,语法简洁
    xpath绝对定位和相对定位
    selenium多种定位
    操作浏览器基本元素(不定时更新)
    爬取网页图片并且下载(1)
  • 原文地址:https://www.cnblogs.com/guangheli/p/11928630.html
Copyright © 2011-2022 走看看