zoukankan      html  css  js  c++  java
  • 多项式全家桶

    包括NTT模数和非NTT模数。

    如果有锅/可以卡常的地方欢迎评论区指出,会在注释里鸣谢

    UPD on 2020/5/24 11:??:加上了多项式快速幂。

    UPD on 2020/5/24 19:08:加上了多项式除法并修了快速幂的一个锅。

    UPD on 2020/5/30 00:00:修了几个锅。

    UPD on 2020/7/15 20:30:修改了数组大小以免溢出,并同时将inv[0]和inv[1]的初始化移到了prep函数里面。

    UPD on 2020/7/16 23:15:修改了prep函数,这样可以返回lmt的值,并修改了排版。

    UPD on 2021/1/5 20:40:改了若干bug。

    UPD on 2021/1/6 22:40:新增 CZT。

    UPD on 2021/1/8 21:10:新增 FDT(下降幂多项式乘法),同时修改了 prep 函数,从而预计算出阶乘和阶乘逆元。

    UPD on 2021/1/14 13:35:新增 多点求值 并修改了 多项式除法。

    UPD on 2021/1/14 16:09:新增了 普通多项式转下降幂多项式 并对原先的 多点求值 进行卡常。

    UPD on 2021/1/14 23:30:新增了多项式快速插值。

    求评论区提一些有效的建议。

    P.S. 由于一些不可抗力(部分缩进是 4 个空格,部分是 tab),直接食用会造成不适,请复制到 tab 长度为 4 的环境下使用。

    Code(巨长代码警告
    #ifndef __POLY_H__
    #define __POLY_H__
    #include<bits/stdc++.h>
    #define clear(a) memset((a),0,len<<5)
    using namespace std;
    typedef long long ll;
    const ll N=1048576,P=998244353;
    const long double Pi=acos(-1.0);
    ll inv[N],fac[N],invfac[N];
    namespace Poly{//模数为NTT模数 
        const ll G=3,img=86583718;
        ll lmt,rev[N],a[N],b[N],c[N],d[N],e[N],h[N],x[N],y[N],z[N],X[N],Y[N],ff[N],gg[N],iv[N],t[N];//poly1
    	ll A[N],B[N],ee[N],Len[N],*p[N],C[N],v[N],*D[N],E[N];//poly2
        inline ll qpow(ll a,ll k){
            ll ret=1;
            while(k){
                if(k&1)ret=ret*a%P;
                a=(a*a)%P;
                k>>=1;
            }
            return ret%P;
        }
        inline void init(ll n){
            lmt=1;ll t=0;
            while(lmt<n)lmt<<=1,t++;
            for(ll i=1;i<lmt;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(t-1);
        }
        inline void NTT(ll *A,ll lmt,ll tp){
            for(ll i=0;i<lmt;i++)if(i<rev[i])swap(A[i],A[rev[i]]);
            for(ll m=1;m<lmt;m<<=1)
                for(ll j=0,Wn=qpow(G,(P-1)/(m<<1));j<lmt;j+=m<<1)
                    for(ll k=0,w=1,x,y;k<m;k++,w=w*Wn%P)
                        x=A[j+k],y=w*A[j+k+m]%P,A[j+k]=(x+y)%P,A[j+k+m]=(x-y+P)%P;
            if(tp==1)return;
            reverse(A+1,A+lmt);
            for(ll i=0,inv=qpow(lmt,P-2);i<=lmt;i++)A[i]=A[i]*inv%P;
        } 
        inline void mul(ll *f,ll *g,ll len){
            init(len);
            NTT(f,lmt,1);NTT(g,lmt,1);
            for(ll i=0;i<lmt;i++)f[i]=(f[i]*g[i])%P;
            NTT(f,lmt,-1);
        } 
        void getinv(ll*f,ll*g,ll len){
            if(len==1){g[0]=qpow(f[0],P-2);return;}
            getinv(f,g,len+1>>1);
            init(len<<1);
            for(ll i=0;i<len;i++)c[i]=f[i];
            for(ll i=len;i<lmt;i++)c[i]=0;
            NTT(c,lmt,1),NTT(g,lmt,1);
            for(ll i=0;i<lmt;i++)g[i]=(2LL-g[i]*c[i]%P+P)%P*g[i]%P;
            NTT(g,lmt,-1);
            for(ll i=len;i<lmt;i++)g[i]=0; 
            clear(c);
        }
        inline void div(ll *f,ll *g,ll *q,ll *r,ll n,ll m){
            for(ll i=0,t=n-1;i<n;i++,t--)ff[i]=f[t];
            for(ll i=0,t=m-1;i<m;i++,t--)gg[i]=g[t];
            ll len=n-m+1;
            for(ll i=len;i<n;i++)ff[i]=gg[i]=0;
            getinv(gg,iv,len);
            mul(ff,iv,len<<1);
            for(ll i=0,t=len-1;i<len;i++)q[i]=ff[t--];
            for(ll i=len;i<n;i++)q[i]=0;
            for(ll i=0;i<n;i++)t[i]=q[i];
            len=n;
            clear(gg);
            for(ll i=0;i<m;i++)gg[i]=g[i];
            mul(t,gg,n<<1);
            for(ll i=0;i<m-1;i++)r[i]=(f[i]-t[i]+P)%P;
            clear(ff),clear(gg),clear(iv),clear(t);
        }
        inline void getdev(ll*f,ll*g,ll len){
            for(ll i=1;i<len;i++)g[i-1]=i*f[i]%P;
            g[len-1]=g[len]=0;
        }
        inline void getinvdev(ll*f,ll*g,ll len){
            for(ll i=1;i<=len;i++)g[i]=f[i-1]*inv[i]%P;
            g[0]=0;
        }
        inline void getln(ll*f,ll*g,ll len){
            getdev(f,a,len);
            getinv(f,b,len);
            mul(a,b,len<<1);
            getinvdev(a,g,len);
            clear(a),clear(b);
        }
        void getexp(ll*f,ll*g,ll len){
            if(len==1){g[0]=1;return;}
            getexp(f,g,len+1>>1);
            init(len<<1);
            for(ll i=0;i<(len<<1);i++)d[i]=e[i]=0;
            getln(g,d,len);
            for(ll i=0;i<len;i++)e[i]=f[i];
            NTT(g,lmt,1),NTT(d,lmt,1),NTT(e,lmt,1);
            for(ll i=0;i<lmt;i++)g[i]=(1-d[i]+e[i]+P)*g[i]%P;
            NTT(g,lmt,-1);
            for(ll i=len;i<lmt;i++)g[i]=0; 
            clear(d),clear(e);
        }
    	void getpow(ll*f,ll*g,ll len,ll k){
            getln(f,h,len);
            for(ll i=0;i<len;i++)h[i]=h[i]*k%P;
            getexp(h,g,len);
            clear(h);
        }
        inline void getsqrt(ll*f,ll*g,ll len){
            getln(f,h,len);
            for(ll i=0;i<len;i++)h[i]=h[i]*inv[2]%P;
            getexp(h,g,len);
            clear(h);
        }
        void sin(ll*f,ll*g,ll len){
            for(ll i=0;i<len;i++)x[i]=img*f[i]%P;
            getexp(x,X,len),getinv(X,Y,len);
            for(ll i=0;i<len;i++)g[i]=(X[i]-Y[i]+P)%P*qpow(img<<1,P-2)%P;
            clear(x),clear(X),clear(Y);
        }
        void cos(ll*f,ll*g,ll len){
            for(ll i=0;i<len;i++)x[i]=img*f[i]%P;
            getexp(x,X,len),getinv(X,Y,len);
            for(ll i=0;i<len;i++)g[i]=(X[i]+Y[i])%P*inv[2]%P;
            clear(x),clear(X),clear(Y);
        } 
        inline void arcsin(ll*f,ll*g,ll len){
            getdev(f,x,len);
            init(len<<1);
            NTT(f,lmt,1);
            for(ll i=0;i<lmt;i++)y[i]=(1+P-f[i]*f[i]%P)%P;
            NTT(y,lmt,-1);
            for(ll i=len;i<lmt;i++)y[i]=0;
            getsqrt(y,z,len);
            memset(y,0,(len+1)<<3);
            getinv(z,y,len);
            NTT(x,lmt,1),NTT(y,lmt,1);
            for(ll i=0;i<lmt;i++)x[i]=x[i]*y[i]%P;
            NTT(x,lmt,-1);
            getinvdev(x,g,len);
            clear(x),clear(y),clear(z);
        }
        inline void arctan(ll*f,ll*g,ll len){
            getdev(f,x,len);
            init(len<<1);
            NTT(f,lmt,1);
            for(ll i=0;i<lmt;i++)y[i]=(1+f[i]*f[i]%P)%P;
            NTT(y,lmt,-1);
            for(ll i=len;i<lmt;i++)y[i]=0;
            getinv(y,z,len);
            NTT(x,lmt,1),NTT(z,lmt,1);
            for(ll i=0;i<lmt;i++)x[i]=x[i]*z[i]%P;
            NTT(x,lmt,-1);
            getinvdev(x,g,len);
            clear(x),clear(y),clear(z);
        }
        inline ll F(ll x){return x*(x-1)/2%(P-1);}
    	inline void CZT(ll *f,ll *g,ll len,ll c,ll m){
        	for(ll i=0;i<len;i++)A[i]=qpow(c,P-1-F(i))*f[i]%P;
    		for(ll i=0;i<len+m;i++)B[i]=qpow(c,F(i));
        	reverse(A,A+len);
        	mul(A,B,len*2+m);
            for(ll i=0;i<m;i++)g[i]=qpow(c,P-1-F(i))*A[i+len-1]%P;
        	clear(A),clear(B);
        }
        void FDT(ll *A,ll len,ll tp){
        	init(len<<1);
        	if(tp==-1)for(ll i=0;i<lmt;i++)A[i]=A[i]*invfac[i]%P;
        	for(ll i=0;i<len;i++){
        		if(tp==-1&&i&1)ee[i]=P-invfac[i];
        		else ee[i]=invfac[i];
    		}
    		for(ll i=len;i<lmt;i++)ee[i]=A[i]=0;
    		NTT(A,lmt,1);NTT(ee,lmt,1);
    		for(ll i=0;i<lmt;i++)A[i]=A[i]*ee[i]%P;
    		NTT(A,lmt,-1);
    		if(tp==1)for(ll i=0;i<lmt;i++)A[i]=A[i]*fac[i]%P;
    		for(ll i=0;i<lmt;i++)ee[i]=0;
    	}
        inline void mulDown(ll *f,ll *g,ll len){
        	FDT(f,len,1);FDT(g,len,1);
        	for(ll i=0;i<len;i++)f[i]=f[i]*g[i]%P;
        	FDT(f,len,-1);
    	}
    	void getP(const ll *a,ll k,ll l,ll r){
        	if(l==r){
        		Len[k]=1;
        		p[k]=new ll[2];
        		p[k][0]=P-a[l];
        		p[k][1]=1;
        		return;
    		}
    		ll mid=l+r>>1;
    		getP(a,k<<1,l,mid);
    		getP(a,k<<1|1,mid+1,r);
    		Len[k]=Len[k<<1]+Len[k<<1|1];
    		p[k]=new ll[Len[k]+1];
    		init(Len[k]+1<<1);
    		static ll A[N],B[N];
    		for(ll i=0;i<=Len[k<<1];i++)A[i]=p[k<<1][i];
    		for(ll i=Len[k<<1]+1;i<lmt;i++)A[i]=0;
    		for(ll i=0;i<=Len[k<<1|1];i++)B[i]=p[k<<1|1][i];
    		for(ll i=Len[k<<1|1]+1;i<lmt;i++)B[i]=0;
    		NTT(A,lmt,1);NTT(B,lmt,1);
    		for(ll i=0;i<lmt;i++)A[i]=A[i]*B[i]%P;
    		NTT(A,lmt,-1);
    		for(ll i=0;i<=Len[k];i++)p[k][i]=A[i];
    	}
    	void solve(ll k,ll l,ll r,const ll *a,ll *A,ll *ans){
    		if(Len[k]<=500){
    			ll m=Len[k]-1;
    			for(ll i=l;i<=r;i++)
    				for(ll j=m;j>=0;j--)
    					ans[i]=(ans[i]*a[i]+A[j])%P;
    			return;
    		}
    		if(l==r){ans[l]=*A;return;}
    		ll mid=l+r>>1,R[Len[k]+2>>1];
    		static ll t[N];
    		div(A,p[k<<1],t,R,Len[k],Len[k<<1]+1);
    		solve(k<<1,l,mid,a,R,ans);
    		div(A,p[k<<1|1],t,R,Len[k],Len[k<<1|1]+1); 
    		solve(k<<1|1,mid+1,r,a,R,ans);
    	}
    	void evaluation(ll *f,ll *a,ll *ans,ll n,ll m){
    		getP(a,1,1,m); 
    		if(n>m){
    			static ll t[N];
    			div(f,p[1],t,f,n,m+1);
    		}
    		solve(1,1,m,a,f,ans);
    	}
    	void solve(ll k,ll l,ll r,const ll *x){
    		if(l==r){
    			D[k]=new ll[1];
    			D[k][0]=v[l];
    			return;
    		}
    		ll mid=l+r>>1;
    		solve(k<<1,l,mid,x);
    		solve(k<<1|1,mid+1,r,x);
    		D[k]=new ll[Len[k]];
    		init(Len[k]);
    		static ll f1[N],f2[N],p1[N],p2[N];
    		for(ll i=0;i<Len[k<<1];i++)f1[i]=D[k<<1][i];
    		for(ll i=Len[k<<1];i<lmt;i++)f1[i]=0;
    		for(ll i=0;i<Len[k<<1|1];i++)f2[i]=D[k<<1|1][i];
    		for(ll i=Len[k<<1|1];i<lmt;i++)f2[i]=0;
    		for(ll i=0;i<=Len[k<<1];i++)p1[i]=p[k<<1][i];
    		for(ll i=Len[k<<1]+1;i<lmt;i++)p1[i]=0;
    		for(ll i=0;i<=Len[k<<1|1];i++)p2[i]=p[k<<1|1][i];
    		for(ll i=Len[k<<1|1]+1;i<lmt;i++)p2[i]=0;
    		mul(f1,p2,Len[k]);
    		mul(f2,p1,Len[k]);
    		for(ll i=0;i<Len[k];i++)D[k][i]=(f1[i]+f2[i])%P;
    	}
    	void interpolation(ll *x,ll *y,ll *f,ll n){
    		ll len=n;
    		getP(x,1,1,n);
    		getdev(p[1],C,n+1);
    		solve(1,1,n,x,C,v);
    		for(ll i=1;i<=n;i++)v[i]=y[i]*qpow(v[i],P-2)%P;
    		solve(1,1,n,x);
    		for(ll i=0;i<n;i++)f[i]=D[1][i];
    		clear(v);
    	}
    	void polytoffp(ll *f,ll *g,ll len){
    		for(ll i=1;i<=len;i++)E[i]=i-1;
    		clear(g);
    		evaluation(f,E,g,len,len);
    		for(ll i=0;i<len;i++)g[i]=g[i+1]*invfac[i],E[i]=(i&1?P-invfac[i]:invfac[i]);
    		E[len]=g[len]=0;
    		mul(g,E,len<<1);
    		clear(E);
    	}
    }
    ll prep(ll n){
        ll lmt=1;
        inv[0]=inv[1]=1;
        while(lmt<n)lmt<<=1;
        for(ll i=2;i<lmt;i++)inv[i]=(P-P/i)*inv[P%i]%P;
        fac[0]=invfac[0]=1;
        for(ll i=1;i<lmt;i++)fac[i]=fac[i-1]*i%P,invfac[i]=invfac[i-1]*inv[i]%P;
        return lmt;
    }
    namespace Poly2{//模数不是NTT模数 
        int lmt,rev[N];
        struct comp{
            long double x,y;
            comp(long double a=0,long double b=0){x=a,y=b;}
        }a[N],b[N],c[N],d[N];
        comp operator+(comp a,comp b){return comp(a.x+b.x,a.y+b.y);}
        comp operator-(comp a,comp b){return comp(a.x-b.x,a.y-b.y);}
        comp operator*(comp a,comp b){return comp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
        comp operator/(comp a,int t){return comp(a.x/t,a.y/t);}
        inline void init(int n){
            lmt=1;int t=0;
            while(lmt<n)lmt<<=1,t++;
            for(int i=1;i<lmt;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(t-1);
        }
        inline void FFT(comp*A,int lmt,int tp){
            for(int i=0;i<lmt;i++)if(i<rev[i])swap(A[i],A[rev[i]]);
            for(int mid=1;mid<lmt;mid<<=1){
                comp Wn(cos(Pi/mid),tp*sin(Pi/mid));
                for(int R=mid<<1,j=0;j<lmt;j+=R){
                    comp w(1,0);
                    for(int k=0;k<mid;k++,w=w*Wn){
                        comp x=A[j+k],y=w*A[j+k+mid];
                        A[j+k]=x+y,A[j+k+mid]=x-y;
                    }
                }
            }
        }
        void MTT(int*f,int*g,int*ans,int n,int m){
            init(n+m);
            const int lim=(1<<15)-1;
            for(int i=0;i<n;i++)a[i]=comp(f[i]&lim,f[i]>>15);
            for(int i=n;i<lmt;i++)a[i]=comp();
            for(int i=0;i<m;i++)b[i]=comp(g[i]&lim,g[i]>>15);
            for(int i=m;i<lmt;i++)b[i]=comp();
            FFT(a,lmt,1),FFT(b,lmt,1);
            for(int i=0;i<lmt;i++){
                int t=(lmt-i)&(lmt-1);
                c[i]=comp((a[i].x+a[t].x)*0.5,(a[i].y-a[t].y)*0.5)*b[i];
                d[i]=comp((a[i].y+a[t].y)*0.5,(a[t].x-a[i].x)*0.5)*b[i];
            }
            FFT(c,lmt,-1),FFT(d,lmt,-1);
            for(int i=0;i<lmt;i++)c[i]=c[i]/lmt,d[i]=d[i]/lmt;
            for(int i=0;i<lmt;i++){
                ll p=c[i].x+0.5,o=c[i].y+0.5,x=d[i].x+0.5,u=d[i].y+0.5;
                ans[i]=(p%P+((o+x)%P<<15)+(u%P<<30))%P;
            }
        }
    }
    #endif
    
  • 相关阅读:
    H5 _拖放使用
    CSS _text-align:justify;实现两端对齐
    Tips_钉钉免登前端实现
    快速组建的开发团队要怎么活下来?
    程序员,你的安全感呢?
    从自我驱动到带领10人团队
    你会给别人提反馈吗?
    简单几步成为微信公众平台开发者
    你了解javascript中的function吗?(1)
    容器之路 HashMap、HashSet解析(一)
  • 原文地址:https://www.cnblogs.com/happydef/p/14277856.html
Copyright © 2011-2022 走看看