zoukankan      html  css  js  c++  java
  • 2020.3.1考试T1 多项式

    出题人很凉心的把算法写成了题目名

    首先我们可以发现每一维的贡献是独立的,这可以从 (solve1) 里看出来

    然后我们可以考虑转化为 (DP) ,这可以从 (solve2) 里看出来

    我们统计每一维能产生的贡献,就是 (a)(0) 面, (b)(1) 面, (c)(2) 面这种形式,能写成一个多项式 (ax^0+bx^1+cx^2),而我们最终显然就是把所有的多项式都乘起来。

    暴力一个一个乘就很 naive,分治 (NTT) 解决就好啦。

    不透彻的话把每个 (solve) 都看一遍就好啦。

    Warning:请不要学习本代码的分治 (NTT) 写法,考场上现想出来的,实现麻烦了不少,建议学一下其他大佬的写法。

    #include<iostream>
    #include<cstdio>
    #define int long long
    #define LL long long
    using namespace std;
    int n;
    const int N = 400010, mod = 469762049, G = 3, Ginv = (mod + 1) / 3;
    int a[N], b[N], c[N], ans[N];
    inline int read()
    {
    	int res = 0; char ch = getchar(); bool XX = false;
    	for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
    	for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
    	return XX ? -res : res;
    }
    void solve1()
    {
    	int tmp;
    	for (int i = 1; i <= a[1]; ++i)
    	{
    		tmp = 0;
    		if (i == 1)++tmp; if (i == a[1])++tmp;
    		++ans[tmp];
    	}
    	for (int i = 0; i <= 2 * n; ++i)printf("%lld
    ", ans[i]);
    }
    void solve2()
    {
    	int tmp;
    	for (int i = 1; i <= a[1]; ++i)
    		for (int j = 1; j <= a[2]; ++j)
    		{
    			tmp = 0;
    			if (i == 1)++tmp; if (i == a[1])++tmp;
    			if (j == 1)++tmp; if (j == a[2])++tmp;
    			++ans[tmp];
    		}
    	for (int i = 0; i <= 2 * n; ++i)printf("%lld
    ", ans[i]);
    }
    void solve3()
    {
    	int tmp;
    	for (int i = 1; i <= n; ++i)
    	{
    		tmp = a[i]; a[i] = b[i] = c[i] = 0;
    		if (tmp == 1)a[i] = 1, c[i] = 0;
    		else b[i] = 2, c[i] = tmp - b[i];
    	}
    	ans[0] = 1;
    	for (int i = 1; i <= n; ++i)
    	{
    		for (int j = 2 * n; j >= 2; --j)
    			ans[j] = ((LL)ans[j] * c[i] % mod + (LL)ans[j - 1] * b[i] % mod + (LL)ans[j - 2] * a[i] % mod) % mod;
    		ans[1] = ((LL)ans[1] * c[i] % mod + (LL)ans[0] * b[i] % mod) % mod;
    		ans[0] = (LL)ans[0] * c[i] % mod;
    	}
    	for (int i = 0; i <= 2 * n; ++i)printf("%lld
    ", ans[i]);
    }
    
    /*下边 solve4*/
    
    int last, top;
    int r[N], zhan[30], tmp[500];
    LL ksm(LL a, LL b, LL mod)
    {
    	LL res = 1;
    	for (; b; b >>= 1, a = a * a % mod)
    		if (b & 1)res = res * a % mod;
    	return res;
    }
    
    void NTT(LL *A, int lim, int opt)
    {
    	if (last != lim)
    	{
    		last = lim;
    		for (int i = 0; i < lim; ++i)
    			r[i] = (r[i >> 1] >> 1) | (i & 1 ? (lim >> 1) : 0);
    	}
    	for (int i = 0; i < lim; ++i)
    		if (i < r[i])swap(A[i], A[r[i]]);
    	int len;
    	LL wn, w, x, y;
    	for (int mid = 1; mid < lim; mid <<= 1)
    	{
    		len = mid << 1;
    		wn = ksm(opt == 1 ? G : Ginv, (mod - 1) / len, mod);
    		for (int j = 0; j < lim; j += len)
    		{
    			w = 1;
    			for (int k = j; k < j + mid; ++k, w = w * wn % mod)
    			{
    				x = A[k]; y = A[k + mid] * w % mod;
    				A[k] = (x + y) % mod;
    				A[k + mid] = (x - y + mod) % mod;
    			}
    		}
    	}
    	if (opt == 1)return;
    	int ni = ksm(lim, mod - 2, mod);
    	for (int i = 0; i < lim; ++i)A[i] = A[i] * ni % mod;
    }
    void MUL(LL *A, int n, LL *B, int m)
    {
    	if (n + m <= 115)
    	{
    		for (int i = 0, to = n + m; i <= to; ++i)tmp[i] = 0;
    		for (int i = 0; i <= n; ++i)
    			for (int j = 0; j <= m; ++j)
    				(tmp[i + j] += A[i] * B[j] % mod) %= mod;
    		for (int i = 0, to = n + m; i <= to; ++i)A[i] = tmp[i];
    		for (int i = 0; i <= m; ++i)B[i] = 0;
    	}
    	else
    	{
    		int lim = 1;
    		while (lim <= (n + m))lim <<= 1;
    		NTT(A, lim, 1); NTT(B, lim, 1);
    		for (int i = 0; i < lim; ++i)A[i] = A[i] * B[i] % mod, B[i] = 0;
    		NTT(A, lim, -1);
    	}
    }
    struct dxs
    {
    	int siz;
    	LL v[N];
    } A[30];
    int newdxs()
    {
    	return zhan[top--];
    }
    void huidxs(int x)
    {
    	A[x].siz = 0;
    	zhan[++top] = x;
    }
    int solve(int l, int r)
    {
    	if (l == r)
    	{
    		int k = newdxs();
    		A[k].siz = 2;
    		A[k].v[0] = c[l]; A[k].v[1] = b[l]; A[k].v[2] = a[l];
    		return k;
    	}
    	int mid = (l + r) >> 1;
    	int lson = solve(l, mid), rson = solve(mid + 1, r);
    	MUL(A[lson].v, A[lson].siz, A[rson].v, A[rson].siz);
    	A[lson].siz = A[lson].siz + A[rson].siz;
    	huidxs(rson);
    	return lson;
    }
    void solve4()
    {
    	int tmp;
    	for (int i = 1; i <= n; ++i)
    	{
    		tmp = a[i]; a[i] = b[i] = c[i] = 0;
    		if (tmp == 1)a[i] = 1, c[i] = 0;
    		else b[i] = 2, c[i] = tmp - b[i];
    	}
    	for (int i = 1; i <= 25; ++i)zhan[++top] = i;
    	int k = solve(1, n);
    	for (int i = 0; i <= 2 * n; ++i)printf("%lld
    ", A[k].v[i]);
    }
    
    /*上边 solve4*/
    
    signed main()
    {
    	freopen("poly.in", "r", stdin);
    	freopen("poly.out", "w", stdout);
    	cin >> n;
    	for (int i = 1; i <= n; ++i)
    	{
    		a[i] = read();
    	}
    	if (n == 1 && a[1] <= 1000)solve1();
    	else if (n == 2 && a[1] <= 1000 && a[2] <= 1000)solve2();
    	else if (n <= 5000)solve3();
    	else solve4();
    	fclose(stdin); fclose(stdout);
    	return 0;
    }
    
  • 相关阅读:
    文件的权限
    正则表达式
    软硬链接的学习
    linux系统中的文件类型和扩展名
    把数组排成最小的数
    整数中1出现的次数(从1到n整数中1出现的次数)
    最小的K个数
    连续子数组的最大和
    数组中出现次数超过一半的数字
    字符串的排列
  • 原文地址:https://www.cnblogs.com/wljss/p/12628999.html
Copyright © 2011-2022 走看看