zoukankan      html  css  js  c++  java
  • FFT入门

    可以先了解下用分治法求多项式的值 链接

    求 $A(x)$ 在 $w_j,j=0,1,2...,n-1$ 处的值:

    设多项式

    $$A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+ dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1}$$

    按下标的奇偶性分类,设

    $$A_1(x)=a_0+a_2*{x}+a_4*{x^2}+dots+a_{n-2}*x^{frac{n}{2}-1}$$

    $$A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ dots+a_{n-1}*x^{frac{n}{2}-1}$$

    那么不难得到 $A(x)=A_1(x^2)+xA_2(x^2)$.

    我们将 $omega_n^k (k<frac{n}{2})$ 代入得

    $$egin{aligned}
    A(omega_n^k) &=A_1(omega_n^{2k})+omega_n^kA_2(omega_n^{2k}) \
    &=A_1(omega_{frac{n}{2}}^{k})+omega_n^kA_2(omega_{frac{n}{2}}^{k})
    end{aligned}$$

    同理,将 $omega_n^{k+frac{n}{2}}$ 代入得

    $$egin{aligned}
    A(omega_n^{k+frac{n}{2}}) &=A_1(omega_n^{2k+n})+omega_n^{k+frac{n}{2}}(omega_n^{2k+n}) \
    &=A_1(omega_n^{2k}*omega_n^n)-omega_n^kA_2(omega_n^{2k}*omega_n^n) \
    &=A_1(omega_n^{2k})-omega_n^kA_2(omega_n^{2k})
    end{aligned}$$

    两个式子只有符号不同,也就是说,算出在前 $n$ 个点的值就能得到后 $n$ 个点的值,相当于问题规模减半了。

    所以可以递归的实现,直到多项式仅剩一个常数项,这时候我们直接返回就好啦!

    这样时间复杂度为 $O(nlogn)$.

    FFT算法的伪代码:

    1.求值 $A(w_j), B(w_j)$,$j=0,1,..2n-1$

    2. 计算 $C(w_j)$,$j=0,1,...,2n-1$

    3. 构造多项式

    $D(x)=C(w_0) + C(w_1)x+...+C(w_{2n-1})x^{2n-1}$

    4. 计算所有的 $D(w_j)$,$j=0,1,...2n-1$

    5. 利用下式计算 $C(x)$ 的系数 $c_j$

    $D(w_j) = 2nc_{2n-j}, j=1,...,2n-1$

    $D(w_0) = 2nc_0$

    递归版

    fft函数 $type=1$ 进去的时候 $a$ 数组存的是系数,返回时存的是计算出来的值,$a[i]=A(w_i)$;

    $type=-1$ 时相反。

    这里预处理了sin和cos值,大概能快2、3倍。

    // luogu-judger-enable-o2
    #include<iostream>
    #include<cstdio>
    #include<cmath>
    using namespace std;
    const int MAXN = 4 * 1e6 + 10; //4倍空间
    inline int read() {
        char c = getchar(); int x = 0, f = 1;
        while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();}
        while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
        return x * f;
    }
    const double Pi = acos(-1.0);
    double ccos[MAXN], ssin[MAXN];
    struct complex {
        double x, y;
        complex (double xx = 0, double yy = 0) {x = xx, y = yy;}
    } a[MAXN], b[MAXN];
    complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);}
    complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);}
    complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分
    void fast_fast_tle(int limit, complex *a, int type) {
        if (limit == 1) return ; //只有一个常数项
        complex a1[limit >> 1], a2[limit >> 1];
        for (int i = 0; i < limit; i += 2) //根据下标的奇偶性分类
            a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
        fast_fast_tle(limit >> 1, a1, type);
        fast_fast_tle(limit >> 1, a2, type);
        complex Wn = complex(ccos[limit] , type * ssin[limit]), w = complex(1, 0);
        //complex Wn = complex(cos(2.0 * Pi / limit) , type * sin(2.0 * Pi / limit)), w = complex(1, 0);
        //Wn为单位根,w表示幂
        for (int i = 0; i < (limit >> 1); i++, w = w * Wn) //这里的w相当于公式中的k
        {
            complex tmp = w * a2[i];
            a[i] = a1[i] + tmp;
            a[i + (limit >> 1)] = a1[i] - tmp; //利用单位根的性质,O(1)得到另一部分
        }
    }
    int main() {
        int N = read(), M = read();     //N,M是秩最高次幂
        for (int i = 0; i <= N; i++) a[i].x = read();
        for (int i = 0; i <= M; i++) b[i].x = read();
        int limit = 1; while (limit <= N + M) limit <<= 1;
    
        for(int i = 1;i <= limit;i++)
        {
            ccos[i] = cos(2.0 * Pi / i);
            ssin[i] = sin(2.0 * Pi / i);
        }
    
        fast_fast_tle(limit, a, 1);
        fast_fast_tle(limit, b, 1);
        //后面的1表示要进行的变换是什么类型
        //1表示从系数变为点值
        //-1表示从点值变为系数
        //至于为什么这样是对的,可以参考一下c向量的推导过程,
        for (int i = 0; i <= limit; i++)
            a[i] = a[i] * b[i];
        fast_fast_tle(limit, a, -1);
        for (int i = 0; i <= N + M; i++) printf("%d ", (int)(a[i].x / limit + 0.5)); //按照我们推倒的公式,这里还要除以n
        return 0;
    }

    非递归版

    这个很容易发现点什么吧?

    • 每个位置分治后的最终位置为其二进制翻转后得到的位置

    这样的话我们可以先把原序列变换好,把每个数放在最终的位置上,再一步一步向上合并。

    一句话就可以 $O(n)$ 预处理出位置 $i$ 最终的位置 $rev[i]$:

    //原理也很简单,将高bit-1位(也就是i/2)反转,再将第一位补到最高位。

    fo(i,0,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    // luogu-judger-enable-o2
    #include<iostream>
    #include<cstdio>
    #include<cmath>
    using namespace std;
    const int MAXN = 4e6 + 10;  //开4倍空间
    inline int read() {
        char c = getchar(); int x = 0, f = 1;
        while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();}
        while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
        return x * f;
    }
    const double Pi = acos(-1.0);
    struct complex {
        double x, y;
        complex (double xx = 0, double yy = 0) {x = xx, y = yy;}
    } a[MAXN], b[MAXN];
    complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);}
    complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);}
    complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分
    int N, M;
    int bit, r[MAXN];
    int limit = 1;
    void fast_fast_tle(complex *A, int type) {
        for (int i = 0; i < limit; i++)
            if (i < r[i]) swap(A[i], A[r[i]]); //求出要迭代的序列
        for (int mid = 1; mid < limit; mid <<= 1) { //待合并区间的长度的一半
            complex Wn( cos(Pi / mid) , type * sin(Pi / mid) ); //单位根
            for (int R = mid << 1, j = 0; j < limit; j += R) { //R是区间的长度,j表示前已经到哪个位置了
                complex w(1, 0); //
                for (int k = 0; k < mid; k++, w = w * Wn) { //枚举左半部分
                    complex x = A[j + k], y = w * A[j + mid + k]; //蝴蝶效应
                    A[j + k] = x + y;
                    A[j + mid + k] = x - y;
                }
            }
        }
    }
    int main() {
        int N = read(), M = read();
        for (int i = 0; i <= N; i++) a[i].x = read();
        for (int i = 0; i <= M; i++) b[i].x = read();
        while (limit <= N + M) limit <<= 1, bit++;
        for (int i = 0; i < limit; i++)
            r[i] = ( r[i >> 1] >> 1 ) | ( (i & 1) << (bit - 1) ) ;
        fast_fast_tle(a, 1);
        fast_fast_tle(b, 1);
        for (int i = 0; i <= limit; i++) a[i] = a[i] * b[i];
        fast_fast_tle(a, -1);
        for (int i = 0; i <= N + M; i++)
            printf("%d ", (int)(a[i].x / limit + 0.5));
        return 0;
    }

    参考链接:

    (建议直接看大佬的,我都是copy过来的,整理一下思路而已)

    https://www.cnblogs.com/zwfymqz/p/8244902.html

    2. https://blog.csdn.net/enjoy_pascal/article/details/81478582

  • 相关阅读:
    linux安装mysql
    yum命令
    java启动jar包中的指定类
    linux系统配置参数修改
    iconfont阿里巴巴矢量图标库批量保存
    Python 使用Pandas读取Excel的学习笔记
    在Ubuntu18.04的Docker中安装Oracle镜像及简单使用
    Eclipse 安装PyDev开发Python及初步使用
    Python打包工具
    MacOS下打包Python应用
  • 原文地址:https://www.cnblogs.com/lfri/p/11575400.html
Copyright © 2011-2022 走看看