zoukankan      html  css  js  c++  java
  • 牛客7501H「2020ICPC-小米 Round1」Grouping

    题目来源:牛客 2020ICPC-小米 Round1。H. Grouping。ACM赛,小米邀请赛,牛客网,杜瑜皓,吉如一。

    题目大意

    题目链接

    给出 (2n) 个整数 (a_1,a_2,dots ,a_{2n})。现在要将它们分成 (n) 组,每组恰好两个数。每组的权值是两数之差(大减小)。

    一个分组方案的权值是【所有组的权值的方差】。可重集 (x_1,x_2,dots ,x_m) 的方差定义为:(frac{1}{m}sum_{i=1}^{m}(x_i-overline{x})^2)

    求分组方案的权值的期望。

    数据范围:(1leq nleq 10^5)(1leq a_ileq 10^6)

    本题题解

    首先,(2n) 个人的分组方案数,可以 DP 求出。设 (dp[i]) 表示 (2i) 个人的分组方案数。转移时要先拿出来一组,剩下 (2(i-1)) 个人的方案数为 (dp[i-1])。为了避免算重,我们强制拿出来的这组包含编号为 (1) 的人。那么这组的另一个人有 (2i-1) 种选择。于是得到转移式:(dp[i] = dp[i-1] imes(2i-1))。DP 的时间复杂度 (O(n))

    答案要求期望,期望等于总和除以方案数,方案数就是 (dp[n])。也就是说,设每组差分别为 (d_1,d_2,dots,d_n),则:

    [ ext{ans} = frac{sum_{ ext{方案}}frac{1}{n}sum_{i = 1}^{n}(d_i-overline{d})^2}{dp[n]} ]

    分母已经求出。我们只需要考虑分子:

    [egin{align} &sum_{ ext{方案}}frac{1}{n}sum_{i = 1}^{n}(d_i-overline{d})^2\ =&sum_{ ext{方案}}left(frac{sum_{i=1}^{n}d_i^2}{n}-frac{2sum_{i=1}^{n}d_icdot overline{d}}{n} + frac{ncdot overline{d}^2}{n} ight)\ =&sum_{ ext{方案}}left(frac{sum_{i=1}^{n}d_i^2}{n} - frac{2sum_{i=1}^{n}d_icdot(frac{1}{n}sum_{j=1}^{n}d_j)}{n} + frac{ncdot (frac{1}{n}sum_{j=1}^{n}d_j)^2}{n} ight)\ =&sum_{ ext{方案}}left(frac{sum_{i=1}^{n}d_i^2}{n} - frac{2cdot (sum_{i=1}^{n}d_i)cdot(sum_{i=1}^{n}d_i)}{n^2} + frac{(sum_{i=1}^{n}d_i)^2}{n^2} ight)\ =&sum_{ ext{方案}}left(frac{sum_{i = 1}^{n}d_i^2}{n} - frac{(sum_{i = 1}^{n}d_i)^2}{n^2} ight)\ =&sum_{ ext{方案}}left(frac{sum_{i = 1}^{n}d_i^2}{n} - frac{sum_{i = 1}^{n}d_i^2}{n^2} - frac{sum_{i eq j}d_id_j}{n^2} ight)\ =&frac{sum_{ ext{方案}}sum_{i = 1}^{n}d_i^2}{n} - frac{sum_{ ext{方案}}sum_{i = 1}^{n}d_i^2}{n^2} - frac{sum_{ ext{方案}}sum_{i eq j}d_id_j}{n^2} end{align} ]

    考虑对这三项分别求。


    前两项是类似的,关键在于求出 (sum_{ ext{方案}}sum_{i = 1}^{n}d_i^2)。可以分别计算 (d_i) 对答案的贡献。也就是:(sum d_i^2 imes dp[n-1]),它等于:(dp[n-1] imes sum_{i=1}^{2n}sum_{j = i + 1}^{2n}(a_i - a_j)^2)。暴力计算,时间复杂度是 (O(n^2)) 的。

    (A = max_{i = 1}^{n}{a_i})(Aleq 10^6)

    考虑枚举 (a_i,a_j) 的差。求出一个数组 (h_k = sum_{1leq i<jleq n}[|a_i-a_j| = k]) ((kin[0,A]))。则 (sum_{i=1}^{2n}sum_{j = i + 1}^{2n}(a_i - a_j)^2) 就等于 (sum_{k = 0}^{A} h_k imes k^2)

    问题转化为如何求 (h) 数组。先用一个桶记录每个值的出现次数,即 (f_j = sum_{i=1}^{n}[a_i=j]) ((jin[0,A])),则 (h_k = sum_{i = k}^{A} f_{i} imes f_{i-k})。为了把它搞成卷积的形式,我们把 (f) 反过来,设 (g_i = f_{A-i}) ((iin[0,A])),则 (h_k = sum_{i = k}^{A}g_{A-i} imes f_{i-k} = sum_{i = 0}^{A-k}g_{A-k-i} imes f_{i})

    (f), (g) 做一次卷积,把卷积的结果反转,就是 (h) 数组了。进而能够求出答案式的前两项:(frac{sum_{ ext{方案}}sum_{i = 1}^{n}d_i^2}{n} - frac{sum_{ ext{方案}}sum_{i = 1}^{n}d_i^2}{n^2})

    时间复杂度 (O(Alog A + n)),因为 (A) 高达 (10^6) 而 NTT 常数较大,所以要注意卡常。


    答案式的第三项:(frac{sum_{ ext{方案}}sum_{i eq j}d_id_j}{n^2})。求分子。把它用 (a) 表示出来后,暴力求是 (O(n^4)) 的(也就是枚举 (4) 个互不相同的点)。

    考虑利用我们预处理好的 (h) 数组来求。先忽略 (4) 个点互不相同这一要求,那么算出是:(sum_{i=0}^{A}sum_{j = 0}^{A}h_icdot h_jcdot icdot j = (sum_{i=0}^{A}h_i imes i)^2),其含义是枚举两对点,第一对点差为 (i),第二对点差为 (j),把这两对点分别作为 (d_i,d_j)

    对于互不相同的要求,我们考虑把有重复点的两对去掉。

    先把 (a) 数组排序。称每对点里下标较小的为“前面点”,下标较大的为“后面点”。

    枚举这个重复点 (i) ((1leq ileq 2n)),此时分三种情况讨论:

    1. 两对点都是以 (i) 为“前面点”,则还需要在 (i) 后面再选两个点 (p,q),求出 (sum_{p,q>i}(a_p-a_i)(a_q-a_i))
    2. 两对点都是以 (i) 为“后面点”,则还需要在 (i) 前面再选两个点 (p,q),求出 (sum_{p,q<i}(a_i-a_p)(a_i-a_q))
    3. 一对点以 (i) 为“前面点”,另一对点以 (i) 为“后面点”,则贡献是 (2 imes sum_{p<i,q>i}(a_i-a_p)(a_q-a_i))

    (a) 序列做前、后缀和,则这三种情况的贡献都能轻松算出来,将它们减掉即可。

    但是,发现两对点完全相同(前面点、后面点都相同)的情况,在分类 1 和分类 2 里都被计算到了,所以要再加回来一次。这种情况的贡献是 (sum_{i=0}^{A}h_icdot icdot i)

    这样,经过简单的容斥(总方案 - 有公共点的 + 两个都是公共点的),我们就求出了答案的第三部分!

    时间复杂度 (O(n+A))


    总时间复杂度 (O(Alog A + n))

    参考代码

    内含一个精细优化后的 NTT 模板(namespace SuperNTT),因为太长了,我将其单独取出并附在后面

    实际提交时,建议使用快速输入、输出,详见本博客公告。

    // problem: H
    #include <bits/stdc++.h>
    using namespace std;
    
    #define pb push_back
    #define mk make_pair
    #define lob lower_bound
    #define upb upper_bound
    #define fi first
    #define se second
    #define SZ(x) ((int)(x).size())
    
    typedef unsigned int uint;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int, int> pii;
    
    template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
    template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
    
    namespace SuperNTT {
    // ...
    } // namespace SuperNTT
    
    const int MAXN = 1e5, MAXM = MAXN * 2, MAXA = 1e6;
    const int MOD = 998244353;
    
    inline int mod1(int x) { return x < MOD ? x : x - MOD; }
    inline int mod2(int x) { return x < 0 ? x + MOD : x; }
    inline void add(int &x, int y) { x = mod1(x + y); }
    inline void sub(int &x, int y) { x = mod2(x - y); }
    inline int pow_mod(int x, int i) {
    	int y = 1;
    	while (i) {
    		if (i & 1) y = (ll)y * x % MOD;
    		x = (ll)x * x % MOD;
    		i >>= 1;
    	}
    	return y;
    }
    
    int n, m, a[MAXM + 5], pre[MAXM + 5], suf[MAXM + 5];
    int f[MAXA + 5], g[MAXA + 5], res[MAXA + 5];
    int inv_n;
    int dp[MAXN + 5];
    int ans;
    
    int main() {
    	cin >> n;
    	
    	dp[0] = 1;
    	for (int i = 1; i <= n; ++i) {
    		dp[i] = (ll)(i * 2 - 1) * dp[i - 1] % MOD;
    	}
    	// cerr << dp[n] << endl;
    	
    	m = n * 2;
    	for (int i = 1; i <= m; ++i) {
    		cin >> a[i];
    		f[a[i]]++;
    	}
    	for (int i = 0; i <= MAXA; ++i) g[MAXA - i] = f[i];
    	SuperNTT :: work(f, g, MAXA + 1, MAXA + 1, res);
    	reverse(res, res + MAXA + 1);
    	int sum = 0;
    	for (int i = 1; i <= MAXA; ++i) {
    		add(sum, (ll)i * i % MOD * res[i] % MOD);
    	}
    	sum = (ll)sum * dp[n - 1] % MOD;
    	inv_n = pow_mod(n, MOD - 2);
    	ans = mod2((ll)sum * inv_n % MOD - (ll)sum * inv_n % MOD * inv_n % MOD);
    	
    	sort(a + 1, a + m + 1);
    	for (int i = 1; i <= m; ++i) pre[i] = mod1(pre[i - 1] + a[i]);
    	for (int i = m; i >= 1; --i) suf[i] = mod1(suf[i + 1] + a[i]);
    	if (n >= 2) {
    		int sum = 0;
    		int tmp = 0;
    //		for (int i = 0; i <= MAXA; ++i) {
    //			for(int j = 0; j <= MAXA; ++j) {
    //				add(sum, (ll)res[i] * res[j] % MOD * i % MOD * j % MOD);
    //			}
    //		}
    		for (int i = 0; i <= MAXA; ++i) add(tmp, (ll)res[i] * i % MOD);
    		for (int i = 0; i <= MAXA; ++i) add(sum, (ll)tmp * res[i] % MOD * i % MOD);
    		// cerr << sum << endl;
    		for (int i = 1; i <= m; ++i) {
    			int x = mod2(suf[i + 1] - (ll)a[i] * (m - i) % MOD);
    			int y = mod2((ll)a[i] * (i - 1) % MOD - pre[i - 1]);
    			sub(sum, (ll)x * x % MOD);
    			sub(sum, (ll)y * y % MOD);
    			sub(sum, (ll)x * y * 2 % MOD);
    		}
    		// cerr << sum << endl;
    		for(int i = 0; i <= MAXA; ++i) {
    			add(sum, (ll)i * i % MOD * res[i] % MOD);
    		}
    		// cerr << sum << endl;
    		sum = (ll)sum * dp[n - 2] % MOD;
    		sub(ans, (ll)sum * inv_n % MOD * inv_n % MOD);
    	}
    	ans = (ll)ans * pow_mod(dp[n], MOD - 2) % MOD;
    	cout << ans << endl;
    	return 0;
    }
    

    NTT 模板(namespace SuperNTT):

    typedef unsigned int uint;
    typedef long long unsigned int uint64;
    
    constexpr uint Max_size = 1 << 21 | 5;
    constexpr uint g = 3, Mod = 998244353;
    
    inline uint norm_2(const uint x)
    {
    	return x < Mod * 2 ? x : x - Mod * 2;
    }
    
    inline uint norm(const uint x)
    {
    	return x < Mod ? x : x - Mod;
    }
    
    struct Z
    {
    	uint v;
    	Z() { }
    	Z(const uint _v) : v(_v) { }
    };
    
    inline Z operator+(const Z &x1, const Z &x2) { return x1.v + x2.v < Mod ? x1.v + x2.v : x1.v + x2.v - Mod; }
    inline Z operator-(const Z &x1, const Z &x2) { return x1.v >= x2.v ? x1.v - x2.v : x1.v + Mod - x2.v; }
    inline Z operator*(const Z &x1, const Z &x2) { return static_cast<uint64>(x1.v) * x2.v % Mod; }
    inline Z &operator*=(Z &x1, const Z &x2) { x1.v = static_cast<uint64>(x1.v) * x2.v % Mod; return x1; }
    
    Z Power(Z Base, int Exp)
    {
    	Z res = 1;
    	for (; Exp; Base *= Base, Exp >>= 1)
    		if (Exp & 1)
    			res *= Base;
    	return res;
    }
    
    inline uint mf(uint x)
    {
    	return (static_cast<uint64>(x) << 32) / Mod;
    }
    
    int size;
    uint w[Max_size], w_[Max_size];
    
    inline uint mult_Shoup_2(const uint x, const uint y, const uint y_)
    {
    	uint q = static_cast<uint64>(x) * y_ >> 32;
    	return x * y - q * Mod;
    }
    
    inline uint mult_Shoup(const uint x, const uint y, const uint y_)
    {
    	return norm(mult_Shoup_2(x, y, y_));
    }
    
    inline void init(int n)
    {
    	for (size = 2; size < n; size <<= 1)
    		;
    	Z pr = Power(g, (Mod - 1) / size);
    	size >>= 1;
    	w[size] = 1, w_[size] = (static_cast<uint64>(w[size]) << 32) / Mod;
    	if (size <= 8)
    	{
    		for (int i = 1; i < size; ++i)
    			w[size + i] = (w[size + i - 1] * pr).v, w_[size + i] = (static_cast<uint64>(w[size + i]) << 32) / Mod;
    	}
    	else
    	{
    		for (int i = 1; i < 8; ++i)
    			w[size + i] = (w[size + i - 1] * pr).v, w_[size + i] = (static_cast<uint64>(w[size + i]) << 32) / Mod;
    		pr *= pr, pr *= pr, pr *= pr;
    		for (int i = 8; i < size; i += 8)
    		{ 
    			w[size + i + 0] = (w[size + i - 8] * pr).v, w_[size + i + 0] = (static_cast<uint64>(w[size + i + 0]) << 32) / Mod;
    			w[size + i + 1] = (w[size + i - 7] * pr).v, w_[size + i + 1] = (static_cast<uint64>(w[size + i + 1]) << 32) / Mod;
    			w[size + i + 2] = (w[size + i - 6] * pr).v, w_[size + i + 2] = (static_cast<uint64>(w[size + i + 2]) << 32) / Mod;
    			w[size + i + 3] = (w[size + i - 5] * pr).v, w_[size + i + 3] = (static_cast<uint64>(w[size + i + 3]) << 32) / Mod;
    			w[size + i + 4] = (w[size + i - 4] * pr).v, w_[size + i + 4] = (static_cast<uint64>(w[size + i + 4]) << 32) / Mod;
    			w[size + i + 5] = (w[size + i - 3] * pr).v, w_[size + i + 5] = (static_cast<uint64>(w[size + i + 5]) << 32) / Mod;
    			w[size + i + 6] = (w[size + i - 2] * pr).v, w_[size + i + 6] = (static_cast<uint64>(w[size + i + 6]) << 32) / Mod;
    			w[size + i + 7] = (w[size + i - 1] * pr).v, w_[size + i + 7] = (static_cast<uint64>(w[size + i + 7]) << 32) / Mod;
    		} 
    	}
    	for (int i = size - 1; i; --i)
    		w[i] = w[i * 2], w_[i] = w_[i * 2];
    	size <<= 1;
    }
    
    inline void DFT_fr_2(Z _A[], const int L)
    {
    	if (L == 1)
    		return;
    	uint *A = reinterpret_cast<uint *>(_A);
    #define butterfly1(a, b)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a + _b), y = norm_2(_a + Mod * 2 - _b);
    		a = x, b = y;
    	} while (0)
    	if (L == 2)
    	{
    		butterfly1(A[0], A[1]);
    		return;
    	}
    #define butterfly(a, b, _w, _w_)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a + _b), y = mult_Shoup_2(_a + Mod * 2 - _b, _w, _w_);
    		a = x, b = y;
    	} while (0)
    	if (L == 4)
    	{
    		butterfly1(A[0], A[2]);
    		butterfly(A[1], A[3], w[3], w_[3]);
    		butterfly1(A[0], A[1]);
    		butterfly1(A[2], A[3]);
    		return; 
    	}
    	for (int d = L >> 1; d != 4; d >>= 1)
    		for (int i = 0; i != L; i += d << 1)
    			for (int j = 0; j != d; j += 4)
    			{
    				butterfly(A[i + j], A[i + d + j], w[d + j], w_[d + j]);
    				butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1], w_[d + j + 1]);
    				butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2], w_[d + j + 2]);
    				butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3], w_[d + j + 3]);
    			}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 4]);
    		butterfly(A[i + 1], A[i + 5], w[5], w_[5]);
    		butterfly(A[i + 2], A[i + 6], w[6], w_[6]);
    		butterfly(A[i + 3], A[i + 7], w[7], w_[7]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 2]);
    		butterfly(A[i + 1], A[i + 3], w[3], w_[3]);
    		butterfly1(A[i + 4], A[i + 6]);
    		butterfly(A[i + 5], A[i + 7], w[3], w_[3]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 1]);
    		butterfly1(A[i + 2], A[i + 3]);
    		butterfly1(A[i + 4], A[i + 5]);
    		butterfly1(A[i + 6], A[i + 7]);
    	}
    #undef butterfly1
    #undef butterfly
    }
    
    inline void IDFT_fr(Z _A[], const int L)
    {
    	if (L == 1)
    		return;
    	uint *A = reinterpret_cast<uint *>(_A);
    #define butterfly1(a, b)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a), t = norm_2(_b);
    		a = x + t, b = x + Mod * 2 - t;
    	} while (0)
    	if (L == 2)
    	{
    		butterfly1(A[0], A[1]);
    		A[0] = norm(norm_2(A[0])), A[0] = A[0] & 1 ? A[0] + Mod : A[0], A[0] /= 2;
    		A[1] = norm(norm_2(A[1])), A[1] = A[1] & 1 ? A[1] + Mod : A[1], A[1] /= 2;
    		return;
    	}
    #define butterfly(a, b, _w, _w_)
    	do
    	{
    		uint _a = a, _b = b;
    		uint x = norm_2(_a), t = mult_Shoup_2(_b, _w, _w_);
    		a = x + t, b = x + Mod * 2 - t;
    	} while (0)
    	if (L == 4)
    	{
    		butterfly1(A[0], A[1]);
    		butterfly1(A[2], A[3]);
    		butterfly1(A[0], A[2]);
    		butterfly(A[1], A[3], w[3], w_[3]);
    		std::swap(A[1], A[3]);
    		for (int i = 0; i != L; ++i)
    		{
    			uint64 m = -A[i] & 3;
    			A[i] = norm((A[i] + m * Mod) >> 2);
    		}
    		return; 
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 1]);
    		butterfly1(A[i + 2], A[i + 3]);
    		butterfly1(A[i + 4], A[i + 5]);
    		butterfly1(A[i + 6], A[i + 7]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 2]);
    		butterfly(A[i + 1], A[i + 3], w[3], w_[3]);
    		butterfly1(A[i + 4], A[i + 6]);
    		butterfly(A[i + 5], A[i + 7], w[3], w_[3]);
    	}
    	for (int i = 0; i != L; i += 8)
    	{
    		butterfly1(A[i], A[i + 4]);
    		butterfly(A[i + 1], A[i + 5], w[5], w_[5]);
    		butterfly(A[i + 2], A[i + 6], w[6], w_[6]);
    		butterfly(A[i + 3], A[i + 7], w[7], w_[7]);
    	}
    	for (int d = 8; d != L; d <<= 1)
    		for (int i = 0; i != L; i += d << 1)
    			for (int j = 0; j != d; j += 4)
    			{
    				butterfly(A[i + j], A[i + d + j], w[d + j], w_[d + j]);
    				butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1], w_[d + j + 1]);
    				butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2], w_[d + j + 2]);
    				butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3], w_[d + j + 3]);
    			}
    #undef butterfly1
    #undef butterfly
    	std::reverse(A + 1, A + L);
    	int k = __builtin_ctz(L);
    	for (int i = 0; i != L; ++i)
    	{
    		uint64 m = -A[i] & (L - 1);
    		A[i] = norm((A[i] + m * Mod) >> k);
    	}
    }
    
    int N, M, L;
    Z A[Max_size], B[Max_size];
    
    void work(int f[], int g[], int n, int m, int res[]) {
    	N = n; M = m;
    	for(int i = 0; i < n; ++i) A[i].v = f[i];
    	for(int i = 0; i < m; ++i) B[i].v = g[i];
    	for (L = 2; L <= N + M - 2; L <<= 1)
    		;
    	init(L);
    	
    	DFT_fr_2(A, L), DFT_fr_2(B, L);
    	for (int i = 0; i != L; ++i)
    		A[i] *= B[i];
    	IDFT_fr(A, L);
    	
    	for(int i = 0; i < n; ++i) res[i] = A[i].v;
    }
    
  • 相关阅读:
    js之iframe子页面与父页面通信
    js的event对象
    整洁代码的4个条件
    PYTHON 自然语言处理
    如何检测浏览器是否支持CSS3
    BootStrap前端框架使用方法详解
    如何使用repr调试python程序
    Python编程快速上手——Excel到CSV的转换程序案例分析
    C++和JAVA传统中积极的一面
    20个LINUX相关的网站
  • 原文地址:https://www.cnblogs.com/dysyn1314/p/13874488.html
Copyright © 2011-2022 走看看