zoukankan      html  css  js  c++  java
  • 【知识总结】多项式全家桶(三点五)(拆系数解决任意模数多项式卷积)

    上一篇:【知识总结】多项式全家桶(三)(任意模数 NTT)

    (请无视此图)

    我最近学了一个常数小还不用背三个模数的做法:拆系数法

    (以下默认多项式项数 (N=10^5) ,系数不超过 (M=10^9) 且为非负整数)

    我们放弃「数论变换」「利用原根性质」之类的想法,来点简单粗暴的:用实数 FFT 把原始结果算出来,然后直接取模。

    为什么过去我们没有这样做呢?因为卷积结果的系数最大可能达到 (NM^2=10^{24}) ,long double 的精度也不够(通常情况下 double 的精度约为 15 位十进制,long double 的精度约为 19 位十进制)。考虑「拆系数」来牺牲时间保证精度。

    设相乘的两个多项式为 (A(x)=sum a_ix^i)(B(x)=sum b_ix^i) ,结果为 (C(x)=sum c_ix^i) 。把(A(x)) 拆成两个多项式 (A_0(x)=sum {a_0}_ix^i)(A_1(x)=sum {a_1}_ix^i) ,其中 ({a_1}_i=lfloorfrac{a_i}{S} floor)({a_0}_i=a_i-Scdot {a_1}_i) ,即 (a_i=Scdot {a_1}_i+{a_0}_i) 。对 (B(x)) 也作同样的操作。这样,就有

    [egin{aligned} a_ib_j&=(Scdot {a_1}_i+{a_0}_i)(Scdot {b_1}_j+{b_0}_j)\ &={a_1}_i{b_1}_jS^2+({a_1}_i{b_0}_j+{a_0}_i{b_1}_j)S+{a_0}_i{b_0}_j end{aligned}]

    于是直接计算 (C_1=A_1*B_1)(C_2=A_1*B_0+A_0*B_1)(C_3=A_0*B_0) ,然后 (c_i={c_1}_iS^2+{c_2}_iS+{c_3}_i) ,算最后一步的时候对「任意模数」取模即可。

    这样,大致估计一下 (C_1) 的最大系数是 (Ncdot(frac{M}{S})^2=frac{NM^2}{S^2})(C_2) 最大系数是 (2Ncdotfrac{M}{S}cdot S=2NM)(C_3) 最大系数是 (NS^2) 。当 (S=sqrt{M}) 时,以上三项均为 (NM) 级别,即 (10^{14}) 左右,足以保证精度(跟瓜学的一般偷懒直接 (S=32768) )。

    代码:

    事实上由于只跟 7 个多项式有关,所以只需要进行 7 次 FFT 。我写的常数特别大了不要跟我学 ……

    题目:洛谷 4239 多项式求逆(加强版)

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <cctype>
    #include <cmath>
    using namespace std;
    
    namespace zyt
    {
    	template<typename T>
    	inline bool read(T &x)
    	{
    		char c;
    		bool f = false;
    		x = 0;
    		do
    			c = getchar();
    		while (c != EOF && c != '-' && !isdigit(c));
    		if (c == EOF)
    			return false;
    		if (c == '-')
    			f = true, c = getchar();
    		do
    			x = x * 10 + c - '0', c = getchar();
    		while (isdigit(c));
    		if (f)
    			x = -x;
    		return true;
    	}
    	template<typename T>
    	inline void write(T x)
    	{
    		static char buf[20];
    		char *pos = buf;
    		if (x < 0)
    			putchar('-'), x = -x;
    		do
    			*pos++ = x % 10 + '0';
    		while (x /= 10);
    		while (pos > buf)
    			putchar(*--pos);
    	}
    	const int N = 1e5 + 10, S = 1 << 15, p = 1e9 + 7, B = 18;
    	typedef long double ld;
    	typedef long long ll;
    	int power(int a, int b)
    	{
    		int ans = 1;
    		while (b)
    		{
    			if (b & 1)
    				ans = (ll)ans * a % p;
    			a = (ll)a * a % p;
    			b >>= 1;
    		}
    		return ans;
    	}
    	int get_inv(const int a)
    	{
    		return power(a, p - 2);
    	}
    	ll dtol(const ld x)
    	{
    		return ll(fabs(x) + 0.5) * (x < 0 ? -1 : 1);
    	}
    	namespace Polynomial
    	{
    		const int LEN = 1 << B;
    		const ld PI = acos(-1.0L);
    		struct cpx
    		{
    			ld x, y;
    			cpx(const ld _x = 0, const ld _y = 0)
    				: x(_x), y(_y) {}
    			cpx conj()
    			{
    				return cpx(x, -y);
    			}
    		}omega[LEN], winv[LEN];
    		ll ctol(const cpx &a)
    		{
    			return dtol(a.x);
    		}
    		int rev[LEN];
    		cpx operator + (const cpx &a, const cpx &b)
    		{
    			return cpx(a.x + b.x, a.y + b.y);
    		}
    		cpx operator - (const cpx &a, const cpx &b)
    		{
    			return cpx(a.x - b.x, a.y - b.y);
    		}
    		cpx operator * (const cpx &a, const cpx &b)
    		{
    			return cpx(a.x * b.x - a.y * b.y, a.y * b.x + a.x * b.y);
    		}
    		void init(const int n, const int lg2)
    		{
    			cpx w = cpx(cos(2.0L * PI / n), sin(2.0L * PI / n)), wi = w.conj();
    			omega[0] = winv[0] = 1;
    			for (int i = 1; i < n; i++)
    				omega[i] = omega[i - 1] * w, winv[i] = winv[i - 1] * wi;
    			for (int i = 0; i < n; i++)
    				rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
    		}
    		void FFT(cpx *const a, const cpx * const w, const int n)
    		{
    			for (int i = 0; i < n; i++)
    				if (i < rev[i])
    					swap(a[i], a[rev[i]]);
    			for (int l = 1; l < n; l <<= 1)
    				for (int i = 0; i < n; i += (l << 1))
    					for (int k = 0; k < l; k++)
    					{
    						cpx x = a[i + k], y = a[i + l + k] * w[n / (l << 1) * k];
    						a[i + k] = x + y, a[i + l + k] = x - y;
    					}
    		}
    		void mul(const cpx *const a, const cpx *const b, cpx *const c, const int n)
    		{
    			static cpx x[LEN], y[LEN];
    			int m = 1, lg2 = 0;
    			while (m < (n + n - 1))
    				m <<= 1, ++lg2;
    			init(m, lg2);
    			memcpy(x, a, sizeof(cpx[n]));
    			memcpy(y, b, sizeof(cpx[n]));
    			for (int i = n; i < m; i++)
    				x[i] = y[i] = 0;
    			FFT(x, omega, m), FFT(y, omega, m);
    			for (int i = 0; i < m; i++)
    				x[i] = x[i] * y[i];
    			FFT(x, winv, m);
    			for (int i = 0; i < n; i++)
    				c[i] = cpx(x[i].x / m, 0.0);
    		}
    		void MTT(const int *const a, const int *const b, int *const ans, const int n)
    		{
    			const int S = 1 << 15;
    			static cpx a0[LEN], a1[LEN], b0[LEN], b1[LEN], c1[LEN], c2[LEN], c3[LEN], c4[LEN];
    			for (int i = 0; i < n; i++)
    			{
    				a0[i] = cpx(a[i] % S, 0), a1[i] = cpx(a[i] / S, 0);
    				b0[i] = cpx(b[i] % S, 0), b1[i] = cpx(b[i] / S, 0);
    			}
    			mul(a0, b0, c1, n), mul(a0, b1, c2, n), mul(a1, b0, c3, n), mul(a1, b1, c4, n);
    			for (int i = 0; i < n; i++)
    			{
    				int x1 = ctol(c1[i]) % p, x2 = ctol(c2[i]) % p, x3 = ctol(c3[i]) % p, x4 = ctol(c4[i]) % p;
    				ans[i] = (x1 + ll(x2 + x3) * S % p + ll(x4) * S % p * S % p) % p;
    			}
    		}
    		void _inv(const int *const a, int *b, const int n)
    		{
    			if (n == 1)
    				return void(b[0] = get_inv(a[0]));
    			static int tmp[LEN];
    			_inv(a, b, (n + 1) >> 1);
    			memset(b + ((n + 1) >> 1), 0, sizeof(int[n - ((n + 1) >> 1)]));
    			MTT(a, b, tmp, n), MTT(tmp, b, tmp, n);
    			for (int i = 0; i < n; i++)
    				b[i] = (2LL * b[i] % p - tmp[i] + p) % p;
    		}
    		void inv(const int *const a, int *b, const int n)
    		{
    			static int tmp[LEN];
    			memcpy(tmp, a, sizeof(int[n]));
    			_inv(tmp, b, n);
    		}
    	}
    	int A[N << 1];
    	int work()
    	{
    		using namespace Polynomial;
    		int n;
    		read(n);
    		for (int i = 0; i < n; i++)
    			read(A[i]);
    		inv(A, A, n);
    		for (int i = 0; i < n; i++)
    			write(A[i]), putchar(' ');
    		return 0;
    	}
    }
    int main()
    {
    #ifdef BlueSpirit
    	freopen("4239.in", "r", stdin);
    #endif
    	return zyt::work();
    }
    
  • 相关阅读:
    将SqlServer的数据导出到Excel/csv中的各种方法 .
    SqlServer: 单用户模式下查杀相关进程实现单/多用户转换 .
    SQL Server游标的使用【转】
    由几行代码浅析C#的方法参数传递
    脑力风暴之小毛驴历险记(1)好多胡萝卜(下)
    关于sql_mode=only_full_group_by问题的解决方法
    如何同一个controller下添加新页面
    UNIAPP全局变量的实现方法
    Ztree点击节点选中复选框的相关操作
    一沙框架(yishaadmin) 前端引入VUE的实现方法
  • 原文地址:https://www.cnblogs.com/zyt1253679098/p/11160807.html
Copyright © 2011-2022 走看看