zoukankan      html  css  js  c++  java
  • Algorithm: Polynomial Multiplication Fast Fourier Transform / NumberTheoretic Transform (English version)

    Intro:

    This blog will start with plain multiplication, go through Divide-and-conquer multiplication, and reach FFT and NTT.

    The aim is to enable the reader (and myself) to fully understand the idea.

    Template question entrance: Luogu P3803 【模板】多项式乘法(FFT)


    Plain multiplication

    Assumption: Two polynomials are \(A(x)=\sum_{i=0}^{n}a_ix^i,B(x)=\sum_{i=0}^{m}b_ix^i\)

    Prerequisite knowledge:

    Knowledge of junior high school mathematics

    The simplest method is to multiply term by term and then combine like terms, written as the formula:

    If \(C(x)=A(x)B(x)\), then \(C(x)=\sum_{i=0}^{n+m}c_ix^i\), where \(c_i=\sum_{j=0}^ia_jb_{i-j}\).

    So a plain multiplication is generated, see the code (\(b\) array omitted with some useless techniques).

    //This program is written by Brian Peng.
    #pragma GCC optimize("Ofast","inline","no-stack-protector")
    #include<bits/stdc++.h>
    using namespace std;
    #define Rd(a) (a=read())
    #define Gc(a) (a=getchar())
    #define Pc(a) putchar(a)
    int read(){
    	register int x;register char c(getchar());register bool k;
    	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
    	if(c^'-')k=1,x=c&15;else k=x=0;
    	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
    	return k?x:-x;
    }
    void wr(register int a){
    	if(a<0)Pc('-'),a=-a;
    	if(a<=9)Pc(a|'0');
    	else wr(a/10),Pc((a%10)|'0');
    }
    signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
    long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
    #define Ps Pc(' ')
    #define Pe Pc('\n')
    #define Frn0(i,a,b) for(register int i(a);i<(b);++i)
    #define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
    #define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
    #define Mst(a,b) memset(a,b,sizeof(a))
    #define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
    #define N (2000010)
    int n,m,a[N],b,c[N];
    signed main(){
    	Rd(n),Rd(m);
    	Frn1(i,0,n)Rd(a[i]);
    	Frn1(i,0,m){Rd(b);Frn1(j,0,n)c[i+j]+=b*a[j];}
    	Frn1(i,0,n+m)wr(c[i]),Ps;
    	exit(0);
    }
    

    Time complexity: \(O(nm)\) (If\(m=O(n)\), then \(O(n^2)\))

    Memory complexity: \(O(n)\)

    Results:

    Expected, so we need to optimize it.


    Divide-and-conquer multiplication (Fake)

    P.s This part describes the Divide-and-conquer method of FFT, which is still different from the exact FFT, so you can skip it if you have already mastered the Divide-and-conquer idea.

    Let \(n\) be the smallest positive integer power of \(2\) that is strictly greater than both the degrees of \(A(x),B(x)\), and we write \(A(x)=\sum_{i=0}^{n-1}a_ix^i,B(x)=\sum_{i=0}^{n-1}b_ix^i\), where the unexisted coefficients are made \(0\).

    Prerequisite knowledge:

      The idea of Divide-and-conquer

    Now consider how to optimize multiplication.

    Try to separate two polynomials according to the parity of the index of \(x\)

    \(A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2),B(x)=B^{[0]}(x^2)+xB^{[1]}(x^2)\),

    where \(A^{[0]}(x)=\sum_{i=0}^{n/2-1}a_{2i}x^i,A^{[1]}(x)=\sum_{i=0}^{n/2-1}a_{2i+1}x^i\), and \(B^{[0]}(x)\) and \(B^{[1]}(x)\) are similar.

    Therefore, the two polynomials are split into four polynomials, each with degree \(<n/2\).

    We let \(A=A(x),A^{[0]}=A^{[0]}(x^2),A^{[1]}=A^{[1]}(x^2)\), and similar for \(B\) and others,

    then \(AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}\).

    A Divide-and-conquer algorithm can be found here: split two polynomials in half, then recursively do \(4\) polynomial multiplications, and finally combine them together (polynomial addition is \(O(n)\) anyway)

    P.s As \(A^{[0]}=A^{[0]}(x^2)\) and \(A^{[1]}=A^{[1]}(x^2)\), the combination process is alternating. Here is the code. (In the code, the \(n\) above is replaced by the variable s, and vector is used to save memory)

    //This program is written by Brian Peng.
    #pragma GCC optimize("Ofast","inline","no-stack-protector")
    #include<bits/stdc++.h>
    using namespace std;
    #define Rd(a) (a=read())
    #define Gc(a) (a=getchar())
    #define Pc(a) putchar(a)
    int read(){
    	register int x;register char c(getchar());register bool k;
    	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
    	if(c^'-')k=1,x=c&15;else k=x=0;
    	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
    	return k?x:-x;
    }
    void wr(register int a){
    	if(a<0)Pc('-'),a=-a;
    	if(a<=9)Pc(a|'0');
    	else wr(a/10),Pc((a%10)|'0');
    }
    signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
    long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
    #define Ps Pc(' ')
    #define Pe Pc('\n')
    #define Frn0(i,a,b) for(register int i(a);i<(b);++i)
    #define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
    #define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
    #define Mst(a,b) memset(a,b,sizeof(a))
    #define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
    typedef vector<int> Vct;
    int n,m,s; 
    Vct a,b,c;
    void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
    void mlt(Vct&a,Vct&b,Vct&c,int n);
    signed main(){
    	Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
    	Frn1(i,0,n)Rd(a[i]);
    	Frn1(i,0,m)Rd(b[i]);
    	mlt(a,b,c,s);
    	Frn1(i,0,n+m)wr(c[i]),Ps;
    	exit(0);
    }
    void mlt(Vct&a,Vct&b,Vct&c,int n){
    	int n2(n>>1);
    	Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
    	if(n==1){c[0]=a[0]*b[0];return;}
    	Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
    	mlt(a0,b0,ab0,n2),mlt(a1,b1,ab1,n2);
    	Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
    	mlt(a0,b1,ab0,n2),mlt(a1,b0,ab1,n2),add(ab0,ab1,abm);
    	Frn0(i,0,n-1)c[i<<1|1]=abm[i];
    }
    

    Results:

    even worse

    Why's that? Because the Time complexity is still \(O(n^2)\).

    \(\textit{Proof. } T(n)=4T(n/2)+f(n)\), in which \(f(n)=O(n)\) the time complexity of polynomial addition.

    Using the Master Theorem with \(a=4,b=2,\log_ba=\log_2 4=2>1\), we have \(T(n)=O(n^{\log_ba})=O(n^2)\).

    So, let's continue optimizing


    Divide-and-conquer multiplication (Real)

    Let's consider how to optimize the "fake" one.

    An intro question: Try to find an algorithm to multiply linear expressions \(ax+b\) and \(cx+d\) with only \(3\) multiplication steps.

    Let's expand the multiplication: \((ax+b)(cx+d)=acx^2+(ad+bc)x+bd\), there seems to be \(4\) multiplication steps used.

    Hence, if we can only use \(3\) multiplication steps, then \(ad+bc\) should cost only one.

    Let's add all coefficients together: \(ac+ad+bc+bd=(a+b)(c+d)\),

    and here is the answer! Use \(3\) multiplication steps to calculate \(ac,bd,(a+b)(c+d)\) respectively, and the \(x\) coefficient is just \(ad+bc=(a+b)(c+d)-ac-bd\)

    Let's go back to the original question

    As \(AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}\),

    we can use the similar method to reduce one multiplication step: \(A^{[1]}B^{[0]}+A^{[0]}B^{[1]}=(A^{[0]}+A^{[1]})(B^{[0]}+B^{[1]})-A^{[0]}B^{[0]}-A^{[1]}B^{[1]}\)

    Here is the code:

    //This program is written by Brian Peng.
    #pragma GCC optimize("Ofast","inline","no-stack-protector")
    #include<bits/stdc++.h>
    using namespace std;
    #define Rd(a) (a=read())
    #define Gc(a) (a=getchar())
    #define Pc(a) putchar(a)
    int read(){
    	register int x;register char c(getchar());register bool k;
    	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
    	if(c^'-')k=1,x=c&15;else k=x=0;
    	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
    	return k?x:-x;
    }
    void wr(register int a){
    	if(a<0)Pc('-'),a=-a;
    	if(a<=9)Pc(a|'0');
    	else wr(a/10),Pc((a%10)|'0');
    }
    signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
    long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
    #define Ps Pc(' ')
    #define Pe Pc('\n')
    #define Frn0(i,a,b) for(register int i(a);i<(b);++i)
    #define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
    #define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
    #define Mst(a,b) memset(a,b,sizeof(a))
    #define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
    typedef vector<int> Vct;
    int n,m,s;
    Vct a,b,c;
    void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
    void mns(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]-b[i];}
    void mlt(Vct&a,Vct&b,Vct&c);
    signed main(){
    	Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
    	Frn1(i,0,n)Rd(a[i]);
    	Frn1(i,0,m)Rd(b[i]);
    	mlt(a,b,c);
    	Frn1(i,0,n+m)wr(c[i]),Ps;
    	exit(0);
    }
    void mlt(Vct&a,Vct&b,Vct&c){
    	int n(a.size()),n2(a.size()>>1);
    	Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
    	if(n==1){c[0]=a[0]*b[0];return;}
    	Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
    	mlt(a0,b0,ab0),mlt(a1,b1,ab1);
    	Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
    	add(a0,a1,a0),add(b0,b1,b0),mlt(a0,b0,abm),mns(abm,ab0,abm),mns(abm,ab1,abm);
    	Frn0(i,0,n-1)c[i<<1|1]=abm[i];
    }
    

    Results

    Better than fake DC multiplication, but even worse than plain multiplication...

    Let's calculate the time complexity of this algorithm:

    \(T(n)=3T(n/2)+f(n)\), in which \(f(n)=O(n)\).

    Using Master Theorem with \(a=3,b=2,\log_ba=\log_2 3\approx1.58>1\), so \(T(n)=O(n^{\log_ba})=O(n^{\log_2 3})\).

    Hmm...so why is it even worse than plain multiplication?

    Reason 1. The constant factor of DC multiplication is too high.

    Reason 2. In \(\#5\) test case, we have \(n=1,m=3\cdot 10^6\), then \(O(n^{\log_2 3})\) is really worse than \(O(nm)\)...

    So, our FFT is eventually coming!


    Fast Fourier Transform

    Fairly Frightening Transform

    Let \(n\) be the smallest positive integer power of \(2\) greater than \(\deg A(x)+\deg B(x)\) and we write \(A(x)=\sum_{i=0}^{n-1}a_ix^i,B(x)=\sum_{i=0}^{n-1}b_ix^i\).

    Prerequisite knowledge:

      The idea of Divide-and-conquer

      Complex number basics

    Linear algebra basics (not strictly required)

    Part 1: To representations of the polynomial

    1. Coefficient expressions

    For a polynomial \(A(x)=\sum_{i=0}^{n-1}a_ix^i\), its coefficient expression is a vector \(\pmb{a}=\left[\begin{matrix}a_0\\a_1\\\vdots\\a_{n-1}\end{matrix} \right]\)

    In coefficient expressions, the time complexities of the following methods are:

    1. Evaluation at a point: \(O(n)\)

    2. Addition: \(O(n)\)

    3. Multiplication: plain \(O(n^2)\), DC \((n^{\log_2 3})\)

    P.s When calculating polynomial multiplication \(C(x)=A(x)B(x)\), the corresponding coefficient expression \(\pmb{c}\) is defined as the convolution of \(\pmb{a}\) and \(\pmb{b}\), written as \(\pmb{c}=\pmb{a}\bigotimes\pmb{b}\).

    2. Point-valued expressions

    The point-valued expression of a polynomial \(A(x)\) with \(\deg A<n\) is a set of \(n\) points: \(\{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\}\)

    We can use \(n\) evaluations to convert a coefficient expression to a point-valued expression with a list of \((x_0,x_1,\cdots,x_{n-1})\) in time complexity of \(O(n^2)\) as shown:

    \(\left[\begin{matrix}1&x_0&x_0^2&\cdots&x_0^{n-1}\\1&x_1&x_1^2&\cdots&x_1^{n-1}\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&x_{n-1}&x_{n-1}^2&\cdots&x_{n-1}^{n-1}\end{matrix} \right]\left[\begin{matrix}a_0\\a_1\\\vdots\\a_{n-1}\end{matrix} \right]=\left[\begin{matrix}y_0\\y_1\\\vdots\\y_{n-1}\end{matrix} \right]\)

    The matrix is written as \(V(x_0,x_1,\cdots,x_{n-1})\), named Vandermonde matrix, so the formula is simplified to \(V(x_0,x_1,\cdots,x_{n-1})\pmb{a}=\pmb{y}\).

    Using Lagrangian formulas, a point-valued expression can be converted back into a coefficient expression in \(O(n^2)\) time, a process called interpolation.

    With two polynomials in point-valued expressions with the same list of \((x_0,\cdots,x_{n-1})\), the time complexity of following methods are:

    1. Addition: \(O(n)\) (Adding the \(y_i\) value respectively)

    2. Multiplication \(O(n)\) (similar)

    This is one central idea of FFT powered polynomial multiplication: with carefully chosen \(x_i\) values, we can achieve evaluation in \(O(n\log n)\), multiplication in \(O(n)\), and finally interpolation in \(O(n\log n)\).

    So what are those \(x_i\) values?

    Part 2: Complex roots of unity

    The \(n\)-th roots of unity are exactly \(n\) complex numbers \(\omega\) that satisfy \(\omega^n=1\), written as:

    \(\omega_n^k=e^{2\pi ik/n}=\cos(2\pi k/n)+i\sin(2\pi k/n)\).

    We can plot \(n\)-th roots of unity as \(n\) vertices of a regular \(n\)-gon inscribed in the unit circle on the complex plane. For example, the following graph shows the \(8\)-th roots of unity.

    There is a pattern: \(\omega_n^j\omega_n^k=\omega_n^{j+k}=\omega_n^{(j+k)\mod n}\). Specifically, \(\omega_n^{-1}=\omega_n^{n-1}\).

    Three other important lemmas.

    \(\text{Lemma 1. }\) For all integers \(n\geqslant 0,k\geqslant 0,d>0\), we have \(\omega_{dn}^{dk}=\omega_n^k\).

    \(\textit{Proof. }\omega_{dn}^{dk}=(e^{2\pi i/dn})^{dk}=(e^{2\pi i/n})^k=\omega_n^k.\square\)

    \(\text{Lemma 2. }\) For all even number \(n\) and integer \(k\), we have \((\omega_n^k)^2=(\omega_n^{k+n/2})^2=\omega_{n/2}^k\).

    \(\textit{Proof. }(\omega_n^k)^2=\omega_n^{2k},(\omega_n^{k+n/2})^2=\omega_n^{2k+n}=\omega_n^{2k}\). Lastly, \(\omega_n^{2k}=\omega_{n/2}^k\) by \(\text{Lemma 1}.\square\)

    \(\text{Lemma 3. }\) For all integers \(n,k\geqslant 0\) such that \(n\nmid k\), we have \(\sum_{j=0}^{n-1}(\omega_n^k)^j=0\).

    \(\textit{Proof. }\) When \(n\nmid k\), we have \(\omega_n^k\neq 1\), so \(\sum_{j=0}^{n-1}(\omega_n^k)^j=\frac{1-(\omega_n^k)^n}{1-\omega_n^k}=\frac{1-\omega_n^{nk}}{1-\omega_n^k}=\frac{1-1}{1-\omega_n^k}=0.\square\) (Question: why is \(n\nmid k\) necessary?)

    The above properties of roots of unity are the essence of FFT optimization.

    Part 3: Discrete Fourier Transform

    Recall the definition of \(n\), which is a power of \(2\). DFT is just the evaluation of coefficient expressed \(A(x)\) on \(n\)-th roots of unity. We write the Vandermonde matrix as

    \(V_n=V(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1})=\left[\begin{matrix}1&1&1&1&\cdots&1\\1&\omega_n&\omega_n^2&\omega_n^3&\cdots&\omega_n^{n-1}\\1&\omega_n^2&\omega_n^4&\omega_n^6&\cdots&\omega_n^{2(n-1)}\\1&\omega_n^3&\omega_n^6&\omega_n^9&\cdots&\omega_n^{3(n-1)}\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots\\1&\omega_n^{n-1}&\omega_n^{2(n-1)}&\omega_n^{3(n-1)}&\cdots&\omega_n^{(n-1)(n-1)}\end{matrix} \right]\),

    then the formula of DFT is \(\pmb{y}=\text{DFT}_n(\pmb a)\): \(V_n\pmb{a}=\pmb{y}\). Specifically, \(y_i=\sum_{j=0}^{n-1}[V_n]_{ij}a_j=\sum_{j=0}^{n-1}\omega_n^{ij}a_j\).

    So, how can we achieve it in \(O(n\log n)\)?

    Part 4: FFT

    Like DC multiplication, we split the polynomial by parity: \(A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)\), where \(A^{[0]}(x)=\sum_{i=0}^{n/2-1}a_{2i}x^i,A^{[1]}(x)=\sum_{i=0}^{n/2-1}a_{2i+1}x^i\).

    Then, our evaluation of \(A(x)\) on \(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\) becomes

    1. Divide-and-conquer: evaluating \(A^{[0]}(x)\) and \(A^{[1]}(x)\) on \((\omega_n^0)^2,(\omega_n^1)^2,\cdots,(\omega_n^{n-1})^2\).

    By \(\text{Lemma 2}\), the list \((\omega_n^0)^2,(\omega_n^1)^2,\cdots,(\omega_n^{n-1})^2\) is exactly a repeated list of \(n/2\)-roots of unity (Why?)

    So we can apply \(DFT_{n/2}(\pmb a^{[0]})=y^{[0]},DFT_{n/2}(\pmb a^{[1]})=\pmb y^{[1]}\). And the second step is

    2. Combining the answers.

    As \(\omega_n^{n/2}=e^{2\pi i (n/2)/n}=e^{\pi i}=-1\) (The beautiful Euler's formula!),

    we have \(\omega_n^{k+n/2}=\omega_n^k\omega_n^{n/2}=-\omega_n^k\),

    so \(y_i=y^{[0]}_i+\omega_n^i y^{[1]}_i,y_{i+n/2}=y^{[0]}_i-\omega_n^i y^{[1]}_i,\) for all \(i=0,1,\cdots,n/2-1\).

    Specifically, when \(n=1\), \(\omega_1^0 a_0=a_0\) in the trivial case.

    Let's calculate the time complexity

    \(T(n)=2T(n/2)+f(n)\), in which \(f(n)=O(n)\) is the time used for combination.

    Using Master Theorem with \(a=2,b=2,\log_ba=\log_2 2=1\), we have \(T(n)=O(n^{\log_ba}\log n)=O(n\log n)\). Whooo!

    Part 5: Inverse DFT

    Don't celebrate too soon, there is still interpolation. Awww

    Since \(\pmb{y}=\text{DFT}_n(\pmb{a})=V_n\pmb{a}\), we have \(\pmb{a}=V_n^{-1}\pmb{y}\), written as \(\pmb{a}=\text{DFT}_n^{-1}(\pmb{y})\).

    \(\text{Theorem. }\) For all \(i,j=0,1,\cdots,n-1\), we have \([V_n^{-1}]_{ij}=\omega_n^{-ij}/n\).

    \(\textit{Proof. }\) We show that \(V_n^{-1}V_n=I_n\) the identity matrix:

    \([V_n^{-1}V_n]_{ij}=\sum_{k=0}^{n-1}(\omega_n^{-ik}/n)\omega_n^{kj}=\frac{\sum_{k=0}^{n-1}\omega_n^{-ik}\omega_n^{kj}}{n}=\frac{\sum_{k=0}^{n-1}\omega_n^{(j-i)k}}{n}\)

    If \(i=j\), then \(\frac{\sum_{k=0}^{n-1}\omega_n^0}{n}=n/n=1\). Otherwise, it is \(0/n=0\) by \(\text{Lemma 3}\). Therefore, \(I_n\) is formed. \(\square\)

    Next, \(\pmb{a}=\text{DFT}_n^{-1}(\pmb{y})=V_n^{-1}\pmb{y}\), in which \(a_i=\sum_{j=0}^{n-1}[V_n^{-1}]_{ij}y_j=\sum_{j=0}^{n-1}(\omega_n^{-ij}/n)y_j=\frac{\sum_{j=0}^{n-1}\omega_n^{-ij}y_j}{n}\).

    Let's compare: in DFT, \(y_i=\sum_{j=0}^{n-1}\omega_n^{ij}a_j\).

    Therefore, we can convert DFT to IDFT by simply replacing \(\omega_n^k\) with \(\omega_n^{-k}\) and dividing the final answers by \(n\).

    Part 6: Recursive Implementation

    According to the previous text, we just need to modify the code of DC multiplication.

    To save memory, we redistribute the coefficients of \(A^{[0]}\) to the left and \(A^{[1]}\) to the right.

    In the code, o\(=\omega_n\), w\(=\omega_n^i\).

    P.s Don't for get \(/n\) for IDFT. In the code, the +0.5 is used to improve accuracy for integer-coefficient FFT.

    //This program is written by Brian Peng.
    #pragma GCC optimize("Ofast","inline","no-stack-protector")
    #include<bits/stdc++.h>
    using namespace std;
    #define Rd(a) (a=read())
    #define Gc(a) (a=getchar())
    #define Pc(a) putchar(a)
    int read(){
    	register int u;register char c(getchar());register bool k;
    	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
    	if(c^'-')k=1,u=c&15;else k=u=0;
    	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
    	return k?u:-u;
    }
    void wr(register int a){
    	if(a<0)Pc('-'),a=-a;
    	if(a<=9)Pc(a|'0');
    	else wr(a/10),Pc((a%10)|'0');
    }
    signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
    long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
    #define Ps Pc(' ')
    #define Pe Pc('\n')
    #define Frn0(i,a,b) for(register int i(a);i<(b);++i)
    #define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
    #define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
    #define Mst(a,b) memset(a,b,sizeof(a))
    #define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
    double const Pi(acos(-1));
    typedef complex<double> Cpx;
    #define N (2100000)
    Cpx o,w,a[N],b[N],tmp[N],x,y;
    int n,m,s;
    bool iv;
    void fft(Cpx*a,int n);
    signed main(){
    	Rd(n),Rd(m),s=1<<int(log2(n+m)+1);
    	Frn1(i,0,n)Rd(a[i]);
    	Frn1(i,0,m)Rd(b[i]);
    	fft(a,s),fft(b,s);
    	Frn0(i,0,s)a[i]*=b[i];
    	iv=1,fft(a,s);
    	Frn1(i,0,n+m)wr(a[i].real()/s+0.5),Ps;
    	exit(0);
    }
    void fft(Cpx*a,int n){
    	if(n==1)return;
    	int n2(n>>1);
    	Frn0(i,0,n2)tmp[i]=a[i<<1],tmp[i+n2]=a[i<<1|1];
    	copy(tmp,tmp+n,a),fft(a,n2),fft(a+n2,n2);
    	o={cos(Pi/n2),(iv?-1:1)*sin(Pi/n2)},w=1;
    	Frn0(i,0,n2)x=a[i],y=w*a[i+n2],a[i]=x+y,a[i+n2]=x-y,w*=o;
    }
    

    Time complexity: \(O(n\log n)\)

    Memory complexity: \(O(n)\)

    Results:

    Not fully AC, as recursive implementation is not fast enough.

    Part 6: Iterative Implementation

    For \(n=\deg_A+1,m=\deg B+1\), let \(l=\lceil\log_2(n+m+1)\rceil\) and \(s=2^l\), then \(s\) is the "\(n\)" in previous parts.

    Similarly, we redistribute the coefficients of \(A^{[0]}\) to the left and \(A^{[1]}\) to the right.

    Observe the pattern of redistribution in each layer of recursion. Take \(s=8\) as an example:

    0-> 0 1 2 3 4 5 6 7
    1-> 0 2 4 6|1 3 5 7
    2-> 0 4|2 6|1 5|3 7
    end 0|4|2|6|1|5|3|7
    

    Still confused? Write them in base-2:

    0-> 000 001 010 011 100 101 110 111
    1-> 000 010 100 110|001 011 101 111
    2-> 000 100|010 110|001 101|011 111
    end 000|100|010|110|001|101|011|111
    

    The base-2 expressions are reversed in the last layer!

    A hint of the proof: the redistribution is based on parity, which is equivalent to the last digit of base-2 expressions.

    In the code, we use array \(r_{0..s-1}\) to store the reverse numbers.

    Butterfly Operation

    It is already written in the code of recursive implementation, but let's clarify that:

    Still remember \(y_i=y^{[0]}_i+\omega_n^i y^{[1]}_i,y_{i+n/2}=y^{[0]}_i-\omega_n^i y^{[1]}_i,i=0,1,\cdots,n/2-1\)?

    To save memory, we do not create the array \(\pmb y\), but the combination is done on the original location of the array \(\pmb a\).

    After redistribution, we have \(a^{[0]}_i=a_i\) and \(a^{[1]}_i=a_{i+n/2}\).

    Let \(x=a^{[0]}_i=a_i,y=\omega_n^i a^{[1]}_i=\omega_n^i a_{i+n/2}\),

    then the result of DFT is simply \(a_i=x+y,a_{i+n/2}=x-y\)!

    With Butterfly Operation, we just need to redistribute the coefficients according to \(r\), and then combine iteratively to implement FFT.

    //This program is written by Brian Peng.
    #pragma GCC optimize("Ofast","inline","no-stack-protector")
    #include<bits/stdc++.h>
    using namespace std;
    #define Rd(a) (a=read())
    #define Gc(a) (a=getchar())
    #define Pc(a) putchar(a)
    int read(){
    	register int u;register char c(getchar());register bool k;
    	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
    	if(c^'-')k=1,u=c&15;else k=u=0;
    	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
    	return k?u:-u;
    }
    void wr(register int a){
    	if(a<0)Pc('-'),a=-a;
    	if(a<=9)Pc(a|'0');
    	else wr(a/10),Pc((a%10)|'0');
    }
    signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
    long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
    #define Ps Pc(' ')
    #define Pe Pc('\n')
    #define Frn0(i,a,b) for(register int i(a);i<(b);++i)
    #define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
    #define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
    #define Mst(a,b) memset(a,b,sizeof(a))
    #define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
    double const Pi(acos(-1));
    typedef complex<double> Cpx;
    #define N (2100000)
    Cpx a[N],b[N],o,w,x,y;
    int n,m,l,s,r[N];
    void fft(Cpx*a,bool iv);
    signed main(){
    	Rd(n),Rd(m),s=1<<(l=log2(n+m)+1);
    	Frn1(i,0,n)Rd(a[i]);
    	Frn1(i,0,m)Rd(b[i]);
    	Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    	fft(a,0),fft(b,0);
    	Frn0(i,0,s)a[i]*=b[i];
    	fft(a,1);
    	Frn1(i,0,n+m)wr(a[i].real()+0.5),Ps;
    	exit(0);
    }
    void fft(Cpx*a,bool iv){
    	Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
    	for(int i(2),i2(1);i<=s;i2=i,i<<=1){
    		o={cos(Pi/i2),(iv?-1:1)*sin(Pi/i2)};
    		for(int j(0);j<s;j+=i){
    			w=1;
    			Frn0(k,0,i2){
    				x=a[j+k],y=w*a[j+k+i2];
    				a[j+k]=x+y,a[j+k+i2]=x-y,w*=o;
    			}
    		}
    	}
    	if(iv)Frn0(i,0,s)a[i]/=s;
    }
    

    Time complexity: \(O(n\log n)\)

    Memory complexity: \(O(n)\)

    Results:

    Celebrate


    Extension: Number Theoretic Transform

    Although FFT has excellent time complexity, inaccuracy will inevitably arise because of the use of complex numbers.

    If the polynomial coefficients and results are non-negative integers in a certain range, NTT is a better choice on accuracy and speed.

    Prerequisite knowledge:

      FFT absolutely

      Modular arithmetics basics

    Primitive roots

    Assume that the following calculations are in the context of \(\bmod P\), where \(P\) is a prime number.

    For a positive integer \(g\), if the list of powers of \(g\) contains every positive integer \(<P\), then we call \(g\) a primitive root \(\bmod P\). (Digression: in Group Theory, the equivalence class of \(g\) in \(\Z_p\) is a generator of \(\Z_p^*\))

    E.g For \(P=7\) and for all positive integers \(<P\), we calculate the possibilities of their powers.

    1-> {1}
    2-> {1,2,4}
    3-> {1,2,3,4,5,6}
    4-> {1,2,4}
    5-> {1,2,3,4,5,6}
    6-> {1,6}
    

    Therefore, \(3,5\) are the primitive roots \(\bmod 7\).

    In the code, we commonly use \(P=998244353,g=3\).

    The special property of primitive root \(g\) is that its powers repeat with period \(P-1\).

    E.g Let \(P=7,g=3\), then the powers of \(g\) (beginning with \(g^0\)) are:\(1,3,2,6,4,5,1,3,2,6,4,5,\cdots\).

    This property is very similar to the roots of unity. If we take \(n=P-1\) and \(\omega_n=g\), then all three lemmas in the FFT part are satisfied.

    However, to complete NTT, there is one last step.

    The substitute for roots of unity

    In FFT, we use \(n\)-th roots of unity, where \(n\) is a power of \(2\).

    However, \(P-1\) is not necessarily \(n\). Hence, we cannot directly replace \(\omega_n\) with \(g\).

    Now, as the powers of \(g\) have a period of \(P-1\),

    if we take a factor \(k\) of \(P-1\), then the powers of \(g^k\) have a period of \(\frac{P-1}{k}\). (Why?)

    This means that if we take \(k=\frac{P-1}{n}\), then the powers of \(g^k\) have a period of exactly \(n\).

    But, how can we be sure that \(n\) is always a factor of \(P-1\)?

    This is why we choose \(P=998244353\), as \(P-1=998244352=2^{23}\cdot 7\cdot 17\), with a high multiplicity of \(2\).

    Therefore, \(g^{\frac{P-1}{n}}\) is just our substitute of \(\omega_n\).

    In the code, we use \(g^{-1}=332748118\) and \(\cdot s^{-1}\) when doing IDFT. Make sure that you include \(\bmod P\) in every operation.

    //This program is written by Brian Peng.
    #pragma GCC optimize("Ofast","inline","no-stack-protector")
    #include<bits/stdc++.h>
    using namespace std;
    #define int long long
    #define Rd(a) (a=read())
    #define Gc(a) (a=getchar())
    #define Pc(a) putchar(a)
    int read(){
    	register int u;register char c(getchar());register bool k;
    	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
    	if(c^'-')k=1,u=c&15;else k=u=0;
    	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
    	return k?u:-u;
    }
    void wr(register int a){
    	if(a<0)Pc('-'),a=-a;
    	if(a<=9)Pc(a|'0');
    	else wr(a/10),Pc((a%10)|'0');
    }
    signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
    long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
    #define Ps Pc(' ')
    #define Pe Pc('\n')
    #define Frn0(i,a,b) for(register int i(a);i<(b);++i)
    #define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
    #define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
    #define Mst(a,b) memset(a,b,sizeof(a))
    #define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
    #define P (998244353)
    #define G (3)
    #define Gi (332748118)
    #define N (2100000)
    int n,m,l,s,r[N],a[N],b[N],o,w,x,y,siv;
    int fpw(int a,int p){return p?a>>1?(p&1?a:1)*fpw(a*a%P,p>>1)%P:a:1;}
    void ntt(int*a,bool iv);
    signed main(){
    	Rd(n),Rd(m),siv=fpw(s=1<<(l=log2(n+m)+1),P-2);
    	Frn1(i,0,n)Rd(a[i]);
    	Frn1(i,0,m)Rd(b[i]);
    	Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    	ntt(a,0),ntt(b,0);
    	Frn0(i,0,s)a[i]=a[i]*b[i]%P;
    	ntt(a,1);
    	Frn1(i,0,n+m)wr(a[i]),Ps;
    	exit(0);
    }
    void ntt(int*a,bool iv){
    	Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
    	for(int i(2),i2(1);i<=s;i2=i,i<<=1){
    		o=fpw(iv?Gi:G,(P-1)/i);
    		for(int j(0);j<s;j+=i){
    			w=1;
    			Frn0(k,0,i2){
    				x=a[j+k],y=w*a[j+k+i2]%P;
    				a[j+k]=(x+y)%P,a[j+k+i2]=(x-y+P)%P,w=w*o%P;
    			}
    		}
    	}
    	if(iv)Frn0(i,0,s)a[i]=a[i]*siv%P;
    }
    

    Time complexity: \(O(n\log n)\)

    Memory complexity: \(O(n)\)

    Results

    No significant improvement in time, but halved the memory cost as int instead of complex is used.


    The End:

    Translating is sooooo time-consuming...

    Another year with Cnblogs! Happy new year!

    Thanks for your support! ありがとう!


    Reference:

    Introduction to Algorithms

    自为风月马前卒:快速傅里叶变换(FFT)详解

    自为风月马前卒:快速数论变换(NTT)小结

  • 相关阅读:
    python的虚拟环境管理
    树的转换
    表达式·表达式树·表达式求值
    找和为K的两个元素
    最大奇数与最小偶数之差的绝对值
    L1-026 I Love GPLT (5分)
    L1-025 正整数A+B (15分)
    L1-024 后天 (5分)
    L1-023 输出GPLT (20分)
    L1-022 奇偶分家 (10分)
  • 原文地址:https://www.cnblogs.com/BrianPeng/p/15761230.html
Copyright © 2011-2022 走看看