zoukankan      html  css  js  c++  java
  • 题解 P4512 【【模板】多项式除法】

    题目地址


    前言

    原理有大佬写了 

    所以蒟蒻只讲下本题的代码细节

    我看懂的大佬博客:博客地址

    因为可能知道了大致的步骤还有很多细的地方不理解
    导致写的时候要花很久
    并且看到大佬们好像都是用递归写的
    希望能有帮助吧


    背景

    由于我太菜了实在看不懂其他大佬的代码只能自己写
    于是因为很多的细节原因和并一些大佬的奇异写法误导调了N+个小时
    # 详细的地方还是看代码里面说明吧
    因为没怎么优化常数比较大吧
    有写代码是可以简化的

    #include<bits/stdc++.h>
    using namespace std;
    #define ll long long
    #define C getchar()-48
    inline ll read()
    {
        ll s=0,r=1;
        char c=C;
        for(;c<0||c>9;c=C) if(c==-3) r=-1;
        for(;c>=0&&c<=9;c=C) s=(s<<3)+(s<<1)+c;
        return s*r;
    }
    const ll p=998244353,G=3,N=400010;
    ll n,m;
    ll f[N],g[N],q[N],r[N],inv[N],rev[N],c[N];
    ll tmp1[N],tmp2[N];
    inline ll ksm(ll a,ll b)//..快速幂 
    {
        ll ans=1;
        while(b)
        {
            if(b&1) ans=(ans*a)%p;
            a=(a*a)%p;
            b>>=1;
        }
        return ans;
    }
    inline void ntt(ll *a,ll n,ll kd)//ntt日常操作 
    {
        for(int i=0;i<n;i++)
        if(i<rev[i])
          swap(a[i],a[rev[i]]);
        for(int i=1;i<n;i<<=1)
        {
            ll gn=ksm(G,(p-1)/(i<<1));
            for(int j=0;j<n;j+=(i<<1))
            {
                ll t1,t2,g=1;
                for(int k=0;k<i;k++,g=g*gn%p)
                {
                    t1=a[j+k],t2=g*a[j+k+i]%p;
                    a[j+k]=(t1+t2)%p,a[j+k+i]=(t1-t2+p)%p; 
                }
            }
        }
        if(kd==1) return;
        ll ny=ksm(n,p-2);
        reverse(a+1,a+n);
        for(int i=0;i<n;i++) a[i]=a[i]*ny%p;
    }
    inline void cl(ll *a,ll *b,ll n,ll m,ll len,ll w)//处理 
    {
        for(int i=0;i<len;i++) tmp1[i]=i<n?a[i]:0;//复制 清空多余//因为tmp被使用多遍 而在做ntt时 用的是长度为len的
        for(int i=0;i<len;i++) tmp2[i]=i<m?b[i]:0;//而有效的值只有它的得出的长度 后面其它的在模意义下都被清掉了 但之前在写的时候有的地方并没有清
                                                  //为了避免出错所以一定要清空 (在这个代码里)//..不要打成 i<n?tmp1[i]=a[i]:0;...只有像我这种蒟蒻才会犯这种错误吧 
        for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(w-1));//蝴蝶 
    }
    inline void polyinv(ll *a,ll *b,ll ed)//递推版  
    {
        b[0]=ksm(a[0],p-2);//设初始值  a*b=1(mod=x)b的值 
        for(int k=1,j=0;k<=(ed<<1);k<<=1,j++)//从两个长度为k的多项式a,b递推  
        {//!!因为这份代码的递推算的是 两个长度为a的多项式在模(m^k)次下的逆元
         //所以如果直接用ed为条件只会推出小于ed的逆元 所以ed要再乘一倍 
         //所以多项式总共需要的范围要为4倍所以N要4倍 
            ll len=k<<1;             //len 两式子计算后大小 
            cl(a,b,k,k,len,j+1);//处理//j+1 要看len调整 因为len乘上了一倍 所以j在处理时也要加上1 之前没有加被坑了好久 
            ntt(tmp1,len,1);ntt(tmp2,len,1);//注意不要直接用a,b算 因为ntt后原多项式的值会变 为了方便所以先复制一遍在用复制的多项式算 
            for(int i=0;i<len;i++) b[i]=tmp2[i]*(2ll-tmp1[i]*tmp2[i]%p+p)%p;//求逆
            ntt(b,len,-1);
            for(int i=k;i<len;i++) b[i]=0;//清空会被模的 //!!!不能删 因为下次递推是直接把0--len都作为有用的做下次运算了  
        }
    }
    inline void polymul(ll *a,ll *b,ll *c,ll n,ll m)//计算多项式相乘 
    {
        ll len=1,w=0;
        while(len<=(n+m)) len<<=1,w++;
        cl(a,b,n,m,len,w);//这里的次数(w)不用加1因为两者都是同不改变的 
        ntt(tmp1,len,1);ntt(tmp2,len,1);
        for(int i=0;i<len;i++) c[i]=tmp1[i]*tmp2[i]%p; 
        ntt(c,len,-1);
    }
    inline void work()  //f=q*g+r  ask q,r     f,g下标从0--n,0--m 
    {
    
        reverse(f,f+1+n);//对应各式的反转操作 
        reverse(g,g+1+m);
    
        polyinv(g,inv,n-m+1);//求逆  因为反转后使r能够被忽略所以是在模x^(n-m+1)意义下的, 所以只要算出g在模x^(n-m+1)下的逆元 
        polymul(f,inv,q,n+1,n-m+1);//计算q 
    
        reverse(q,q+n-m+1);//将原式反转回来 
        reverse(f,f+n+1);
        reverse(g,g+m+1);
    
        polymul(g,q,r,m+1,n-m+1);//计算q*g的值 
        for(int i=0;i<m;i++) r[i]=(f[i]-r[i]+p)%p;//相减算出r 
    }
    int main()//注意输入的多项式是 0--n 和0--m 不是长度为n,m; 
    {
        n=read(),m=read();    //读入 
        for(int i=0;i<=n;i++) f[i]=read();
        for(int i=0;i<=m;i++) g[i]=read();
        work();
        for(int i=0;i<=n-m;i++) printf("%lld ",q[i]);printf("
    ");//输出 
        for(int i=0;i<m;i++)    printf("%lld ",r[i]);
        return 0;
    }

    # 干净的代码

    #include<bits/stdc++.h>
    using namespace std;
    #define ll long long
    #define C getchar()-48
    inline ll read()
    {
        ll s=0,r=1;
        char c=C;
        for(;c<0||c>9;c=C) if(c==-3) r=-1;
        for(;c>=0&&c<=9;c=C) s=(s<<3)+(s<<1)+c;
        return s*r;
    }
    const ll p=998244353,G=3,N=400010;
    ll n,m;
    ll f[N],g[N],q[N],r[N],inv[N],rev[N],c[N];
    ll tmp1[N],tmp2[N];
    inline ll ksm(ll a,ll b)
    {
        ll ans=1;
        while(b)
        {
            if(b&1) ans=(ans*a)%p;
            a=(a*a)%p;
            b>>=1;
        }
        return ans;
    }
    inline void ntt(ll *a,ll n,ll kd)
    {
        for(int i=0;i<n;i++)
        if(i<rev[i])
          swap(a[i],a[rev[i]]);
        for(int i=1;i<n;i<<=1)
        {
            ll gn=ksm(G,(p-1)/(i<<1));
            for(int j=0;j<n;j+=(i<<1))
            {
                ll t1,t2,g=1;
                for(int k=0;k<i;k++,g=g*gn%p)
                {
                    t1=a[j+k],t2=g*a[j+k+i]%p;
                    a[j+k]=(t1+t2)%p,a[j+k+i]=(t1-t2+p)%p; 
                }
            }
        }
        if(kd==1) return;
        ll ny=ksm(n,p-2);
        reverse(a+1,a+n);
        for(int i=0;i<n;i++) a[i]=a[i]*ny%p;
    }
    inline void cl(ll *a,ll *b,ll n,ll m,ll len,ll w)
    {
        for(int i=0;i<len;i++) tmp1[i]=i<n?a[i]:0;
        for(int i=0;i<len;i++) tmp2[i]=i<m?b[i]:0;
        for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(w-1));
    }
    inline void polyinv(ll *a,ll *b,ll ed)
    {
        b[0]=ksm(a[0],p-2);
        for(int k=1,j=0;k<=(ed<<1);k<<=1,j++)
        {
            ll len=k<<1;          
            cl(a,b,k,k,len,j+1);
            ntt(tmp1,len,1);ntt(tmp2,len,1);
            for(int i=0;i<len;i++) b[i]=tmp2[i]*(2ll-tmp1[i]*tmp2[i]%p+p)%p;
            ntt(b,len,-1);
            for(int i=k;i<len;i++) b[i]=0;
        }
    }
    inline void polymul(ll *a,ll *b,ll *c,ll n,ll m) 
    {
        ll len=1,w=0;
        while(len<=(n+m)) len<<=1,w++;
        cl(a,b,n,m,len,w);
        ntt(tmp1,len,1);ntt(tmp2,len,1);
        for(int i=0;i<len;i++) c[i]=tmp1[i]*tmp2[i]%p; 
        ntt(c,len,-1);
    }
    inline void work() 
    {
    
        reverse(f,f+1+n);
        reverse(g,g+1+m);
    
        polyinv(g,inv,n-m+1);
        polymul(f,inv,q,n+1,n-m+1);
    
        reverse(q,q+n-m+1);
        reverse(f,f+n+1);
        reverse(g,g+m+1);
    
        polymul(g,q,r,m+1,n-m+1);
        for(int i=0;i<m;i++) r[i]=(f[i]-r[i]+p)%p;
    }
    int main()
    {
        n=read(),m=read();    
        for(int i=0;i<=n;i++) f[i]=read();
        for(int i=0;i<=m;i++) g[i]=read();
        work();
        for(int i=0;i<=n-m;i++) printf("%lld ",q[i]);printf("
    ");
        for(int i=0;i<m;i++)    printf("%lld ",r[i]);
        return 0;
    }
    代码
  • 相关阅读:
    波段是金牢记六大诀窍
    zk kafka mariadb scala flink integration
    Oracle 体系结构详解
    图解 Database Buffer Cache 内部原理(二)
    SQL Server 字符集介绍及修改方法演示
    SQL Server 2012 备份与还原详解
    SQL Server 2012 查询数据库中所有表的名称和行数
    SQL Server 2012 查询数据库中表格主键信息
    SQL Server 2012 查询数据库中所有表的索引信息
    图解 Database Buffer Cache 内部原理(一)
  • 原文地址:https://www.cnblogs.com/1436177712qqcom/p/10328663.html
Copyright © 2011-2022 走看看