zoukankan      html  css  js  c++  java
  • CF1444B Divide and Sum

    题目来源:Codeforces Round #680 (Div. 1, based on Moscow Team Olympiad)/Codeforces Round #680 (Div. 2, based on Moscow Team Olympiad),CF1444B/CF1445D,Divide and Sum

    说明:这个做法用到了 NTT,从时间、空间常数,以及实现的简单性来看,都不是最优解法。但是我个人挺喜欢这个做法的,感觉一步一步推下来,很有逻辑。不依赖灵感,不需要恍然大悟地想到什么巧妙的结论。自有一种朴实之美。

    题目大意

    题目链接

    给定一个长度为 (2n) 的序列 (a)。现在要将 (a) 划分为两个子序列 (p,q),两个子序列长度都恰好为 (n),且无公共元素。

    划分完成后,我们将 (p,q) 分别按从小到大从大到小排序,记得到的两个序列分别为 (x,y)

    我们规定一种划分方式的权值为:(f(p,q) = sum_{i = 1}^{n}|x_i - y_i|)

    请求出所有划分方式的权值之和。答案对 (998244353) 取模。

    数据范围:(1leq nleq 150000)(1leq a_ileq 10^9)

    本题题解

    首先,因为得到的两个子序列是要排序的,所以元素的初始顺序其实不重要。那么可以先将 (a) 序列排序。以下讨论 (a) 序列都是指排好序后的序列。

    考虑 (p,q) 的第 (i) 位对答案的贡献 ((1leq ileq n))。枚举 (p)(i) 位上的元素,记为 (a_j);枚举 (q)(i) 位上的元素,记为 (a_k) ((k eq j))。那么 (a_j) 必须恰好是 (p) 序列里第 (i) 小的,(a_k) 必须恰好是 (q) 序列里第 (n-i+1) 小的,也就是说,它们前面分别恰有 (i-1) 个 / (n-i) 个自己序列的元素。那么不难用组合数求出,(|a_j-a_k|) 在第 (i) 位上对答案贡献的方案数。具体来说:

    • (k<j) 时,对答案的贡献是:({k-1choose n-i}{j-k-1choose(i - 1) - (k - 1 - (n - i))}{2n-jchoose n-i}(a_j-a_k)),化简一下,等于:({k-1choose n-i}{j-k-1choose n - k}{2n-jchoose n-i}(a_j-a_k))。这三个组合数,含义分别是:在 (k) 前面选出 (q) 序列里的元素(剩下的都在 (p) 序列里);在 (k,j) 之间选出 (p) 序列里的元素;在 (j) 后面选出 (p) 序列里的元素(剩下的都在 (q) 序列里)。
    • (j<k) 时,对答案的贡献是:({j - 1choose i - 1}{k - j - 1choose (n - i) - (j - 1 - (i - 1))}{2n-kchoose i-1}(a_k-a_j)),化简一下,等于:({j-1choose i-1}{k-j-1choose n - j}{2n-kchoose i-1}(a_k-a_j))。这三个组合数,和前面类似,含义分别是:在 (j) 前面选出 (p) 序列里的元素;在 (j,k) 之间选出 (q) 序列里的元素;在 (k) 后面选出 (q) 序列里的元素。

    暴力枚举 (i,j,k),按此式子计算答案,时间复杂度 (O(n^3))。这个暴力做法的代码片段附在了参考代码部分。


    继续优化。发现枚举 (j) 再枚举 (k) 这件事比较愚蠢。考虑将它们拆开来,也就是分别枚举 (j,k)。以枚举 (j) 为例。考虑一个 (j) 的贡献,有两种情况:

    1. (k<j),此时这个 (j) 对答案的贡献是 (a_j) 乘以一个系数。
    2. (j<k),此时这个 (j) 对答案的贡献是 (-a_j) 乘以一个系数。

    我们要求出这个系数。对于情况 1,相当于要求 (j) 前面存在一个合法的 (k)。考虑 (j) 前面要有哪些东西:

    • (p) 序列前 (i-1) 小的元素(恰好这么多,否则 (j) 就不是第 (i) 了)。
    • (q) 序列前 (n-i+1) 小的元素(或者更多元素)。

    因此,发现 (j) 一定要大于等于 ((i-1)+(n-i+1)+1=n+1)。依次枚举 (jin[n+1,2n]),每个 (j) 对答案的贡献就是:({j - 1choose i - 1}cdot {2n-jchoose n-i}cdot a_j)

    同理,情况 2 中,(jin[1,n]),对答案的贡献是:({j - 1choose i - 1}cdot {2n-jchoose n-i}cdot (-a_j))

    (k) 的贡献也是类似的。分别是:(kin[n+1,2n])({k-1choose n-i}cdot{2n-kchoose i-1}cdot a_k)(kin[1,n])({k-1choose n-i}cdot {2n-kchoose i-1}cdot (-a_k))

    这样,我们只需要先枚举 (i),再分别枚举 (j,k)(而不是套起来)。时间复杂度 (O(n^2))。这个做法的代码片段附在了参考代码部分。


    最后一步,我们把上述 (n^2) 的式子拆开,写成卷积的形式,就可以了。

    (jin[n+1,2n])({j - 1choose i - 1}cdot {2n-jchoose n-i}cdot a_j) 为例,可以写成:

    [sum_{j=n+1}^{2n}a_jcdot (j-1)!cdot (2n-j)!sum_{i=1}^{n}frac{1}{(i-1)!(n-i)!}cdot frac{1}{(j-i)!(n-(j-i))!} ]

    (f_i=frac{1}{(i-1)!(n-i)!})(g_i=frac{1}{i!(n-i)!}),则后半部分就是 (fcdot g) (多项式乘法)的第 (j) 项。我们对 (f,g) 做 NTT 即可。

    我们求 (j,k) 的贡献时,是两个不同的式子,所以各要做一次 NTT。总时间复杂度 (O(nlog n))

    参考代码

    内含一个精细优化后的 NTT 模板(namespace SuperNTT),因为太长了,我将其单独取出并附在后面

    实际提交时,建议使用快速输入、输出,详见本博客公告。

    // problem: CF1444B
    #include <bits/stdc++.h>
    using namespace std;
    
    #define pb push_back
    #define mk make_pair
    #define lob lower_bound
    #define upb upper_bound
    #define fi first
    #define se second
    #define SZ(x) ((int)(x).size())
    
    typedef unsigned int uint;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int, int> pii;
    
    template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
    template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
    
    namespace SuperNTT {
    // ...
    } // namespace SuperNTT
    
    const int MAXN = 3e5;
    const int MOD = 998244353;
    inline int mod1(int x) { return x < MOD ? x : x - MOD; }
    inline int mod2(int x) { return x < 0 ? x + MOD : x; }
    inline void add(int &x, int y) { x = mod1(x + y); }
    inline void sub(int &x, int y) { x = mod2(x - y); }
    inline int pow_mod(int x, int i) {
    	int y = 1;
    	while (i) {
    		if (i & 1) y = (ll)y * x % MOD;
    		x = (ll)x * x % MOD;
    		i >>= 1;
    	}
    	return y;
    }
    
    int fac[MAXN + 5], ifac[MAXN + 5];
    inline int comb(int n, int k) {
    	if (n < k) return 0;
    	return (ll)fac[n] * ifac[k] % MOD * ifac[n - k] % MOD;
    }
    void facinit(int lim = MAXN) {
    	fac[0] = 1;
    	for (int i = 1; i <= lim; ++i) fac[i] = (ll)fac[i - 1] * i % MOD;
    	ifac[lim] = pow_mod(fac[lim], MOD - 2);
    	for (int i = lim - 1; i >= 0; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % MOD;
    }
    
    int n, a[MAXN + 5];
    
    int f[MAXN * 4 + 5], g[MAXN * 4 + 5], res[MAXN * 4 + 5];
    int main() {
    	facinit();
    	cin >> n;
    	for (int i = 1; i <= n * 2; ++i) {
    		cin >> a[i];
    	}
    	sort(a + 1, a + n * 2 + 1);
    	int ans = 0;
    	
    	for (int i = 0; i <= n; ++i) {
    		if (i != 0) f[i] = (ll)ifac[i - 1] * ifac[n - i] % MOD;
    		g[i] = (ll)ifac[i] * ifac[n - i] % MOD;
    	}
    	
    	SuperNTT :: work(f, g, n + 1, n + 1, res);
    	for (int j = n + 1; j <= n * 2; ++j) {
    		add(ans, (ll)a[j] * fac[j - 1] % MOD * fac[2 * n - j] % MOD * res[j] % MOD);
    	}
    	for (int j = 1; j <= n; ++j) {
    		sub(ans, (ll)a[j] * fac[j - 1] % MOD * fac[2 * n - j] % MOD * res[j] % MOD);
    	}
    	
    	memset(f, 0, sizeof(f));
    	memset(g, 0, sizeof(g));
    	for (int i = 1; i <= n; ++i) {
    		f[i] = (ll)ifac[i - 1] * ifac[n - i] % MOD;
    	}
    	for (int i = n + 1; i <= n * 2 + 1; ++i) {
    		g[i] = (ll)ifac[i - n - 1] * ifac[2 * n + 1 - i] % MOD;
    	}
    	reverse(f, f + n + 1);
    	SuperNTT :: work(f, g, n + 1, n * 2 + 2, res);
    	for (int j = n + 1; j <= n * 2; ++j) {
    		add(ans, (ll)a[j] * fac[j - 1] % MOD * fac[n * 2 - j] % MOD * res[n + j] % MOD);
    	}
    	for (int j = 1; j <= n; ++j) {
    		sub(ans, (ll)a[j] * fac[j - 1] % MOD * fac[n * 2 - j] % MOD * res[n + j] % MOD);
    	}
    	cout << ans << endl;
    	return 0;
    }
    

    NTT 模板(namespace SuperNTT):

    typedef unsigned int uint;
    typedef long long unsigned int uint64;
    
    constexpr uint Max_size = 1 << 21 | 5;
    constexpr uint g = 3, Mod = 998244353;
    
    inline uint norm_2(const uint x)
    {
    	return x < Mod * 2 ? x : x - Mod * 2;
    }
    
    inline uint norm(const uint x)
    {
    	return x < Mod ? x : x - Mod;
    }
    
    struct Z
    {
    	uint v;
    	Z() { }
    	Z(const uint _v) : v(_v) { }
    };
    
    inline Z operator+(const Z &x1, const Z &x2) { return x1.v + x2.v < Mod ? x1.v + x2.v : x1.v + x2.v - Mod; }
    inline Z operator-(const Z &x1, const Z &x2) { return x1.v >= x2.v ? x1.v - x2.v : x1.v + Mod - x2.v; }
    inline Z operator*(const Z &x1, const Z &x2) { return static_cast<uint64>(x1.v) * x2.v % Mod; }
    inline Z &operator*=(Z &x1, const Z &x2) { x1.v = static_cast<uint64>(x1.v) * x2.v % Mod; return x1; }
    
    Z Power(Z Base, int Exp)
    {
    	Z res = 1;
    	for (; Exp; Base *= Base, Exp >>= 1)
    		if (Exp & 1)
    			res *= Base;
    	return res;
    }
    
    inline uint mf(uint x)
    {
    	return (static_cast<uint64>(x) << 32) / Mod;
    }
    
    int size;
    uint w[Max_size], w_[Max_size];
    
    inline uint mult_Shoup_2(const uint x, const uint y, const uint y_)
    {
    	uint q = static_cast<uint64>(x) * y_ >> 32;
    	return x * y - q * Mod;
    }
    
    inline uint mult_Shoup(const uint x, const uint y, const uint y_)
    {
    	return norm(mult_Shoup_2(x, y, y_));
    }
    
    inline void init(int n)
    {
    	for (size = 2; size < n; size <<= 1)
    		;
    	Z pr = Power(g, (Mod - 1) / size);
    	size >>= 1;
    	w[size] = 1, w_[size] = (static_cast<uint64>(w[size]) << 32) / Mod;
    	if (size <= 8)
    	{
    		for (int i = 1; i < size; ++i)
    			w[size + i] = (w[size + i - 1] * pr).v, w_[size + i] = (static_cast<uint64>(w[size + i]) << 32) / Mod;
    	}
    	else
    	{
    		for (int i = 1; i < 8; ++i)
    			w[size + i] = (w[size + i - 1] * pr).v, w_[size + i] = (static_cast<uint64>(w[size + i]) << 32) / Mod;
    		pr *= pr, pr *= pr, pr *= pr;
    		for (int i = 8; i < size; i += 8)
    		{ 
    			w[size + i + 0] = (w[size + i - 8] * pr).v, w_[size + i + 0] = (static_cast<uint64>(w[size + i + 0]) << 32) / Mod;
    			w[size + i + 1] = (w[size + i - 7] * pr).v, w_[size + i + 1] = (static_cast<uint64>(w[size + i + 1]) << 32) / Mod;
    			w[size + i + 2] = (w[size + i - 6] * pr).v, w_[size + i + 2] = (static_cast<uint64>(w[size + i + 2]) << 32) / Mod;
    			w[size + i + 3] = (w[size + i - 5] * pr).v, w_[size + i + 3] = (static_cast<uint64>(w[size + i + 3]) << 32) / Mod;
    			w[size + i + 4] = (w[size + i - 4] * pr).v, w_[size + i + 4] = (static_cast<uint64>(w[size + i + 4]) << 32) / Mod;
    			w[size + i + 5] = (w[size + i - 3] * pr).v, w_[size + i + 5] = (static_cast<uint64>(w[size + i + 5]) << 32) / Mod;
    			w[size + i + 6] = (w[size + i - 2] * pr).v, w_[size + i + 6] = (static_cast<uint64>(w[size + i + 6]) << 32) / Mod;
    			w[size + i + 7] = (w[size + i - 1] * pr).v, w_[size + i + 7] = (static_cast<uint64>(w[size + i + 7]) << 32) / Mod;
    		} 
    	}
    	for (int i = size - 1; i; --i)
    		w[i] = w[i * 2], w_[i] = w_[i * 2];
    	size <<= 1;
    }
    
    inline void DFT_fr_2(Z _A[], const int L)
    {
    	if (L == 1)
    		return;
    	uint *A = reinterpret_cast<uint *>(_A);
    #define butterfly1(a, b)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a + _b), y = norm_2(_a + Mod * 2 - _b);
    		a = x, b = y;
    	} while (0)
    	if (L == 2)
    	{
    		butterfly1(A[0], A[1]);
    		return;
    	}
    #define butterfly(a, b, _w, _w_)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a + _b), y = mult_Shoup_2(_a + Mod * 2 - _b, _w, _w_);
    		a = x, b = y;
    	} while (0)
    	if (L == 4)
    	{
    		butterfly1(A[0], A[2]);
    		butterfly(A[1], A[3], w[3], w_[3]);
    		butterfly1(A[0], A[1]);
    		butterfly1(A[2], A[3]);
    		return; 
    	}
    	for (int d = L >> 1; d != 4; d >>= 1)
    		for (int i = 0; i != L; i += d << 1)
    			for (int j = 0; j != d; j += 4)
    			{
    				butterfly(A[i + j], A[i + d + j], w[d + j], w_[d + j]);
    				butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1], w_[d + j + 1]);
    				butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2], w_[d + j + 2]);
    				butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3], w_[d + j + 3]);
    			}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 4]);
    		butterfly(A[i + 1], A[i + 5], w[5], w_[5]);
    		butterfly(A[i + 2], A[i + 6], w[6], w_[6]);
    		butterfly(A[i + 3], A[i + 7], w[7], w_[7]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 2]);
    		butterfly(A[i + 1], A[i + 3], w[3], w_[3]);
    		butterfly1(A[i + 4], A[i + 6]);
    		butterfly(A[i + 5], A[i + 7], w[3], w_[3]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 1]);
    		butterfly1(A[i + 2], A[i + 3]);
    		butterfly1(A[i + 4], A[i + 5]);
    		butterfly1(A[i + 6], A[i + 7]);
    	}
    #undef butterfly1
    #undef butterfly
    }
    
    inline void IDFT_fr(Z _A[], const int L)
    {
    	if (L == 1)
    		return;
    	uint *A = reinterpret_cast<uint *>(_A);
    #define butterfly1(a, b)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a), t = norm_2(_b);
    		a = x + t, b = x + Mod * 2 - t;
    	} while (0)
    	if (L == 2)
    	{
    		butterfly1(A[0], A[1]);
    		A[0] = norm(norm_2(A[0])), A[0] = A[0] & 1 ? A[0] + Mod : A[0], A[0] /= 2;
    		A[1] = norm(norm_2(A[1])), A[1] = A[1] & 1 ? A[1] + Mod : A[1], A[1] /= 2;
    		return;
    	}
    #define butterfly(a, b, _w, _w_)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a), t = mult_Shoup_2(_b, _w, _w_);
    		a = x + t, b = x + Mod * 2 - t;
    	} while (0)
    	if (L == 4)
    	{
    		butterfly1(A[0], A[1]);
    		butterfly1(A[2], A[3]);
    		butterfly1(A[0], A[2]);
    		butterfly(A[1], A[3], w[3], w_[3]);
    		std::swap(A[1], A[3]);
    		for (int i = 0; i != L; ++i)
    		{
    			uint64 m = -A[i] & 3;
    			A[i] = norm((A[i] + m * Mod) >> 2);
    		}
    		return; 
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 1]);
    		butterfly1(A[i + 2], A[i + 3]);
    		butterfly1(A[i + 4], A[i + 5]);
    		butterfly1(A[i + 6], A[i + 7]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 2]);
    		butterfly(A[i + 1], A[i + 3], w[3], w_[3]);
    		butterfly1(A[i + 4], A[i + 6]);
    		butterfly(A[i + 5], A[i + 7], w[3], w_[3]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 4]);
    		butterfly(A[i + 1], A[i + 5], w[5], w_[5]);
    		butterfly(A[i + 2], A[i + 6], w[6], w_[6]);
    		butterfly(A[i + 3], A[i + 7], w[7], w_[7]);
    	}
    	for (int d = 8; d != L; d <<= 1)
    		for (int i = 0; i != L; i += d << 1)
    			for (int j = 0; j != d; j += 4)
    			{
    				butterfly(A[i + j], A[i + d + j], w[d + j], w_[d + j]);
    				butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1], w_[d + j + 1]);
    				butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2], w_[d + j + 2]);
    				butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3], w_[d + j + 3]);
    			}
    #undef butterfly1
    #undef butterfly
    	std::reverse(A + 1, A + L);
    	int k = __builtin_ctz(L);
    	for (int i = 0; i != L; ++i)
    	{
    		uint64 m = -A[i] & (L - 1);
    		A[i] = norm((A[i] + m * Mod) >> k);
    	}
    }
    
    int N, M, L;
    Z A[Max_size], B[Max_size];
    
    void work(int f[], int g[], int n, int m, int res[]) {
    	N = n; M = m;
    	memset(A, 0, sizeof(A));
    	memset(B, 0, sizeof(B));
    	for(int i = 0; i < n; ++i) A[i].v = f[i];
    	for(int i = 0; i < m; ++i) B[i].v = g[i];
    	for (L = 2; L <= N + M - 2; L <<= 1)
    		;
    	init(L);
    	
    	DFT_fr_2(A, L), DFT_fr_2(B, L);
    	for (int i = 0; i != L; ++i)
    		A[i] *= B[i];
    	IDFT_fr(A, L);
    	
    	for(int i = 0; i < n + m - 1; ++i) res[i] = A[i].v;
    }
    

    (O(n^3)) 做法片段:

    sort(a + 1, a + n * 2 + 1);
    int ans = 0;
    for (int i = 1; i <= n; ++i) {
    	for (int j = i; j <= 2 * n; ++j) {
    		// a[j] -> p[i]
    		for (int k = n - i + 1; k <= n * 2; ++k) {
    			// a[k] -> q[i]
    			if (a[j] == a[k]) continue;
    			if (k < j) {
    				if (k - 1 <= (i - 1) + n - i)
    					add(ans, (ll)comb(k - 1, n - i) * comb(j - k - 1, (i - 1) - (k - 1 - (n - i))) % MOD * comb(n * 2 - j, n - i) % MOD * (a[j] - a[k]) % MOD);
    			} else {
    				if (j - 1 <= (i - 1) + n - i)
    					add(ans, (ll)comb(j - 1, i - 1) * comb(k - j - 1, (n - i) - (j - 1 - (i - 1))) % MOD * comb(n * 2 - k, i - 1) % MOD * (a[k] - a[j]) % MOD);
    			}
    		}
    	}
    }
    cout << ans << endl;
    

    (O(n^2)) 做法片段:

    sort(a + 1, a + n * 2 + 1);
    int ans = 0;
    for (int i = 1; i <= n; ++i) {
    	for (int j = n + 1; j <= n * 2; ++j) {
    		// k < j
    		add(ans, (ll)comb(j - 1, i - 1) * comb(n * 2 - j, n - i) % MOD * a[j] % MOD);
    	}
    	for (int j = n; j >= 1; --j) {
    		// j < k
    		sub(ans, (ll)comb(n * 2 - j, n - i) * comb(j - 1, i - 1) % MOD * a[j] % MOD);
    	}
    	for (int k = n + 1; k <= n * 2; ++k) {
    		// j < k
    		add(ans, (ll)comb(k - 1, n - i) * comb(n * 2 - k, i - 1) % MOD * a[k] % MOD);
    	}
    	for (int k = n; k >= 1; --k) {
    		// k < j
    		sub(ans, (ll)comb(n * 2 - k, i - 1) * comb(k - 1, n - i) % MOD * a[k] % MOD);
    	}
    }
    
  • 相关阅读:
    [go]go addressable 详解
    [go]灵活的处理json与go结构体
    [django]django内置的用户模型
    [go]文件读写&io操作
    *2.3.2_加入env
    UVM_INFO
    uvm_config_db在UVM验证环境中的应用
    *2.2.4 加入virtual interface
    *2.2.3 加入objection机制
    2.2.2 加入factory机制
  • 原文地址:https://www.cnblogs.com/dysyn1314/p/13912158.html
Copyright © 2011-2022 走看看