zoukankan      html  css  js  c++  java
  • 多项式模板

    这里给出了一个多项式/形式幂级数的类的实现。由于仅考虑形式幂级数的前n项,即$!mod{x^n}$下的等价类,故其形式与多项式相同,因而在类的实现上没有对两者作区分。没有返回值的方法会将结果用于更新自身。对于所有用于形式幂级数的方法,参数n代表$!mod{x^n}$,且保证调用后前n项均可访问。三角函数使用了要求$sqrt{-1}$存在的实现。下表说明了各方法的复杂度。

    • der(求导)/pri(积分):$O(n)$
    • mul/inv/sqrt/log/exp/sin/cos/tan:$O(nlog n)$
    • div/mod:$O(nlog n)$
    • 在$m$个点处求值:$O(nlog n+mlog^2min(m,n))$
    • 以根集构造:$O(nlog^2n)$
    • 以点集构造(插值):$O(nlog^2n)$

    注意,对于某些算式,相比直接调用以上方法进行计算,使用针对性的实现能使效率大幅提高。


    #include<algorithm>
    #include<vector>
    #define RAN(a)a.begin(),a.end()
    using namespace std;
    typedef unsigned long long u64;
    typedef unsigned u32;
    namespace num{
    	const u32 p=998244353;
    	const u32 g=3;
    	inline u32 imod(u32 a){
    		return a<p?a:a-p;
    	}
    	inline u32 ipow(u32 a,u32 n){
    		u32 s=1;
    		for(;n;n>>=1){
    			if(n&1)
    				s=(u64)s*a%p;
    			a=(u64)a*a%p;
    		}
    		return s;
    	}
    	class inv_t{
    	public:
    		inv_t():f(1,1){}
    		u32 operator[](int n){
    			int m=f.size();
    			if(m<n){
    				f.resize(n);
    				for(int i=m+1;i<=n;++i)
    					f[i-1]=(u64)(p-p/i)*f[p%i-1]%p;
    			}
    			return f[n-1];
    		}
    		u32 operator()(u32 a)const{
    			return ipow(a,p-2);
    		}
    	private:
    		vector<u32>f;
    	}inv;
    }
    using namespace num;
    class poly{
    public:
    	vector<u32>a;
    	poly(){}
    	explicit poly(int n):a(n){}
    	u32&operator[](int i){
    		return a[i];
    	}
    	const u32&operator[](int i)const{
    		return a[i];
    	}
    	int size()const{
    		return a.size();
    	}
    	void swap(poly&b){
    		a.swap(b.a);
    	}
    	void der(){
    		fix();
    		if(size()){
    			for(int i=1;i<size();++i)
    				a[i-1]=(u64)i*a[i]%p;
    			a.pop_back();
    		}
    	}
    	void pri(){
    		fix();
    		shl();
    		for(int i=size()-1;i>0;--i)
    			a[i]=(u64)a[i]*num::inv[i]%p;
    	}
    	static int len(int n){
    		while(n^n&-n)
    			n+=n&-n;
    		return n;
    	}
    	void fft(int n,bool f){
    		a.resize(n);
    		if(n<=1)
    			return;
    		for(int i=0,j=0;i<n;++i){
    			if(i<j)
    				std::swap(a[i],a[j]);
    			int k=n>>1;
    			while((j^=k)<k)
    				k>>=1;
    		}
    		vector<u32>w(n/2);
    		w[0]=1;
    		for(int i=1;i<n;i<<=1){
    			for(int j=i/2-1;~j;--j)
    				w[j<<1]=w[j];
    			int m=(p-1)/i/2;
    			u64 s=ipow(g,f?p-1-m:m);
    			for(int j=1;j<i;j+=2)
    				w[j]=s*w[j-1]%p;
    			for(int j=0;j<n;j+=i<<1){
    				u32*b=&a[0]+j,*c=b+i;
    				for(int k=0;k<i;++k){
    					u32 v=(u64)w[k]*c[k]%p;
    					c[k]=imod(b[k]+p-v);
    					b[k]=imod(b[k]+v);
    				}
    			}
    		}
    	}
    	void dft(int n){
    		fft(n,0);
    	}
    	void idft(){
    		int n=size();
    		fft(n,1);
    		u64 f=num::inv(n);
    		for(int i=0;i<n;++i)
    			a[i]=f*a[i]%p;
    	}
    	void fix(){
    		while(size()&&!a.back())
    			a.pop_back();
    	}
    	void mul(poly b){
    		fix();
    		b.fix();
    		int n=len(size()+b.size()-1);
    		dft(n);
    		b.dft(n);
    		for(int i=0;i<n;++i)
    			a[i]=(u64)a[i]*b[i]%p;
    		idft();
    		fix();
    	}
    	void mod(int n){
    		a.resize(n);
    	}
    	void inv(int n){
    		int m=len(n);
    		mod(m);
    		vector<u32>b(1,num::inv(a[0]));
    		a.swap(b);
    		for(int i=2;i<=m;i<<=1){
    			int l=i<<1;
    			poly c(l);
    			for(int j=0;j<i;++j)
    				c[j]=b[j];
    			c.dft(l);
    			dft(l);
    			for(int j=0;j<l;++j)
    				a[j]=a[j]*(2+p-(u64)a[j]*c[j]%p)%p;
    			idft();
    			mod(i);
    		}
    		mod(n);
    	}
    	void sqrt(int n){
    		int m=len(n);
    		mod(m);
    		vector<u32>b(1,1);
    		a.swap(b);
    		u64 w=(p+1)/2;
    		for(int i=2;i<=m;i<<=1){
    			poly c(i);
    			for(int j=0;j<i;++j)
    				c[j]=b[j];
    			vector<u32>t=a;
    			inv(i);
    			mul(c);
    			mod(i);
    			for(int j=0;j<i>>1;++j)
    				a[j]=w*(a[j]+t[j])%p;
    			for(int j=i>>1;j<i;++j)
    				a[j]=w*a[j]%p;
    		}
    		mod(n);
    	}
    	void log(int n){
    		mod(n);
    		poly b=*this;
    		der();
    		b.inv(n-1);
    		mul(b);
    		mod(n-1);
    		pri();
    		mod(n);
    	}
    	void exp(int n){
    		int m=len(n);
    		mod(m);
    		vector<u32>b(1,1);
    		a.swap(b);
    		for(int i=2;i<=m;i<<=1){
    			poly c=*this;
    			log(i);
    			for(int j=0;j<i;++j)
    				a[j]=imod(b[j]+p-a[j]);
    			++a[0]%=p;
    			mul(c);
    			mod(i);
    		}
    		mod(n);
    	}
    	void sin(int n){
    		mod(n);
    		u64 w=ipow(g,(p-1)/4);
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		exp(n);
    		poly b=*this;
    		b.inv(n);
    		w=(p+1)/2*(p-w)%p;
    		for(int i=0;i<n;++i)
    			a[i]=(a[i]+p-b[i])*w%p;
    	}
    	void cos(int n){
    		mod(n);
    		u64 w=ipow(g,(p-1)/4);
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		exp(n);
    		poly b=*this;
    		b.inv(n);
    		w=(p+1)/2;
    		for(int i=0;i<n;++i)
    			a[i]=(a[i]+b[i])*w%p;
    	}
    	void tan(int n){
    		mod(n);
    		u64 w=ipow(g,(p-1)/4)*2;
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		exp(n);
    		++a[0]%=p;
    		inv(n);
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		(a[0]+=p-w/2)%=p;
    	}
    	poly mod_bf(poly b){
    		fix();
    		b.fix();
    		int n=size();
    		int m=b.size();
    		if(n<m)
    			return poly();
    		poly q(n-m+1);
    		u64 w=num::inv(b[m-1]);
    		for(;n>=m;--n)
    			if(u64 s=a[n-1]*w%p){
    				q[n-m]=s;
    				for(int i=n-1;i>=n-m;--i)
    					a[i]=(a[i]+(p-s)*b[i-n+m])%p;
    			}
    		fix();
    		return q;
    	}
    	void gcd(poly b){
    		fix();
    		b.fix();
    		while(b.size()){
    			mod_bf(b);
    			swap(b);
    		}
    	}
    	void div(poly b){
    		fix();
    		b.fix();
    		int n=size()-b.size()+1;
    		if(n<=0){
    			a.clear();
    			return;
    		}
    		reverse(RAN(a));
    		mod(n);
    		reverse(RAN(b.a));
    		b.inv(n);
    		mul(b);
    		mod(n);
    		reverse(RAN(a));
    		fix();
    	}
    	void mod(poly b){
    		fix();
    		b.fix();
    		int m=b.size();
    		if(size()>=m){
    			poly c=*this;
    			div(b);
    			mul(b);
    			mod(m-1);
    			for(int i=0;i<m-1;++i)
    				a[i]=imod(c[i]+p-a[i]);
    			fix();
    		}
    	}
    	u32 val(u32 x)const{
    		u32 s=0;
    		for(int i=size()-1;~i;--i)
    			s=((u64)s*x+a[i])%p;
    		return s;
    	}
    	u32 operator()(u32 x)const{
    		return val(x);
    	}
    	template<class ite1,class ite2>
    	void val(int n,ite1 x,ite2 y)const;
    	template<class ite>
    	poly(int n,ite x);
    	template<class ite1,class ite2>
    	poly(int n,ite1 x,ite2 y);
    	class seg;
    private:
    	void shl(){
    		a.insert(a.begin(),0);
    	}
    	void mod(poly b,const poly&f);
    	template<class ite>
    	static poly gen(int n,ite a);
    	template<class ite>
    	static poly gen(int x,int y,ite&a);
    	template<class ite>
    	static poly gen_bf(int n,ite&a);
    };
    void poly::mod(poly b,const poly&f){
    	fix();
    	b.fix();
    	int m=b.size();
    	if(size()>=m){
    		poly c=*this;
    		div(b);
    		dft(f.size());
    		for(int i=0;i<f.size();++i)
    			a[i]=(u64)a[i]*f[i]%p;
    		idft();
    		for(int i=0;i<m-1;++i)
    			a[i]=imod(c[i]+p-a[i]);
    		mod(m-1);
    		fix();
    	}
    }
    template<class ite>
    poly::poly(int n,ite x){
    	*this=gen(n,x);
    }
    template<class ite>
    poly poly::gen(int n,ite a){
    	return gen(0,n,a);
    }
    template<class ite>
    poly poly::gen(int x,int y,ite&a){
    	if(y-x<=200){
    		return gen_bf(y-x,a);
    	}else{
    		int m=x+y>>1;
    		poly b=gen(x,m,a);
    		b.mul(gen(m,y,a));
    		return b;
    	}
    }
    template<class ite>
    poly poly::gen_bf(int n,ite&a){
    	poly f(1);
    	f[0]=1;
    	for(int i=0;i<n;++i){
    		f.shl();
    		u64 s=p-*a++;
    		for(int j=0;j<=i;++j)
    			f[j]=(f[j]+s*f[j+1])%p;
    	}
    	return f;
    }
    class poly::seg{
    public:
    	template<class ite>
    	seg(int n,const ite&x):a(n){
    		ite t=x;
    		for(int i=0;i<n;++i)
    			a[i]=*t++;
    		dfs(0,n,1);
    	}
    	poly gen()const{
    		return t[1].a;
    	}
    	template<class ite>
    	poly gen(const ite&b)const{
    		ite t=b;
    		return gen(t);
    	}
    	template<class ite>
    	void val(const ite&b,const poly&f)const{
    		ite t=b;
    		val(t,f);
    	}
    	template<class ite>
    	seg(int n,ite&x):a(n){
    		for(int i=0;i<n;++i)
    			a[i]=*x++;
    		dfs(0,n,1);
    	}
    	template<class ite>
    	poly gen(ite&b)const{
    		poly f=t[1].a;
    		f.der();
    		return gen(0,a.size(),1,b,f);
    	}
    	template<class ite>
    	void val(ite&b,const poly&f)const{
    		val(0,a.size(),1,b,f);
    	}
    private:
    	struct pair{
    		poly a,b;
    	};
    	vector<pair>t;
    	vector<u32>a;
    	void dfs(int x,int y,int k){
    		if(t.size()<=k)
    			t.resize(k+1);
    		if(y-x<=200){
    			t[k].a=poly::gen(y-x,a.begin()+x);
    		}else{
    			int m=x+y>>1,i=k<<1,j=i^1;
    			dfs(x,m,i);
    			dfs(m,y,j);
    			int n=len(y-x+1);
    			t[i].b=t[i].a;
    			t[i].b.dft(n);
    			t[j].b=t[j].a;
    			t[j].b.dft(n);
    			t[k].a=t[i].b;
    			for(int l=0;l<n;++l)
    				t[k].a[l]=(u64)t[k].a[l]*t[j].b[l]%p;
    			t[k].a.idft();
    		}
    	}
    	template<class ite>
    	void val(int x,int y,int k,ite&b,poly f)const{
    		if(k!=1)
    			f.mod(t[k].a,t[k].b);
    		else
    			f.mod(t[k].a);
    		if(y-x<=200){
    			for(int i=0;i<y-x;++i)
    				*b++=f(a[x+i]);
    		}else{
    			int m=x+y>>1,i=k<<1,j=i^1;
    			val(x,m,i,b,f);
    			val(m,y,j,b,f);
    		}
    	}
    	template<class ite>
    	poly gen(int x,int y,int k,ite&b,poly f)const{
    		if(k!=1)
    			f.mod(t[k].a,t[k].b);
    		else
    			f.mod(t[k].a);
    		if(y-x<=200){
    			vector<u64>c(y-x);
    			for(int i=0;i<y-x;++i){
    				u64 n=a[x+i];
    				u64 m=*b++;
    				m=m*num::inv(f(n))%p;
    				u64 s=0;
    				for(int j=y-x;j>0;--j){
    					s=(t[k].a[j]+s*n)%p;
    					c[j-1]+=s*m%p;
    				}
    			}
    			poly d(y-x);
    			for(int i=0;i<y-x;++i)
    				d[i]=c[i]%p;
    			return d;
    		}else{
    			int m=x+y>>1;
    			int i=k<<1;
    			int j=i^1;
    			poly c=gen(x,m,i,b,f);
    			poly d=gen(m,y,j,b,f);
    			int n=len(y-x+1);
    			c.dft(n);
    			d.dft(n);
    			for(int l=0;l<n;++l)
    				c[l]=((u64)c[l]*t[j].b[l]+(u64)d[l]*t[i].b[l])%p;
    			c.idft();
    			return c;
    		}
    	}
    };
    template<class ite1,class ite2>
    poly::poly(int n,ite1 x,ite2 y){
    	*this=seg(n,x).gen(y);
    }
    template<class ite1,class ite2>
    void poly::val(int n,ite1 x,ite2 y)const{
    	int m=size();
    	if(min(m,n)<=100)
    		for(int i=0;i<n;++i)
    			*y++=val(*x++);
    	else
    		for(int l=n;;l/=2)
    			if(l*2<=m){
    				for(int i=0;i<n;i+=l)
    					seg(min(l,n-i),x).val(y,*this);
    				break;
    			}
    }
    

    2018-08-14

    • 计划今后增加$O(nlog^2n)$求多项式gcd的算法。

    2018-08-18

    • 增加了$O(n^2)$的gcd,未保证结果为首一多项式。
    • 分离了基础部分和点值有关部分。

    2018-08-19

    • 为避免影响编译器优化,删去了在运行时确定模数的功能。

    2019-03-17

    • 增加了sin/cos/tan。

    2019-06-28

    • 修复了除法中非预期地降低效率的问题。
  • 相关阅读:
    错排问题
    用GDAL/OGR去读shapefile
    聊聊MyBatis缓存机制
    一份平民化的应用性能优化检查列表(完整篇)--转
    微服务实战(七):从单体式架构迁移到微服务架构
    微服务实战(六):选择微服务部署策略
    微服务实战(五):微服务的事件驱动数据管理
    微服务实战(四):服务发现的可行方案以及实践案例
    微服务实战(三):深入微服务架构的进程间通信
    微服务实战(一):微服务架构的优势与不足
  • 原文地址:https://www.cnblogs.com/f321dd/p/9363493.html
Copyright © 2011-2022 走看看