zoukankan      html  css  js  c++  java
  • 多项式模板

    这里给出了一个多项式/形式幂级数的类的实现。由于仅考虑形式幂级数的前n项,即$!mod{x^n}$下的等价类,故其形式与多项式相同,因而在类的实现上没有对两者作区分。没有返回值的方法会将结果用于更新自身。对于所有用于形式幂级数的方法,参数n代表$!mod{x^n}$,且保证调用后前n项均可访问。三角函数使用了要求$sqrt{-1}$存在的实现。下表说明了各方法的复杂度。

    • der(求导)/pri(积分):$O(n)$
    • mul/inv/sqrt/log/exp/sin/cos/tan:$O(nlog n)$
    • div/mod:$O(nlog n)$
    • 在$m$个点处求值:$O(nlog n+mlog^2min(m,n))$
    • 以根集构造:$O(nlog^2n)$
    • 以点集构造(插值):$O(nlog^2n)$

    注意,对于某些算式,相比直接调用以上方法进行计算,使用针对性的实现能使效率大幅提高。


    #include<algorithm>
    #include<vector>
    #define RAN(a)a.begin(),a.end()
    using namespace std;
    typedef unsigned long long u64;
    typedef unsigned u32;
    namespace num{
    	const u32 p=998244353;
    	const u32 g=3;
    	inline u32 imod(u32 a){
    		return a<p?a:a-p;
    	}
    	inline u32 ipow(u32 a,u32 n){
    		u32 s=1;
    		for(;n;n>>=1){
    			if(n&1)
    				s=(u64)s*a%p;
    			a=(u64)a*a%p;
    		}
    		return s;
    	}
    	class inv_t{
    	public:
    		inv_t():f(1,1){}
    		u32 operator[](int n){
    			int m=f.size();
    			if(m<n){
    				f.resize(n);
    				for(int i=m+1;i<=n;++i)
    					f[i-1]=(u64)(p-p/i)*f[p%i-1]%p;
    			}
    			return f[n-1];
    		}
    		u32 operator()(u32 a)const{
    			return ipow(a,p-2);
    		}
    	private:
    		vector<u32>f;
    	}inv;
    }
    using namespace num;
    class poly{
    public:
    	vector<u32>a;
    	poly(){}
    	explicit poly(int n):a(n){}
    	u32&operator[](int i){
    		return a[i];
    	}
    	const u32&operator[](int i)const{
    		return a[i];
    	}
    	int size()const{
    		return a.size();
    	}
    	void swap(poly&b){
    		a.swap(b.a);
    	}
    	void der(){
    		fix();
    		if(size()){
    			for(int i=1;i<size();++i)
    				a[i-1]=(u64)i*a[i]%p;
    			a.pop_back();
    		}
    	}
    	void pri(){
    		fix();
    		shl();
    		for(int i=size()-1;i>0;--i)
    			a[i]=(u64)a[i]*num::inv[i]%p;
    	}
    	static int len(int n){
    		while(n^n&-n)
    			n+=n&-n;
    		return n;
    	}
    	void fft(int n,bool f){
    		a.resize(n);
    		if(n<=1)
    			return;
    		for(int i=0,j=0;i<n;++i){
    			if(i<j)
    				std::swap(a[i],a[j]);
    			int k=n>>1;
    			while((j^=k)<k)
    				k>>=1;
    		}
    		vector<u32>w(n/2);
    		w[0]=1;
    		for(int i=1;i<n;i<<=1){
    			for(int j=i/2-1;~j;--j)
    				w[j<<1]=w[j];
    			int m=(p-1)/i/2;
    			u64 s=ipow(g,f?p-1-m:m);
    			for(int j=1;j<i;j+=2)
    				w[j]=s*w[j-1]%p;
    			for(int j=0;j<n;j+=i<<1){
    				u32*b=&a[0]+j,*c=b+i;
    				for(int k=0;k<i;++k){
    					u32 v=(u64)w[k]*c[k]%p;
    					c[k]=imod(b[k]+p-v);
    					b[k]=imod(b[k]+v);
    				}
    			}
    		}
    	}
    	void dft(int n){
    		fft(n,0);
    	}
    	void idft(){
    		int n=size();
    		fft(n,1);
    		u64 f=num::inv(n);
    		for(int i=0;i<n;++i)
    			a[i]=f*a[i]%p;
    	}
    	void fix(){
    		while(size()&&!a.back())
    			a.pop_back();
    	}
    	void mul(poly b){
    		fix();
    		b.fix();
    		int n=len(size()+b.size()-1);
    		dft(n);
    		b.dft(n);
    		for(int i=0;i<n;++i)
    			a[i]=(u64)a[i]*b[i]%p;
    		idft();
    		fix();
    	}
    	void mod(int n){
    		a.resize(n);
    	}
    	void inv(int n){
    		int m=len(n);
    		mod(m);
    		vector<u32>b(1,num::inv(a[0]));
    		a.swap(b);
    		for(int i=2;i<=m;i<<=1){
    			int l=i<<1;
    			poly c(l);
    			for(int j=0;j<i;++j)
    				c[j]=b[j];
    			c.dft(l);
    			dft(l);
    			for(int j=0;j<l;++j)
    				a[j]=a[j]*(2+p-(u64)a[j]*c[j]%p)%p;
    			idft();
    			mod(i);
    		}
    		mod(n);
    	}
    	void sqrt(int n){
    		int m=len(n);
    		mod(m);
    		vector<u32>b(1,1);
    		a.swap(b);
    		u64 w=(p+1)/2;
    		for(int i=2;i<=m;i<<=1){
    			poly c(i);
    			for(int j=0;j<i;++j)
    				c[j]=b[j];
    			vector<u32>t=a;
    			inv(i);
    			mul(c);
    			mod(i);
    			for(int j=0;j<i>>1;++j)
    				a[j]=w*(a[j]+t[j])%p;
    			for(int j=i>>1;j<i;++j)
    				a[j]=w*a[j]%p;
    		}
    		mod(n);
    	}
    	void log(int n){
    		mod(n);
    		poly b=*this;
    		der();
    		b.inv(n-1);
    		mul(b);
    		mod(n-1);
    		pri();
    		mod(n);
    	}
    	void exp(int n){
    		int m=len(n);
    		mod(m);
    		vector<u32>b(1,1);
    		a.swap(b);
    		for(int i=2;i<=m;i<<=1){
    			poly c=*this;
    			log(i);
    			for(int j=0;j<i;++j)
    				a[j]=imod(b[j]+p-a[j]);
    			++a[0]%=p;
    			mul(c);
    			mod(i);
    		}
    		mod(n);
    	}
    	void sin(int n){
    		mod(n);
    		u64 w=ipow(g,(p-1)/4);
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		exp(n);
    		poly b=*this;
    		b.inv(n);
    		w=(p+1)/2*(p-w)%p;
    		for(int i=0;i<n;++i)
    			a[i]=(a[i]+p-b[i])*w%p;
    	}
    	void cos(int n){
    		mod(n);
    		u64 w=ipow(g,(p-1)/4);
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		exp(n);
    		poly b=*this;
    		b.inv(n);
    		w=(p+1)/2;
    		for(int i=0;i<n;++i)
    			a[i]=(a[i]+b[i])*w%p;
    	}
    	void tan(int n){
    		mod(n);
    		u64 w=ipow(g,(p-1)/4)*2;
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		exp(n);
    		++a[0]%=p;
    		inv(n);
    		for(int i=0;i<n;++i)
    			a[i]=a[i]*w%p;
    		(a[0]+=p-w/2)%=p;
    	}
    	poly mod_bf(poly b){
    		fix();
    		b.fix();
    		int n=size();
    		int m=b.size();
    		if(n<m)
    			return poly();
    		poly q(n-m+1);
    		u64 w=num::inv(b[m-1]);
    		for(;n>=m;--n)
    			if(u64 s=a[n-1]*w%p){
    				q[n-m]=s;
    				for(int i=n-1;i>=n-m;--i)
    					a[i]=(a[i]+(p-s)*b[i-n+m])%p;
    			}
    		fix();
    		return q;
    	}
    	void gcd(poly b){
    		fix();
    		b.fix();
    		while(b.size()){
    			mod_bf(b);
    			swap(b);
    		}
    	}
    	void div(poly b){
    		fix();
    		b.fix();
    		int n=size()-b.size()+1;
    		if(n<=0){
    			a.clear();
    			return;
    		}
    		reverse(RAN(a));
    		mod(n);
    		reverse(RAN(b.a));
    		b.inv(n);
    		mul(b);
    		mod(n);
    		reverse(RAN(a));
    		fix();
    	}
    	void mod(poly b){
    		fix();
    		b.fix();
    		int m=b.size();
    		if(size()>=m){
    			poly c=*this;
    			div(b);
    			mul(b);
    			mod(m-1);
    			for(int i=0;i<m-1;++i)
    				a[i]=imod(c[i]+p-a[i]);
    			fix();
    		}
    	}
    	u32 val(u32 x)const{
    		u32 s=0;
    		for(int i=size()-1;~i;--i)
    			s=((u64)s*x+a[i])%p;
    		return s;
    	}
    	u32 operator()(u32 x)const{
    		return val(x);
    	}
    	template<class ite1,class ite2>
    	void val(int n,ite1 x,ite2 y)const;
    	template<class ite>
    	poly(int n,ite x);
    	template<class ite1,class ite2>
    	poly(int n,ite1 x,ite2 y);
    	class seg;
    private:
    	void shl(){
    		a.insert(a.begin(),0);
    	}
    	void mod(poly b,const poly&f);
    	template<class ite>
    	static poly gen(int n,ite a);
    	template<class ite>
    	static poly gen(int x,int y,ite&a);
    	template<class ite>
    	static poly gen_bf(int n,ite&a);
    };
    void poly::mod(poly b,const poly&f){
    	fix();
    	b.fix();
    	int m=b.size();
    	if(size()>=m){
    		poly c=*this;
    		div(b);
    		dft(f.size());
    		for(int i=0;i<f.size();++i)
    			a[i]=(u64)a[i]*f[i]%p;
    		idft();
    		for(int i=0;i<m-1;++i)
    			a[i]=imod(c[i]+p-a[i]);
    		mod(m-1);
    		fix();
    	}
    }
    template<class ite>
    poly::poly(int n,ite x){
    	*this=gen(n,x);
    }
    template<class ite>
    poly poly::gen(int n,ite a){
    	return gen(0,n,a);
    }
    template<class ite>
    poly poly::gen(int x,int y,ite&a){
    	if(y-x<=200){
    		return gen_bf(y-x,a);
    	}else{
    		int m=x+y>>1;
    		poly b=gen(x,m,a);
    		b.mul(gen(m,y,a));
    		return b;
    	}
    }
    template<class ite>
    poly poly::gen_bf(int n,ite&a){
    	poly f(1);
    	f[0]=1;
    	for(int i=0;i<n;++i){
    		f.shl();
    		u64 s=p-*a++;
    		for(int j=0;j<=i;++j)
    			f[j]=(f[j]+s*f[j+1])%p;
    	}
    	return f;
    }
    class poly::seg{
    public:
    	template<class ite>
    	seg(int n,const ite&x):a(n){
    		ite t=x;
    		for(int i=0;i<n;++i)
    			a[i]=*t++;
    		dfs(0,n,1);
    	}
    	poly gen()const{
    		return t[1].a;
    	}
    	template<class ite>
    	poly gen(const ite&b)const{
    		ite t=b;
    		return gen(t);
    	}
    	template<class ite>
    	void val(const ite&b,const poly&f)const{
    		ite t=b;
    		val(t,f);
    	}
    	template<class ite>
    	seg(int n,ite&x):a(n){
    		for(int i=0;i<n;++i)
    			a[i]=*x++;
    		dfs(0,n,1);
    	}
    	template<class ite>
    	poly gen(ite&b)const{
    		poly f=t[1].a;
    		f.der();
    		return gen(0,a.size(),1,b,f);
    	}
    	template<class ite>
    	void val(ite&b,const poly&f)const{
    		val(0,a.size(),1,b,f);
    	}
    private:
    	struct pair{
    		poly a,b;
    	};
    	vector<pair>t;
    	vector<u32>a;
    	void dfs(int x,int y,int k){
    		if(t.size()<=k)
    			t.resize(k+1);
    		if(y-x<=200){
    			t[k].a=poly::gen(y-x,a.begin()+x);
    		}else{
    			int m=x+y>>1,i=k<<1,j=i^1;
    			dfs(x,m,i);
    			dfs(m,y,j);
    			int n=len(y-x+1);
    			t[i].b=t[i].a;
    			t[i].b.dft(n);
    			t[j].b=t[j].a;
    			t[j].b.dft(n);
    			t[k].a=t[i].b;
    			for(int l=0;l<n;++l)
    				t[k].a[l]=(u64)t[k].a[l]*t[j].b[l]%p;
    			t[k].a.idft();
    		}
    	}
    	template<class ite>
    	void val(int x,int y,int k,ite&b,poly f)const{
    		if(k!=1)
    			f.mod(t[k].a,t[k].b);
    		else
    			f.mod(t[k].a);
    		if(y-x<=200){
    			for(int i=0;i<y-x;++i)
    				*b++=f(a[x+i]);
    		}else{
    			int m=x+y>>1,i=k<<1,j=i^1;
    			val(x,m,i,b,f);
    			val(m,y,j,b,f);
    		}
    	}
    	template<class ite>
    	poly gen(int x,int y,int k,ite&b,poly f)const{
    		if(k!=1)
    			f.mod(t[k].a,t[k].b);
    		else
    			f.mod(t[k].a);
    		if(y-x<=200){
    			vector<u64>c(y-x);
    			for(int i=0;i<y-x;++i){
    				u64 n=a[x+i];
    				u64 m=*b++;
    				m=m*num::inv(f(n))%p;
    				u64 s=0;
    				for(int j=y-x;j>0;--j){
    					s=(t[k].a[j]+s*n)%p;
    					c[j-1]+=s*m%p;
    				}
    			}
    			poly d(y-x);
    			for(int i=0;i<y-x;++i)
    				d[i]=c[i]%p;
    			return d;
    		}else{
    			int m=x+y>>1;
    			int i=k<<1;
    			int j=i^1;
    			poly c=gen(x,m,i,b,f);
    			poly d=gen(m,y,j,b,f);
    			int n=len(y-x+1);
    			c.dft(n);
    			d.dft(n);
    			for(int l=0;l<n;++l)
    				c[l]=((u64)c[l]*t[j].b[l]+(u64)d[l]*t[i].b[l])%p;
    			c.idft();
    			return c;
    		}
    	}
    };
    template<class ite1,class ite2>
    poly::poly(int n,ite1 x,ite2 y){
    	*this=seg(n,x).gen(y);
    }
    template<class ite1,class ite2>
    void poly::val(int n,ite1 x,ite2 y)const{
    	int m=size();
    	if(min(m,n)<=100)
    		for(int i=0;i<n;++i)
    			*y++=val(*x++);
    	else
    		for(int l=n;;l/=2)
    			if(l*2<=m){
    				for(int i=0;i<n;i+=l)
    					seg(min(l,n-i),x).val(y,*this);
    				break;
    			}
    }
    

    2018-08-14

    • 计划今后增加$O(nlog^2n)$求多项式gcd的算法。

    2018-08-18

    • 增加了$O(n^2)$的gcd,未保证结果为首一多项式。
    • 分离了基础部分和点值有关部分。

    2018-08-19

    • 为避免影响编译器优化,删去了在运行时确定模数的功能。

    2019-03-17

    • 增加了sin/cos/tan。

    2019-06-28

    • 修复了除法中非预期地降低效率的问题。
  • 相关阅读:
    【数据结构】线性表&&顺序表详解和代码实例
    【智能算法】超详细的遗传算法(Genetic Algorithm)解析和TSP求解代码详解
    【智能算法】用模拟退火(SA, Simulated Annealing)算法解决旅行商问题 (TSP, Traveling Salesman Problem)
    【智能算法】迭代局部搜索(Iterated Local Search, ILS)详解
    10. js时间格式转换
    2. 解决svn working copy locked问题
    1. easyui tree 初始化的两种方式
    10. js截取最后一个斜杠后面的字符串
    2. apache整合tomcat部署集群
    1. apache如何启动
  • 原文地址:https://www.cnblogs.com/f321dd/p/9363493.html
Copyright © 2011-2022 走看看