zoukankan      html  css  js  c++  java
  • 多项式全家桶学习笔记(How EI's poly works)

    这里都是一些论文级别的玩意,基本不是给正常人类看的

    注意:这里仅对模数是 (998244353) 的部分进行介绍。

    零.让我们开始

    这里是一些基础的东西,不怎么需要想,这里就略过了。

    Code
    #include<bits/stdc++.h>
    #define endl '
    ' 
    #define rep(i,a,b) for(int i=(a);i<=(b);++i)
    #define Rep(i,a,b) for(int i=(a);i>=(b);--i)
    using namespace std;
    const int P=998244353,G=3,LIMIT=50;
    typedef vector<int> vec;
    struct IO_Tp {
        const static int _I_Buffer_Size = 2 << 22;
        char _I_Buffer[_I_Buffer_Size], *_I_pos = _I_Buffer;
    
        const static int _O_Buffer_Size = 2 << 22;
        char _O_Buffer[_O_Buffer_Size], *_O_pos = _O_Buffer;
    
        IO_Tp() { fread(_I_Buffer, 1, _I_Buffer_Size, stdin); }
        ~IO_Tp() { fwrite(_O_Buffer, 1, _O_pos - _O_Buffer, stdout); }
    
        IO_Tp &operator>>(int &res) {
        	int f=1;
            while (!isdigit(*_I_pos)&&(*_I_pos)!='-') ++_I_pos;
            if(*_I_pos=='-')f=-1,++_I_pos;
            res = *_I_pos++ - '0';
            while (isdigit(*_I_pos)) res = res * 10 + (*_I_pos++ - '0');
            res*=f;
            return *this;
        }
    
        IO_Tp &operator<<(int n) {
        	if(n<0)*_O_pos++='-',n=-n;
            static char _buf[10];
            char *_pos = _buf;
            do
                *_pos++ = '0' + n % 10;
            while (n /= 10);
            while (_pos != _buf) *_O_pos++ = *--_pos;
            return *this;
        }
    
        IO_Tp &operator<<(char ch) {
            *_O_pos++ = ch;
            return *this;
        }
    } IO;//快读
    void chkmax(int &x,int y){if(x<y)x=y;}
    void chkmin(int &x,int y){if(x>y)x=y;}
    int qpow(int a,int k,int p=P){//快速幂
    	int ret=1;
    	while(k){
    		if(k&1)ret=1ll*ret*a%p;
    		a=1ll*a*a%p;
    		k>>=1;
    	}
    	return ret;
    }
    int norm(int x){return x>=P?x-P:x;}
    int reduce(int x){return x<0?x+P:x;}
    void add(int&x,int y){if((x+=y)>=P)x-=P;}//取模
    struct Maths{
    	int n;
    	vec fac,invfac,inv;
    	void build(int n){
    		this->n=n;
    		fac.resize(n+1);
    		invfac.resize(n+1);
    		inv.resize(n+1);
    		fac[0]=1;
    		rep(k,1,n)fac[k]=1ll*fac[k-1]*k%P;
    		inv[1]=inv[0]=1;
    		rep(k,2,n)inv[k]=P-1ll*(P/k)*inv[P%k]%P;
    		invfac[0]=1;
    		rep(k,1,n)invfac[k]=1ll*invfac[k-1]*inv[k]%P;
    	}
    	Maths(){build(1);}
    	void chk(int k){
    		int lmt=n;
    		if(k>lmt){while(k>lmt)lmt<<=1;build(lmt);}
    	}
    	int cfac(int k){return chk(k),fac[k];}
    	int cifac(int k){return chk(k),invfac[k];}
    	int cinv(int k){return chk(k),inv[k];}
    	int binom(int n,int m){
    		if(m<0||m>n)return 0;
    		return 1ll*cfac(n)*cifac(m)%P*cifac(n-m)%P;
    	}
    }math;//普通数论部分
    struct poly{
    	vec a;
    	poly(int v=0):a(1){
    		if((v%=P)<0)v+=P;
    		a[0]=v;
    	}
    	poly(const vec&a):a(a){}
    	poly(initializer_list<int>init):a(init){}
    	int operator[](int k)const{return k<a.size()?a[k]:0;}
    	int&operator[](int k){
    		if(k>=a.size())a.resize(k+1);
    		return a[k];
    	}
    	int deg()const{return a.size()-1;}
    	void redeg(int d){a.resize(d+1);}
    	poly slice(int d)const{
    		if(d<a.size())return vec(a.begin(),a.begin()+d+1);
    		vec res(a);
    		res.resize(d+1);
    		return res;
    	}
    	int*base(){return a.data();}
    	const int*base()const{return a.data();}
    	poly println(FILE* fp)const{
    		fprintf(fp,"%d",a[0]);
    		rep(i,1,a.size()-1)fprintf(fp," %d",a[i]);
    		fputc('
    ',fp);
    		return *this;
    	}
    	poly operator+(const poly&rhs)const{
    		vec res(max(a.size(),rhs.a.size()));
    		rep(i,0,res.size()-1)if((res[i]=operator[](i)+rhs[i])>=P)res[i]-=P;
    		return res;
    	}
    	poly operator-()const{
    		poly ret(a);
    		rep(i,0,a.size()-1)if(ret[i])ret[i]=P-ret[i];
    		return ret;
    	}
    	poly operator-(const poly&rhs)const{return operator+(-rhs);}
            /*
            这里应该有一堆屎山声明,可是我懒得罗列了所以就没写
            */
        poly shift(int k)const;
    };//声明+部分简单函数
    poly zeroes(int deg){return vec(deg+1);}//0函数
    poly operator "" _z(unsigned long long a){return {0,(int)a};}
    poly operator+(int v,const poly&rhs){return poly(v)+rhs;}//多项式加整数
    poly operator*(int v,const poly&rhs){//多项式乘整数
    	poly ret=zeroes(rhs.deg());
    	rep(i,0,rhs.deg())ret[i]=1ll*rhs[i]*v%P;
    	return ret;
    }
    poly operator*(const poly&lhs,int v){return v*lhs;}
    poly poly::shift(int k)const{//多项式乘 x^k
    	poly g=zeroes(deg()+k);
    	rep(i,0,k-1)g[i]=0;
    	rep(i,min(0,-k),deg()-1)g[i+k]=a[i];
    	return g;
    }
    template<class T>
    IO_Tp& operator>>(IO_Tp& IO,vector<T>&v){//输入 vector
    	for(T&x:v)IO>>x;
    	return IO;
    }
    template<class T>
    IO_Tp& operator<<(IO_Tp& IO,vector<T>&v){//输出 vector
    	for(T&x:v)IO<<x;
    	return IO;
    }
    

    一.多项式乘法

    原理和普通的多项式乘法一致,没啥好说的,我们看看哪些地方可以优化。

    我们注意到可以在初始化的时候做一些预处理,这样大概可以减少一定的常数,在 P3803 这道题上面总时间 (1.6s o 1s),快了 0.6s。

    补充:EI 认为 “经过测试,某些形式特殊的数组的 NTT 改良版本,看似省略了部分计算,实则缓存不友好,还不如直接做”。(虽然经过测试,某一版本的 NTT 比这一版块 0.2s,但是码量差不多翻了一倍(在 NTT 部分),因此不予使用。)

    Code
    struct NTT{
    	int L,brev[1<<11];
    	vec root;
    	NTT():L(-1){
    		rep(i,1,(1<<11)-1)brev[i]=brev[i>>1]>>1|((i&1)<<10);
    	}
    	void preproot(int l){
    		L=l;
    		root.resize(2<<L);
    		rep(i,0,L){
    			int *w=root.data()+(1<<i);
    			w[0]=1;
    			int omega=qpow(G,(P-1)>>i);
    			rep(j,1,(1<<i)-1)w[j]=1ll*w[j-1]*omega%P;
    		}
    	}
    	void dft(int*a,int lgn,int d=1){
    		if(L<lgn)preproot(lgn);
    		int n=1<<lgn;
    		rep(i,0,n-1){
    			int rev=(brev[i>>11]|(brev[i&((1<<11)-1)]<<11))>>((11<<1)-lgn);
    			if(i<rev)swap(a[i],a[rev]);
    		}
    		for(int i=1;i<n;i<<=1){
    			int *w=root.data()+(i<<1);
    			for(int j=0;j<n;j+=i<<1)rep(k,0,i-1){
    				int aa=1ll*w[k]*a[i+j+k]%P;
    				a[i+j+k]=norm(a[j+k]+P-aa);
    				add(a[j+k],aa);
    			}
    		}
    		if(d==-1){
    			reverse(a+1,a+n);
    			int inv=nt.inv(n);
    			rep(i,0,n-1)a[i]=1ll*a[i]*inv%P;
    		}
    	}
    }ntt;
    

    二.多项式乘法逆

    阅读参考资料

    首先列一下牛迭的式子:假设 (fin mathbb R[[x]],Ain mathbb R[[x,y]]),满足 (A(x,f)=0),令 (f_0=fmod x^n),则

    [fmod x^{2n}=f_0-frac{A(x,f_0)}{frac{delta A}{delta y}(x,f_0)}mod x^{2n} ]

    如果定义 (operatorname{ord}(f)=min{n|[x^n]f eq 0}),那么我们可以观察到 (operatorname{ord}(A(f_0))ge n),因此计算 (A'(f_0)) 的精度只需达到 (n) 即可。

    下面讨论操作的优化。在以下讨论中,内容分为三部分:

    1. 直接按式子计算的时间。下面定义 (E(n)) 为一次长度为 (n) 的 DFT 所需的时间,(M(n)) 为一次两个精度为 (n) 的形式幂级数的乘法所需要的时间,因此 (M(n)=(3+o(1))E(2n)=(6+o(1))E(n))。在下文中,一切 (o(1)) 会被省略。
    2. 利用循环卷积优化。注意到在大部分情况下,我们已经得到了结果中的一部分系数,而长度为 (n) 的 DFT 解决了循环卷积问题 (fgmod (x^n-1)),仅用于计算卷积很浪费,所以可以考虑先计算循环卷积,必要时进行一些处理,最后得到所需的系数。
    3. 减少 DFT 次数。注意到很多时候计算了相同的 DFT,或者可以用线性变换的性质合并几次 IDFT,考虑减少这些额外的开销。

    在这一部分,我们研究的问题是倒数。

    1

    (fin mathbb R[[x]]),令 (g=1/f),求 (g)

    (A(g)=fg-1),代入牛迭式子可得:

    [gmod x^{2n}=2g_0-fg_0^2mod x^{2n} ]

    这就是我们一般使用的方法。

    这里一共使用了 (1) 次长度 (2n) 的乘法,(1) 次长度 (4n) 的乘法,用时 (M(2n)+M(4n)=18E(n))。(其实我觉得做三次长度为 (2n) 的 DFT 就行了,这样是 (12E(n)) 的,不知道为什么没有这么写)

    2

    (fin mathbb R[[x]]),令 (g=1/f),求 (g)

    考虑 (gmod x^{2n}=g_0-(fg_0-1)g_0mod x^{2n}),显然 (deg((fmod x^2n)g_0-1)<3n,operatorname{ord}((fmod x^2n)g_0-1)ge n),因此只需计算 ((fmod x^{2n})g_0mod (x^{2n}-1)) 即可,同理 ((fg_0-1)g_0mod x^{2n}) 也只需要长为 (2n) 的循环卷积,用时 (12E(n))。所以计算 (gmod x^n) 的时间是 (12E(n))

    3

    观察上述过程,有两次和 (g_0) 有关的长为 (2n) 的循环卷积,可以记录下来而不是重新算,用时 (10E(n))

    EI 的代码应该是按照这个实现的,不过把普通的递归换成了迭代,因此总时间 (0.6s o 0.2s),优化了 (0.4s)

    Code
    struct Newton{
    	void inv(const poly&f,const poly&nttf,poly&g,const poly&nttg,int t){//given f,g,nttf,nttg
    		int n=1<<t;
    		poly prod=nttf;
    		rep(i,0,(n<<1)-1)prod[i]=1ll*prod[i]*nttg[i]%P;
    		ntt.dft(prod.base(),t+1,-1);//calculate fg-1
    		rep(i,0,n-1)prod[i]=0;//prod=
    		ntt.dft(prod.base(),t+1,1);
    		rep(i,0,(n<<1)-1)prod[i]=1ll*prod[i]*nttg[i]%P;
    		ntt.dft(prod.base(),t+1,-1);//calculate (fg-1)g
    		rep(i,0,n-1)prod[i]=0;
    		g=g-prod;//calculate g-(fg-1)g
    	}
    	void inv(const poly&f,const poly&nttf,poly&g,int t){//given f,nttf,g
    		poly nttg=g;
    		nttg.redeg((2<<t)-1);
    		ntt.dft(nttg.base(),t+1,1);//calc nttg
    		inv(f,nttf,g,nttg,t);
    	}
    	void inv(const poly&f,poly&g,int t){//given f,g
    		poly nttg=g;
    		nttg.redeg((2<<t)-1);
    		ntt.dft(nttg.base(),t+1,1);//calc nttg
    		poly nttf=f;
    		nttf.redeg((2<<t)-1);
    		ntt.dft(nttf.base(),t+1,1);//calc nttf
    		inv(f,nttf,g,nttg,t);
    	}
    }nit;
    poly poly::inv()const{
    	poly g=nt.inv(a[0]);
    	for(int t=0;(1<<t)<=deg();++t)nit.inv(slice((2<<t)-1),g,t);
    	g.redeg(deg());
    	return g;
    }
    

    还可以继续优化吗?

    如果允许长度为 (3n) 的 DFT,那么考虑 (gmod x^{2n}=g_0-(fg_0^2-g_0)mod x^{2n}),用长度为 (3n) 的循环卷积计算 (fg_0^2) 即可达到 (9E(n)) 的用时,可惜大部分时候(比如 (998244353))都不能做。

    注意到我们不是必须要循环卷积才能解决问题,对 (ain mathbb R,a^{2n} eq 1),考虑在 (mathbb R[x]/(x^{2n}-1)(x^n-a^n)) 中计算卷积,即在 (1,zeta_{2n},zeta_{2n}^2,dots,zeta_{2n}^{2n-1},a,azeta_n,azeta_n^2,dots,azeta_n^{n-1}) 上多点求值和插值。对 (fin mathbb R[x]/(x^{2n}-1)(x^n-a^n)) 进行多点求值只需用 FFT 计算 (mathcal F_{2n}(f))(mathcal F_{2n}(f(ax))),而插值只需分别还原 CRT 合并。

    容易发现,如果在 (mathbb R[x]/(x^{2n}-1)(x^n-a^n)) 中进行卷积,仍然可以处理超出长度部分的影响,且不需要长度为 (3n) 的 DFT,同时也计算了 (mathcal F_{2n}(fmod x^{2n}))(mathcal F_{2n}(g_0)),所需时间仍是改进前的 (9E(n)=frac{3}{2}M(n)),所以可以几乎完全代替前一种做法。

    简单描述一下思路:为了算出所需结果 (f),先算出 (fmod (x^{2n}−1)(x^n−a^n)),考虑超出部分对前 (n) 项(本应全是 (0))的贡献,利用这些信息还原出这一部分,然后即可把这一部分对所需部分的影响消除掉。

    算出结果需在 (1, zeta_{2n}, zeta_{2n}^2, dots, zeta_{2n}^{2n-1}, a, azeta_n, azeta_n^2, dots, azeta_n^{n-1}) 上多点求值和插值,多点求值即计算 (mathcal F_{2n}(f))(mathcal F_n(f(ax))),插值即分别还原并 CRT 合并。

    实际实现并不需要考虑这些,最后的推荐实现也可以这样描述。考虑将所需结果 (f) 表示为 (ax^n+bx^{2n}+cx^{3n}),其中 (a,b,cin mathbb R[x])(deg(a),deg(b),deg(c)<n),那么可以用循环卷积计算出 (fmod (x^{2n}-1))(fmod (x^n-i)=f(zeta_{4n}x)mod (x^n-1)),也就相当于算出了 (b,a+c,ia-b-ic),还原出 (a) 即可。

    三.多项式 ln

    哈哈这个我熟,直接两边求导再积分,就可以得到 (ln f=int f'/f),然后利用多项式求逆就可以做到 (9E(n)+6E(n)=15E(n)) 了!

    一翻 EI 的代码:wtf 怎么这么长,怎么还有个 quo 函数???

    EI:You are too naive.

    翻了翻博客,发现求商数居然还有科技,学废了。

    1

    商数:对于 (f,hin mathbb R[[x]]),令 (g=1/f,q=hg=h/f),求 (q)

    显然先求 (g=1/f),再求 (q=hg) 即可,总时间是 (18E(n)+6E(n)=24E(n))

    对数:略。

    2

    对于 (f,hin mathbb R[[x]]),令 (g=1/f,q=hg=h/f),求 (q)

    直接求 (g) 然后卷积的总时间是 (18E(n))

    如果不求 (g),注意到 (A(q)=fq-h=0),令 (g_0=gmod x^n,h_0=hmod x^n,q_0=qmod x^n=g_0h_0mod x^n),有

    [qmod x^{2n}=q_0-(fq_0-h)g_0mod x^{2n} ]

    计算 (g_0)(12E(n)),计算 (q_0)(6E(n)),计算 ((fq_0-h)) 和倒数类似需要 (12E(n)) 的时间,因此计算 (qmod x^{2n})(30E(n)),计算 (qmod x^n) 就是 (15E(n))

    3

    对于 (f,hin mathbb R[[x]]),令 (g=1/f,q=hg=h/f),求 (q)

    (g_0=gmod x^n,g_1=(gmod x^{2n}-g_0)/x^n,h_0=hmod x^n,h_1=(hmod x^{2n}-h_0)/x^n)

    如果需要求 (g),考虑计算

    [qmod x^{2n}=(gmod x^{2n})(hmod x^{2n})mod x^{2n}=g_0h_0+(g_0h_1+g_1h_0)x^nmod x^{2n} ]

    • 计算 (mathcal F_{2n}(g_0),g_0,g_1) 需要 (18E(n)) 时间。
    • 计算 (mathcal F_{2n}(g_1),mathcal F_{2n}(h_0),mathcal F_{2n}(h_1)) 需要 (6E(n)) 时间。
    • 计算 (g_0h_0,g_0h_1+g_1h_0) 需要 (4E(n)) 时间。

    总共需要 (28E(n)) 时间计算 (qmod x^{2n}),因此计算 (qmod x^n) 就需要 (14E(n)) 时间。

    这里用到的技巧可以表述为:对于 (f,gin mathbb R[[x]]),已知 (fmod x^n,gmod x^n,mathcal F_n(fmod x^{n/2})),则需要 (5E(n)) 时间计算出 (fgmod x^n)

    如果不求 (g),仍然考虑

    [qmod x^{2n}=q_0-(fq_0-h)g_0mod x^{2n} ]

    与第二部分的做法相比,可以用更快计算倒数的方法,计算 (q_0) 时用的 (mathcal F_{2n}(g_0)) 可以用于计算 ((fq_0-1)g_0),所以计算 (qmod x^{2n}) 所用时间为 (24E(n)),计算 (qmod x^n) 总时间为 (12E(n)=2M(n))

    • 计算 (g_0) ( o) (9E(n))
    • 计算 (mathcal F_{2n}(g_0),q_0) ( o) (6E(n))
    • 计算 (fq_0) ( o) (M(n)=6E(n))
    • 计算 ((fq_0-h)g_0),已知 (mathcal F_{2n}(g_0)) ( o) (4E(n))

    因此计算 (qmod x^{2n}) 总时间是 (25E(n)),计算 (qmod x^n) 总时间是 (12.5E(n))

    观察 (q=g_0h_0mod x^n,(fq_0-h)g_0),可以发现符合上文描述的技巧的使用条件,其中 ((fq_0-h)g_0) 可以视为计算 (((fq_0-h)/x^n)g_0),再考虑到相同的 DFT 只用计算一次,需要 (9E(n)) 时间计算,总时间就是 (12E(n)=2M(n))

    Code
    poly poly::quo(const poly&rhs)const{
    	if(rhs.deg()==0)return 1ll*a[0]*nt.inv(rhs[0])%P;
    	poly g=nt.inv(rhs[0]);
    	int t=0,n;
    	for(n=1;(n<<1)<=rhs.deg();++t,n<<=1)nit.inv(rhs.slice((n<<1)-1),g,t);
    	poly nttg=g;
    	nttg.redeg((n<<1)-1);
    	ntt.dft(nttg.base(),t+1,1);
    	poly eps1=rhs.slice((n<<1)-1);
    	ntt.dft(eps1.base(),t+1,1);
    	rep(i,0,(n<<1)-1)eps1[i]=1ll*eps1[i]*nttg[i]%P;
    	ntt.dft(eps1.base(),t+1,-1);
    	memcpy(eps1.base(),eps1.base()+n,sizeof(int)<<t);
    	memset(eps1.base()+n,0,sizeof(int)<<t);
    	ntt.dft(eps1.base(),t+1,1);
    	poly h0=slice(n-1);
    	h0.redeg((n<<1)-1);
    	ntt.dft(h0.base(),t+1,1);
    	poly h0g0=zeroes((n<<1)-1);
    	rep(i,0,(n<<1)-1)h0g0[i]=1ll*h0[i]*nttg[i]%P;
    	ntt.dft(h0g0.base(),t+1,-1);
    	poly h0eps1=zeroes((n<<1)-1);
    	rep(i,0,(n<<1)-1)h0eps1[i]=1ll*h0[i]*eps1[i]%P;
    	ntt.dft(h0eps1.base(),t+1,-1);
    	rep(i,0,n-1)h0eps1[i]=reduce(operator[](i+n)-h0eps1[i]);
    	memset(h0eps1.base()+n,0,sizeof(int)<<t);
    	ntt.dft(h0eps1.base(),t+1,1);
    	rep(i,0,(n<<1)-1)h0eps1[i]=1ll*h0eps1[i]*nttg[i]%P;
    	ntt.dft(h0eps1.base(),t+1,-1);
    	memcpy(h0eps1.base()+n,h0eps1.base(),sizeof(int)<<t);
    	memset(h0eps1.base(),0,sizeof(int)<<t);
    	return (h0g0+h0eps1).slice(rhs.deg());
    }
    
    </details>
    
  • 相关阅读:
    理解爬虫原理
    中文词频统计与词云生成
    复合数据类型,英文词频统计
    字符串操作、文件操作,英文词频统计预处理
    了解大数据的特点、来源与数据呈现方式
    为Bootstrap模态对话框添加拖拽移动功能
    前端进阶学习笔记
    前端基础学习笔记
    MySQL学习笔记(模块二)
    MySQL学习笔记(模块一)
  • 原文地址:https://www.cnblogs.com/happydef/p/poly4.html
Copyright © 2011-2022 走看看