zoukankan      html  css  js  c++  java
  • 【洛谷P5158】 【模板】多项式快速插值

    卡常严重,可有采用如下优化方案: 

    1.预处理单位根 

    2.少取几次模 

    3.复制数组时用 memcpy     

    4.进行多项式乘法项数少的时候直接暴力乘 

    5.进行多项式多点求值时如果项数小于500的话直接秦九昭展开

    code: 

    #include <bits/stdc++.h>     
    #define ll long long 
    #define ull unsigned long long 
    #define setIO(s) freopen(s".in","r",stdin) // , freopen(s".out","w",stdout)        
    using namespace std;  
    char buf[100000],*p1,*p2;
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
    int rd() 
    {
        int x=0; char s=nc();
        while(s<'0') s=nc();
        while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc();
        return x;
    }         
    void print(int x) {if(x>=10) print(x/10);putchar(x%10+'0');}
    const int G=3;  
    const int N=2000005;   
    const int mod=998244353;                   
    int A[N],B[N],w[2][N],mem[N*100],*ptr=mem;      
    inline int qpow(int x,int y) 
    {
        int tmp=1;     
        for(;y;y>>=1,x=(ll)x*x%mod)     if(y&1) tmp=(ll)tmp*x%mod;  
        return tmp;    
    }      
    inline int INV(int a) { return qpow(a,mod-2); }        
    inline void ntt_init(int len) 
    {
        int i,j,k,mid,x,y;      
        w[1][0]=w[0][0]=1,x=qpow(3,(mod-1)/len),y=qpow(x,mod-2);
        for (i=1;i<len;++i) w[0][i]=(ll)w[0][i-1]*x%mod,w[1][i]=(ll)w[1][i-1]*y%mod;         
    }
    void NTT(int *a,int len,int flag) 
    {
        int i,j,k,mid,x,y;                
        for(i=k=0;i<len;++i) 
        {
            if(i>k)    swap(a[i],a[k]);  
            for(j=len>>1;(k^=j)<j;j>>=1);  
        }   
        for(mid=1;mid<len;mid<<=1)            
            for(i=0;i<len;i+=mid<<1) 
                for(j=0;j<mid;++j)          
                {
                    x=a[i+j], y=(ll)w[flag==-1][len/(mid<<1)*j]*a[i+j+mid]%mod;  
                    a[i+j]=(x+y)%mod;  
                    a[i+j+mid]=(x-y+mod)%mod;   
                }   
        if(flag==-1)  
        {
            int rev=INV(len);   
            for(i=0;i<len;++i)    a[i]=(ll)a[i]*rev%mod;   
        }
    }              
    inline void getinv(int *a,int *b,int len,int la) 
    {
        if(len==1) { b[0]=INV(a[0]);   return; }
        getinv(a,b,len>>1,la);    
        int l=len<<1,i;   
        memset(A,0,l*sizeof(A[0]));              
        memset(B,0,l*sizeof(A[0]));    
        memcpy(A,a,min(la,len)*sizeof(a[0]));                                                      
        memcpy(B,b,len*sizeof(b[0]));             
        ntt_init(l);   
        NTT(A,l,1),NTT(B,l,1);      
        for(i=0;i<l;++i)  A[i]=((ll)2-(ll)A[i]*B[i]%mod+mod)*B[i]%mod;
        NTT(A,l,-1);                                 
        memcpy(b,A,len<<2);          
    }  
    struct poly 
    {
        int len,*a;    
        poly(){}       
        poly(int l) {len=l,a=ptr,ptr+=l; }            
        inline void rev() { reverse(a,a+len); }       
        inline void fix(int l) {len=l,a=ptr,ptr+=l;}   
        inline void get_mod(int l) { for(int i=l;i<len;++i) a[i]=0;  len=l;  }
        inline poly dao() 
        {        
            poly re(len-1);   
            for(int i=1;i<len;++i)  re.a[i-1]=(ll)i*a[i]%mod;         
            return re;    
        }    
        inline poly Inv(int l) 
        {  
            poly b(l);              
            getinv(a,b.a,l,len);                                  
            return b;                        
        }                                                                    
        inline poly operator * (const poly &b) const 
        {
            poly c(len+b.len-1);   
            if(c.len<=500) 
            {         
                for(int i=0;i<len;++i)   
                    if(a[i])   for(int j=0;j<b.len;++j)  c.a[i+j]=(c.a[i+j]+(ll)(a[i])*b.a[j])%mod;      
                return c; 
            }
            int n=1;    
            while(n<(len+b.len)) n<<=1; 
            memset(A,0,n<<2);  
            memset(B,0,n<<2);   
            memcpy(A,a,len<<2);                             
            memcpy(B,b.a,b.len<<2);            
            ntt_init(n);        
            NTT(A,n,1), NTT(B,n,1);     
            for(int i=0;i<n;++i) A[i]=(ll)A[i]*B[i]%mod;   
            NTT(A,n,-1);   
            memcpy(c.a,A,c.len<<2);  
            return c;       
        }    
        poly operator + (const poly &b) const 
        {
            poly c(max(len,b.len));    
            for(int i=0;i<c.len;++i)  c.a[i]=((i<len?a[i]:0)+(i<b.len?b.a[i]:0))%mod;   
            return c;    
        }
        poly operator - (const poly &b) const 
        {    
            poly c(len);       
            for(int i=0;i<len;++i)   
            {
                if(i>=b.len)   c.a[i]=a[i];  
                else c.a[i]=(a[i]-b.a[i]+mod)%mod;    
            } 
            return c;  
        }
        poly operator / (poly u) 
        {  
            int n=len,m=u.len,l=1;  
            while(l<(n-m+1)) l<<=1;                           
            rev(),u.rev();            
            poly v=u.Inv(l);    
            v.get_mod(n-m+1);        
            poly re=(*this)*v;   
            rev(),u.rev();    
            re.get_mod(n-m+1);         
            re.rev();  
            return re;   
        }      
        poly operator % (poly u) 
        {      
            poly re=(*this)-u*(*this/u);        
            re.get_mod(u.len-1);       
            return re;    
        }                     
    }p[N<<2],pr;    
    int xx[N],yy[N];              
    #define lson now<<1  
    #define rson now<<1|1           
    inline void pushup(int l,int r,int now)
    {
        int mid=(l+r)>>1;      
        if(r>mid)   p[now]=p[lson]*p[rson]; 
        else p[now]=p[lson];   
    }
    void build(int l,int r,int now,int *pp) 
    {
        if(l==r) 
        {     
            p[now].fix(2);  
            p[now].a[0]=mod-pp[l];  
            p[now].a[1]=1;   
            return; 
        }  
        int mid=(l+r)>>1;   
        if(l<=mid)  build(l,mid,lson,pp);     
        if(r>mid)   build(mid+1,r,rson,pp);          
        p[now]=p[lson]*p[rson];   
    }    
    void get_val(int l,int r,int now,poly b,int *pp,int *t) 
    {
        if(b.len<=500)     
        {   
            for(int i=l;i<=r;++i) 
            {
                ull s=0;             
                for(int j=b.len-1;j>=0;--j)     
                {
                    s=((ull)s*pp[i]+b.a[j])%mod;  
                    if(!(j&7))   s%=mod;       
                }
                t[i]=s%mod;   
            }
            return;  
        } 
        int mid=(l+r)>>1;     
        if(l<=mid)   get_val(l,mid,lson,b%p[lson],pp,t);  
        if(r>mid)    get_val(mid+1,r,rson,b%p[rson],pp,t);     
    }   
    poly solve_polate(int l,int r,int now,int *t) 
    {
        if(l==r) 
        {
            poly re(1);   
            re.a[0]=t[l];   
            return re;   
        } 
        int mid=(l+r)>>1;    
        poly L,R;  
        L=solve_polate(l,mid,lson,t);   
        R=solve_polate(mid+1,r,rson,t);   
        return L*p[rson]+R*p[lson];           
    }       
    void check_Interpolate();  
    poly Interpolate(int *a,int *b,int n);       
    void check_Evaluation();  
    void check_Inv();   
    void check_mult();   
    void check_divide();       
    poly Interpolate(int *a,int *b,int n) 
    { 
        int i,j;   
        build(1,n,1,a);        
        static int t[N];  
        poly tmp=p[1].dao();          
        get_val(1,n,1,tmp,a,t);                            
        for(i=1;i<=n;++i)    t[i]=(ll)INV(t[i])*b[i]%mod;                       
        return solve_polate(1,n,1,t);    
    }
    void check_Interpolate() 
    {
        // setIO("input");        
        int i,j,n; 
        n=rd(); 
        for(i=1;i<=n;++i)      xx[i]=rd(),yy[i]=rd(); 
        poly re=Interpolate(xx,yy,n);                 
        for(i=0;i<re.len;++i)       print(re.a[i]), printf(" "); 
        for(;i<n;++i)    print(re.a[i]), printf(" ");      
    }
    void check_Evaluation() 
    {   
        int i,j,n,m,l; 
        n=rd(),m=rd();              
        pr.fix(n+1);         
        static int pp[N];    
        for(i=0;i<=n;++i)   pr.a[i]=rd();   
        for(i=1;i<=m;++i)   pp[i]=rd();   
        build(1,m,1,pp);   
        get_val(1,m,1,pr,pp,pp);                                         
        for(i=1;i<=m;++i)   printf("%d
    ",pp[i]);                    
    }
    void check_Inv() 
    {
        int i,j,n; 
        scanf("%d",&n);    
        pr.fix(n);   
        for(i=0;i<n;++i)   scanf("%d",&pr.a[i]);      
        int l=1; 
        while(l<n)  l<<=1;  
        pr=pr.Inv(l);   
        for(i=0;i<n;++i)   printf("%d ",pr.a[i]);      
    }
    void check_mult() 
    {
        int i,j,n,m; 
        scanf("%d%d",&n,&m);  
        poly a(n+1),b(m+1);  
        for(i=0;i<=n;++i)   scanf("%d",&a.a[i]); 
        for(i=0;i<=m;++i)   scanf("%d",&b.a[i]); 
        a=a*b;  
        for(i=0;i<a.len;++i)   printf("%d ",a.a[i]); 
    }
    void check_divide() 
    {
        int i,j,n,m;   
        scanf("%d%d",&n,&m);    
        poly F(n+1), G(m+1);   
        for(i=0;i<=n;++i)    scanf("%d",&F.a[i]);  
        for(i=0;i<=m;++i)    scanf("%d",&G.a[i]);  
        poly Q=F/G;  
        poly R=F%G;  
        for(i=0;i<Q.len;++i)    printf("%d ",Q.a[i]);    
        printf("
    ");   
        for(i=0;i<R.len;++i)    printf("%d ",R.a[i]);   
    }    
    

      

  • 相关阅读:
    MySQL数据表类型 = 存储引擎类型
    删除链表节点
    链表逆序(反转)
    腾讯2012笔试题
    MysqL数据表类型
    进程间的通信方式
    网络套接字编程学习笔记一
    HTTP报头
    C语言排序算法
    交换排序经典的冒泡排序算法总结
  • 原文地址:https://www.cnblogs.com/guangheli/p/11928600.html
Copyright © 2011-2022 走看看