zoukankan      html  css  js  c++  java
  • 【LOJ #2320】「清华集训 2017」生成树计数

    Description

    题目链接:

    在一个 (s) 个点的图中,存在 (s-n) 条边,使图中形成了 (n) 个连通块,第 (i) 个连通块中有 (a_i) 个点。

    现在我们需要再连接 (n-1) 条边,使该图变成一棵树。对一种连边方案,设原图中第 (i) 个连通块连出了 (d_i) 条边,那么这棵树 (T) 的价值为:

    [mathrm{val}(T) = left(prod_{i=1}^{n} {d_i}^m ight)left(sum_{i=1}^{n} {d_i}^m ight) ]

    你的任务是求出所有可能的生成树的价值之和,对 (998244353) 取模。

    (n leq 3 imes 10^4,m leq 30)

    时空限制:( exttt{5s/1GB})

    Solution

    算法一

    由于我比较菜,所以想了半天才会这个暴力。

    将每个连通块看成一个点,首先我们知道 Prufer 序列中每个点的出现次数就是度数减一,因此我们不妨考虑枚举度数序列计算。

    考虑在两个大小分别为 (a)(b) 的连通块之间连边有 (acdot b) 种选择,因此我们把所有边的贡献相乘,所以每种连通块的生成树对应的原树的方案数为 (prod_{i=1}^na_i^{d_i})

    (q_i) 表示 Prufer 序列中 (i) 的出现次数,即 (q_i=d_i-1)。如果确定了一个 (sum q_i=n-2),那么我们有

    [ ext{ans}=sum_{sum q_i=n-2}frac{(n-2)!}{prod_{i=1}^nq_i!}prod_{i=1}^na_i^{q_i+1}left(prod_{i=1}^{n} {(q_i+1)}^m ight)left(sum_{i=1}^{n} {(q_i+1)}^m ight) ]

    这个式子只需要 (q_i) 的信息即可计算,我们仔细观察可以发现这个式子是可以 DP 的。

    首先我们将奇怪的项先提出来,得到

    [ ext{ans}=(n-2)!sum_{sum q_i=n-2}prod_{i=1}^nfrac{a_i^{q_i+1}}{q_i!}left(prod_{i=1}^{n} {(q_i+1)}^m ight)left(sum_{i=1}^{n} {(q_i+1)}^m ight) ]

    考虑当前考虑到前 (n) 个点有 (sum_{i=1}^nq_i=s),需要考虑的式子是下面这样的,不妨设它为 (g(n,s))

    [g(n,s)=sum_{sum_{i=1}^nq_i=s}prod_{i=1}^nfrac{a_i^{q_i+1}}{q_i!}left(prod_{i=1}^{n} {(q_i+1)}^m ight)left(sum_{i=1}^{n} {(q_i+1)}^m ight) ]

    那么考虑新加入一个 (q_{n+1}=k),这个式子就变为

    [frac{(k+1)^mcdot a_{n+1}^{k+1}}{k!}sum_{sum_{i=1}^nq_i=s}prod_{i=1}^nfrac{a_i^{q_i+1}}{q_i!}left(prod_{i=1}^{n} {(q_i+1)}^m ight)left(sum_{i=1}^{n} {(q_i+1)}^m+(k+1)^m ight) ]

    再设

    [f(n,s)=sum_{sum_{i=1}^n q_i=s}prod_{i=1}^nfrac{a_i^{q_i+1}}{q_i!}left(prod_{i=1}^{n} {(q_i+1)}^m ight) ]

    容易发现

    [egin{aligned} f(n+1,s+k)&leftarrow f(n,s)cdot frac{(k+1)^mcdot a_{n+1}^{k+1}}{k!}\ g(n+1,s+k)&leftarrow g(n,s)cdot frac{(k+1)^mcdot a_{n+1}^{k+1}}{k!} +f(n,s)cdotfrac{(k+1)^{2m}cdot a_{n+1}^{k+1}}{k!} end{aligned} ]

    边界是 (f(0,0)=1,g(0,0)=0),这样我们就可以 (mathcal O(n^3)) DP 了。

    期望得分 (20) 分。

    算法二

    我们仔细观察,设 (f(i,*),g(i,*)) 的生成函数分别为 (F_i(x),G_i(x)),那么我们有

    [egin{aligned} F_i(x)&=F_{i-1}(x)cdotleft(sum_{j=0}^{n-1}frac{(j+1)^ma_i^{j+1}}{j!}x^j ight)\ G_i(x)&=G_{i-1}(x)cdotleft(sum_{j=0}^{n-1}frac{(j+1)^ma_i^{j+1}}{j!}x^j ight)+F_{i-1}(x)cdotleft(sum_{j=0}^{n-1}frac{(j+1)^{2m}{j+1}}{j!}x^j ight) end{aligned} ]

    那么就可以 (mathcal O(n^2log n)) FFT 了,常数有点大不太能过得去,可能要优化一下常数或者用些啥技巧。

    (或者可能这档分压根就不是这么做的 qwq)

    期望得分 (35sim 40) 分。假装它就是 (40) 吧。

    算法三

    所有 (a_i) 都一样的话,我们发现转移用到的生成函数也是一样的,因此不妨设

    [T_1=sum_{j=0}^{n-1}frac{(j+1)^ma_i^{j+1}}{j!}x^j\ T_2=sum_{j=0}^{n-1}frac{(j+1)^{2m}{j+1}}{j!}x^j ]

    多项式乘法是有交换律和结合律的,简单推导可以得到

    [F_i(x)=T_1^i\ G_i(x)=icdot T_1^{i-1}cdot T_2 ]

    因为我们只需要 ([x^{n-2}]G_n(x)),我们可以多项式快速幂一下。

    时间复杂度就是 (mathcal O(nlog n)) 或者 (mathcal O(nlog^2n))

    结合算法二可以获得 (60) 分。

    算法四

    剩下的部分就是一些牛逼(套路)操作了。

    仔细观察,转移用到的生成函数除了 (a_i),其它部分都很相似,我们不妨设

    [A(x)=sum_{i=0}^{n-1}frac{(i+1)^m}{i!}\ B(x)=sum_{i=0}^{n-1}frac{(i+1)^{2m}}{i!} ]

    那么有

    [egin{aligned} F_i(x)&=F_{i-1}(x)cdot a_iA(a_ix)\ G_i(x)&=G_{i-1}(x)cdot a_iA(a_ix)+F_{i-1}(x)cdot a_iB(a_ix) end{aligned} ]

    简单推导可以得到

    [egin{aligned} F_n(x)&=prod_{i=1}^na_iprod_{i=1}^nA(a_ix)\ G_n(x)&=prod_{i=1}^na_isum_{i=1}^nprod_{j=1}^negin{cases}A(a_jx) & i eq j\B(a_jx) & i=jend{cases}\ end{aligned} ]

    (G_n(x)) 的表达式写得好一点是

    [G_n(x)=prod_{i=1}^na_iprod_{i=1}^nA(a_ix)sum_{i=1}^nleft(frac{B}{A} ight)(a_ix)\ ]

    显然对于某个多项式 (F(x)),求 (sum_{i=1}^nF(a_ix)) 比求 (prod_{i=1}^nF(a_ix)) 容易得多,我们考虑先求 ln 再求 exp

    [G_n(x)=prod_{i=1}^na_ileft(e^{sum_{i=1}^n(ln A)(a_ix)}sum_{i=1}^nleft(frac{B}{A} ight)(a_ix) ight)\ ]

    整理一下,答案就是

    [egin{aligned} ext{ans}&=(n-2)![x^{n-2}]G_n(x)\&=(n-2)!prod_{i=1}^na_i[x^{n-2}]left(e^{sum_{i=1}^n(ln A)(a_ix)}sum_{i=1}^nleft(frac{B}{A} ight)(a_ix) ight) end{aligned} ]

    现在的问题转化为,对于一个多项式 (F(x)),求 (sum_{i=1}^n F(a_ix))

    因为是求和,我们可以写成

    [sum_{i=1}^n F(a_ix)=sum_{i=0}^{n-1}x^i[x^i]F(x)sum_{j=1}^na_j^i ]

    那么现在的问题就是,对于每个 (i),求出 (sum_{j=1}^na_j^i)

    众所周知,(frac{1}{1-ax}=sum_{igeq0}a^ix^i),因此上面的问题可以有如下转化

    [sum_{j=1}^na_j^i=[x^i]sum_{j=1}^nfrac{1}{1-a_jx} ]

    这是个经典问题。因为问题规模不允许我们对于每个 (1-a_jx) 求逆后相加,所以我们考虑直接从分式入手。我们尝试分治这个和式,然后合并两边的分式的时候,就模拟分式通分后相加的过程

    这样能保证分治的时候,该区间的多项式次数为该区间长度,从而保证复杂度。

    至此我们就解决了这个问题,时间复杂度 (mathcal O(nlog^2n+nlog m))。所以 (m) 其实可以出到 (10^{18})

    注意特判 (n=1),否则你会在 UOJ 上获得 97 分的好分数,别问我是怎么知道的

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &x)
    {
    	static char ch; 
    	while (!isdigit(ch = getchar()));
    	x = ch - '0'; 
    	while (isdigit(ch = getchar()))
    		x = x * 10 + ch - '0'; 
    }
    
    const int mod = 998244353; 
    
    inline int qpow(int x, int y)
    {
    	int res = 1; 
    	for (; y; y >>= 1, x = 1LL * x * x % mod)
    		if (y & 1)
    			res = 1LL * res * x % mod; 
    	return res; 
    }
    
    inline void add(int &x, const int &y)
    {
    	x += y; 
    	if (x >= mod)
    		x -= mod; 
    }
    
    inline void dec(int &x, const int &y)
    {
    	x -= y;
    	if (x < 0)
    		x += mod; 
    }
    
    typedef std::vector<int> vi; 
    typedef std::pair<vi, vi> pvi; 
    #define mp(x, y) std::make_pair(x, y)
    
    const int MaxN = 2e5 + 5; 
    const int INF = 0x3f3f3f3f; 
    
    int fac[MaxN], fac_inv[MaxN], pwm[MaxN], ind[MaxN]; 
    
    inline void fac_init(int n)
    {
    	ind[1] = 1; 
    	for (int i = 2; i <= n; ++i)
    		ind[i] = 1LL * ind[mod % i] * (mod - mod / i) % mod; 
    
    	fac[0] = 1; 
    	for (int i = 1; i <= n; ++i)
    		fac[i] = 1LL * fac[i - 1] * i % mod; 
    
    	fac_inv[n] = qpow(fac[n], mod - 2); 
    	for (int i = n - 1; i >= 0; --i)
    		fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; 
    }
    
    namespace polynomial
    {
    	int P, L; 
    	int rev[MaxN]; 
    
    	inline void DFT_init(int n)
    	{
    		P = 0, L = 1; 
    		while (L < n)
    			L <<= 1, ++P; 
    		for (int i = 1; i < L; ++i)
    			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (P - 1)); 
    	}
    	
    	inline void DFT(vi &a, int n, int opt)
    	{
    		for (int i = 0; i < n; ++i)
    			if (i < rev[i])
    				std::swap(a[i], a[rev[i]]);
    
    		int g = opt == 1 ? 3 : (mod + 1) / 3; 
    		for (int k = 1; k < n; k <<= 1)
    		{
    			int omega = qpow(g, (mod - 1) / (k << 1)); 
    			for (int i = 0; i < n; i += k << 1)
    			{
    				int x = 1; 
    				for (int j = 0; j < k; ++j)
    				{
    					int u = a[i + j]; 
    					int v = 1LL * a[i + j + k] * x % mod; 
    					add(a[i + j] = u, v); 
    					dec(a[i + j + k] = u, v); 
    					x = 1LL * x * omega % mod; 
    				}
    			}
    		}
    		if (opt == -1)
    		{
    			int inv = ind[n]; 
    			for (int i = 0; i < n; ++i)
    				a[i] = 1LL * a[i] * inv % mod; 
    		}
    	}
    
    	inline vi plus(vi a, vi b)
    	{
    		int sze = std::max(a.size(), b.size()); 
    		a.resize(sze), b.resize(sze); 
    
    		for (int i = 0; i < sze; ++i)
    			add(a[i], b[i]); 
    		return a; 
    	}
    	inline vi mul(vi a, vi b, int lim = INF)
    	{
    		int sze = a.size() + b.size() - 1; 
    		DFT_init(sze), a.resize(L, 0), b.resize(L, 0); 
    
    		vi c(L); 
    		DFT(a, L, 1), DFT(b, L, 1); 
    		for (int i = 0; i < L; ++i)
    			c[i] = 1LL * a[i] * b[i] % mod; 
    		DFT(c, L, -1);
    
    		return c.resize(std::min(sze, lim)), c; 
    	}
    	inline vi inverse(vi a)
    	{
    		int n = a.size(), m = 1; 
    		vi b(1, qpow(a[0], mod - 2)), ta; 
    		while (m < n)
    		{
    			m <<= 1; 
    			DFT_init(m << 1); 
    
    			b.resize(L, 0); 
    			(ta = a).resize(m); 
    			ta.resize(L, 0); 
    
    			DFT(b, L, 1), DFT(ta, L, 1); 
    			for (int i = 0; i < L; ++i)
    				b[i] = 1LL * b[i] * (mod + 2 - 1LL * ta[i] * b[i] % mod) % mod; 
    			DFT(b, L, -1); 
    
    			b.resize(m, 0); 
    		}
    		return b.resize(n), b; 
    	}
    	inline vi derivative(vi a)
    	{
    		vi res(0); 
    		for (int i = 1, lim = a.size(); i < lim; ++i)
    			res.push_back(1LL * i * a[i] % mod); 
    		return res; 
    	}
    	inline vi anti_derivative(vi a)
    	{
    		vi res(1, 0); 
    		for (int i = 0, lim = a.size(); i < lim; ++i)
    			res.push_back(1LL * a[i] * ind[i + 1] % mod); 
    		return res; 
    	}
    	inline vi ln(vi a)
    	{
    		return anti_derivative(mul(derivative(a), inverse(a), a.size() - 1)); 
    	}
    	inline vi exp(vi a)
    	{
    		int n = a.size(), m = 1; 
    		vi b(1, 1), ta; 
    		while (m < n)
    		{
    			m <<= 1; 
    
    			b.resize(m, 0); 
    			vi ln_b = ln(b); 
    
    			(ta = a).resize(m); 
    			add(ta[0], 1); 
    			for (int i = 0; i < m; ++i)
    				dec(ta[i], ln_b[i]); 
    			b = mul(b, ta, m); 
    		}
    		return b.resize(n), b; 
    	}
    }
    
    vi sum; 
    int n, m; 
    int a[MaxN]; 
    
    inline pvi solve(int l, int r)
    {
    	using namespace polynomial; 
    	if (l == r)
    	{
    		vi t(1, 1); t.push_back(mod - a[l]); 
    		return mp(vi(1, 1), t); 
    	}
    	int mid = (l + r) >> 1; 
    	pvi lef = solve(l, mid), rit = solve(mid + 1, r); 
    	return mp(plus(mul(lef.first, rit.second), mul(rit.first, lef.second)), mul(lef.second, rit.second)); 
    }
    
    inline vi get_sum(vi a)
    {
    	vi res(0); int n = a.size(); 
    	for (int i = 0; i < n; ++i)
    		res.push_back(1LL * a[i] * sum[i] % mod); 
    	return res; 
    }
    
    int main()
    {
    	read(n), read(m), fac_init(MaxN - 1); 
    	for (int i = 0; i <= (n << 1); ++i)
    		pwm[i] = qpow(i, m); 
    
    	int prod = 1; 
    	for (int i = 1; i <= n; ++i)
    	{
    		read(a[i]);
    		prod = 1LL * prod * a[i] % mod; 
    	}
    
    	if (n == 1)
    		return puts(m ? "0" : "1"), 0; 
    
    	using namespace polynomial; 
    
    	pvi t = solve(1, n); 
    	sum = mul(t.first, inverse(t.second), n - 1); 
    
    	vi A(0), B(0); 
    	for (int i = 0; i < n - 1; ++i)
    	{
    		A.push_back(1LL * pwm[i + 1] * fac_inv[i] % mod); 
    		B.push_back(1LL * pwm[i + 1] * pwm[i + 1] % mod * fac_inv[i] % mod); 
    	}
    	B = get_sum(mul(B, inverse(A), n - 1)); 
    	A = exp(get_sum(ln(A))); 
    
    	int res = mul(A, B)[n - 2]; 
    	std::cout << 1LL * fac[n - 2] * prod % mod * res % mod << '
    '; 
    
    	return 0; 
    }
    
  • 相关阅读:
    如何说明白代码评审
    面试感悟----一名3年工作经验的程序员应该具备的技能(转载自@五月的仓颉)
    根据ip地址从第三方接口获取详细的地理位置
    linux安装telnet遇到的问题
    redis脑图
    数据库相关面试题
    logback系列一:名词解释
    java并发编程系列一、多线程
    logback系列二:logback在项目中的应用
    rocketmq特性(features)
  • 原文地址:https://www.cnblogs.com/cyx0406/p/LOJ2320.html
Copyright © 2011-2022 走看看