zoukankan      html  css  js  c++  java
  • P3803 【模板】多项式乘法(FFT)

    (color{#0066ff}{题目描述})

    给定一个n次多项式F(x),和一个m次多项式G(x)。

    请求出F(x)和G(x)的卷积。

    (color{#0066ff}{输入格式})

    第一行2个正整数n,m。

    接下来一行n+1个数字,从低到高表示F(x)的系数。

    接下来一行m+1个数字,从低到高表示G(x)的系数

    (color{#0066ff}{输出格式})

    一行n+m+1个数字,从低到高表示F(x)∗G(x)的系数。

    (color{#0066ff}{输入样例})

    1 2
    1 2
    1 2 1
    

    (color{#0066ff}{输出样例})

    1 4 5 2
    

    (color{#0066ff}{数据范围与提示})

    保证输入中的系数大于等于 0 且小于等于9。

    对于100%的数据:(n, m leq {10}^6) , 共计20个数据点,2s。

    数据有一定梯度。

    空间限制:256MB

    (color{#0066ff}{题解})

    对于两个多项式相乘,显然暴力是(O(n^2))

    如何优化呢?

    我们知道,n+1个点可以唯一确定一个n次多项式

    那么我们对于(A(x)*B(x)=C(x))

    可以拆成这样

    (left{egin{matrix}(a_1,b_1) \ (a_2,b_2) \(a_3,b_3) \ . \ . \ . \(a_{n+1},b_{n+1)}end{matrix} ight} * left{egin{matrix}(c_1,d_1) \ (c_2,d_2) \(c_3,d_3) \ . \ . \ . \(c_{n+1},d_{n+1)}end{matrix} ight} = left{egin{matrix}(a_1*c_1,b_1*d_1) \ (a_2*c_2,b_2*d_2) \(a_3*c_3,b_3*d_3) \ . \ . \ . \(a_{n+1}*c_{n+1},b_{n+1}*d_{n+1})end{matrix} ight})

    上面这个点值表达式的运算显然是(O(n))

    我们现在要把A和B转成点值表达式,然后乘过去,最后再转换回来

    现在考虑转成点值表达式

    比如(A(x)=a_0+a_1x+a_2x^2)

    那么点值表达式就是

    ((x_1,a_0+a_1x_1+a_2x_1^2)(x_2,a_0+a_1x_2+a_2x_2^2)(x_3,x_0+a_1x_3+a_2x_3^2))

    但是带入求y用秦九韶是(O(n))的,而且要带入n个点,所以就(O(n^2))

    现在考虑优化这个过程

    引入一个东西,单位复数根

    定义(omega_n^n=1,omega_n)有n个

    对于复数坐标系的两个点(可以理解为两个向量)

    两个单位向量A,B,模长为(a=b=1),角度为(alpha,eta)

    欧拉公式:(e^{ix}=cos x+i*sin x)

    所以二者相乘即为(ae^{ialpha}*be^{ieta}=abe^{i(alpha+eta)}=Ae^{iTheta})

    也就是说,复数相乘集合意义,模长相乘,极角相加

    因此,(w_n^n=1),则(nTheta\%2pi=0)

    (Theta=frac{2pi}{n})

    定义(omega_0=(1,0),omega_1=omega_n^1)

    (omega_n^k=omega_0*omega_1^k)

    而且有(omega_n^{k+frac{n}{2}}=-omega_n^k),因为关于原点对称

    还有(omega_{2n}^{2k}=omega_n^k),n份取k份等同于2n份取2k份

    回到多项式(A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1})

    我们硬凑一下,使得(A(x))共有(2^k)次方项(不足后面补系数0)

    然后开始分治,把(A(x))拆成(A_0(x))(A_1(x))

    (A_0(x)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{frac{n-2}{2}})

    (A_1(x)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{frac{n-2}{2}})

    不难发现(A(x)=A_0(x^2)+x*A_1(x^2))

    我们把(omega_n^k)带入

    (A(omega_n^k)=A_0(omega_n^{2k})+omega_n^k*A_1(omega_n^{2k})=A_0(omega_{frac{n}{2}}^k)+omega_n^k*A_1(omega_{frac{n}{2}}^k))

    (A(omega_n^{k+frac{n}{2}})=A_0(-omega_n^{k})-omega_n^k*A_1(-omega_n^{k})=A_0(omega_{frac{n}{2}}^k)-omega_n^k*A_1(omega_{frac{n}{2}}^k))

    两式子差距只在正负号!

    所以求出了(A_0,A_1),相加即为第一个,相减即为第二个

    也就是说,对于当前序列,我们求出了一半,另一半也出来了

    那么我们去找一找对应关系

    ({a_0} ||||| {a_1} ||||| {a_2} ||||| {a_3} ||||| {a_4} ||||| {a_5} ||||| {a_6} ||||| {a_7})

    ({000} ||| {001} ||| {010} ||| {011} ||| {100} ||| {101} ||| {110} ||| {111})

    (|||||||||||{a_0,a_2,a_4,a_6} ||||||||||||||||||||||||| {a_1,a_3,a_5,a_7}||||||||||||)

    (||||||{a_0,a_4} |||||||||| {a_2,a_6} ||||||||||||| {a_1,a_5} |||||||||| {a_3,a_7} |||||)

    ({a_0} ||||| {a_4} ||||| {a_2} ||||| {a_6} ||||| {a_1} ||||| {a_5} ||||| {a_3} ||||| {a_7})

    ({000} ||| {100} ||| {010} ||| {110} ||| {001} ||| {101} ||| {011} ||| {111})

    卧槽,这是二进制翻转啊

    我们用(r_i)代表数i翻转后是几

    (r[i]=r[i>>1]>>1|(i&1)*(len<<1))

    要求当前位置的翻转后的数,把当前最后一位删去

    那么当前的数的r已经求出来了

    将这个数翻转(注意,二进制位数固定,比如长度为8,那么(00000001 o 10000000)),再<<1,

    这样现在的数除了最高位其他都是当前的翻转数了

    这时只要考虑原来最后一位是不是1即可

    通过FFT,我们求出了A(x)的点值表达式

    B(x)同理,然(A(x)*B(x)=C(x))(O(n))

    现在的问题是怎么转回系数表达式

    定理

    (left{egin{matrix} y_0\y_1\y_2\.\.\.\y_{n-1}end{matrix} ight}=left{egin{matrix} 1 & 1 & 1 & ...& 1 \1 & omega_n & omega_n^2 & ... & omega_n^{n-1}\1 & omega_n^2 & omega_n^4 & ... & omega_n^{2(n-1)}\. & . & . & . & .\. & . & . & . & .\. & . & . & . & .\1 & omega_n^{n-1} & omega_n^{2(n-1)} & ... & omega_n^{(n-1)(n-1)}end{matrix} ight}*left{egin{matrix} a_0\a_1\a_2\.\.\.\a_{n-1}end{matrix} ight})

    左面是点值表达式,右面是系数表达式

    现在已知(Y=W*A)

    (A=Y*W^{-1})

    于是要矩阵求逆

    不过这个矩阵又有个定理

    它的逆矩阵为:指数全取负,再所有数/n之后的矩阵

    那么就简单了

    #include<bits/stdc++.h>
    #define LL long long
    LL in() {
    	char ch; LL x = 0, f = 1;
    	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
    	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
    	return x * f;
    }
    const double pi = acos(-1);
    const int maxn = 3e6 + 41;
    struct node {
    	double x, y;
    	node(double x = 0, double y = 0): x(x), y(y) {}
    	friend node operator + (const node &a, const node &b) { return node(a.x + b.x, a.y + b.y); }
    	friend node operator - (const node &a, const node &b) { return node(a.x - b.x, a.y - b.y); }
    	friend node operator * (const node &a, const node &b) { return node(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
    	friend node operator / (const node &a, const double &b) { return node(a.x / b, a.y / b); }
    }A[maxn], B[maxn], C[maxn];
    int len, n, m, r[maxn];
    void FFT(node *D, int flag) {
    	for(int i = 0; i < len; i++) if(i < r[i]) std::swap(D[i], D[r[i]]);
    	for(int l = 1; l < len; l <<= 1) {
    		node w0(cos(pi / l), flag * sin(pi / l));
    		for(int i = 0; i < len; i += (l << 1)) {
    			node w(1, 0), *a0 = D + i, *a1 = D + i + l;
    			for(int k = 0; k < l; k++, a0++, a1++, w = w * w0) {
    				node tmp = *a1 * w;
    				*a1 = *a0 - tmp;
    				*a0 = *a0 + tmp;
    			}
    		}
    	}
    	if(!(~flag)) for(int i = 0; i < len; i++) D[i] = D[i] / len;
    }
    int main() {
    	n = in(), m = in();
    	for(len = 1; len <= n + m; len <<= 1);
    	for(int i = 0; i <= n; i++) A[i] = in();
    	for(int i = 0; i <= m; i++) B[i] = in();
    	for(int i = 1; i < len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    	FFT(A, 1), FFT(B, 1);
    	for(int i = 0; i < len; i++) C[i] = A[i] * B[i];
    	FFT(C, -1);
    	for(int i = 0; i <= n + m; i++) printf("%d%c", (int)round(C[i].x), i == n + m? '
    ' : ' ');
    	return 0;
    }
    

    NTT模数写法

    #include<bits/stdc++.h>
    #define LL long long
    LL in() {
        char ch; LL x = 0, f = 1;
        while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
        for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
        return x * f;
    }
    using std::vector;
    const int mod = 998244353;
    const int maxn = 3e6 + 10;
    LL ksm(LL x, LL y) {
        LL re = 1LL;
        while(y) {
            if(y & 1) re = re * x % mod;
            x = x * x % mod;
            y >>= 1;
        }	
        return re;
    }
            
    void FNTT(vector<int> &A, int len, int flag) { 
    	A.resize(len);
    	int *r = new int[maxn];
    	r[0] = 0;
    	for(int i = 0; i < len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    	for(int i = 0; i < len; i++) if(i < r[i]) std::swap(A[i], A[r[i]]);
        for(int l = 1; l < len; l <<= 1) {
    		int w0 = ksm(3, (mod - 1) / (l << 1));
    		for(int i = 0; i < len; i += (l << 1)) {
    			int w = 1, a0 = i, a1 = l + i;
    			for(int k = 0; k < l; k++, a0++, a1++, w = 1LL * w * w0 % mod) {
                    int tmp = 1LL * A[a1] * w % mod;
                    A[a1] = ((A[a0] - tmp) % mod + mod) % mod;
                    A[a0] = (A[a0] + tmp) % mod;
    			}
    		}
    	}
    	if(!(~flag)) {
    		std::reverse(A.begin() + 1, A.end());
    		int inv = ksm(len, mod - 2);
    		for(int i = 0; i < len; i++) A[i] = 1LL * inv * A[i] % mod;
    	}
    	delete []r;
    }
    
    vector<int> operator * (vector<int> A, vector<int> B) {
        int tot = A.size() + B.size() - 1;
        int len = 1;
        while(len <= tot) len <<= 1;
        FNTT(A, len, 1);
        FNTT(B, len, 1);
        vector<int> ans;
        ans.resize(len);
        for(int i = 0; i < len; i++) ans[i] = 1LL * A[i] * B[i] % mod;
        FNTT(ans, len, -1);
        ans.resize(tot);
        return ans;
    }
    signed main() {
        int n = in(), m = in();
        vector<int> A, B, C;
        for(int i = 0; i <= n; i++) A.push_back(in());
        for(int i = 0; i <= m; i++) B.push_back(in());
        C = A * B;
        for(int i = 0; i <= n + m; i++) printf("%d%c", C[i], i == n + m? '
    ' : ' ');
        return 0;
    }
    
  • 相关阅读:
    git命令的使用
    动态生成表格的每一行的操作按钮如何获取当前行的index
    js判断一些时间范围是否有重复时间段
    infiniband install driver
    python之pip install
    KVM :vnc 远程控制kvm创建虚拟机
    如何设置UNIX/Linux中新创建目录或文件的默认权限
    python获取报文参考代码
    JAVA命名规范
    oracle常用知识随笔
  • 原文地址:https://www.cnblogs.com/olinr/p/10089226.html
Copyright © 2011-2022 走看看