Description
题目链接:
在一个 (s) 个点的图中,存在 (s-n) 条边,使图中形成了 (n) 个连通块,第 (i) 个连通块中有 (a_i) 个点。
现在我们需要再连接 (n-1) 条边,使该图变成一棵树。对一种连边方案,设原图中第 (i) 个连通块连出了 (d_i) 条边,那么这棵树 (T) 的价值为:
你的任务是求出所有可能的生成树的价值之和,对 (998244353) 取模。
(n leq 3 imes 10^4,m leq 30)
时空限制:( exttt{5s/1GB})
Solution
算法一
由于我比较菜,所以想了半天才会这个暴力。
将每个连通块看成一个点,首先我们知道 Prufer 序列中每个点的出现次数就是度数减一,因此我们不妨考虑枚举度数序列计算。
考虑在两个大小分别为 (a) 和 (b) 的连通块之间连边有 (acdot b) 种选择,因此我们把所有边的贡献相乘,所以每种连通块的生成树对应的原树的方案数为 (prod_{i=1}^na_i^{d_i})。
设 (q_i) 表示 Prufer 序列中 (i) 的出现次数,即 (q_i=d_i-1)。如果确定了一个 (sum q_i=n-2),那么我们有
这个式子只需要 (q_i) 的信息即可计算,我们仔细观察可以发现这个式子是可以 DP 的。
首先我们将奇怪的项先提出来,得到
考虑当前考虑到前 (n) 个点有 (sum_{i=1}^nq_i=s),需要考虑的式子是下面这样的,不妨设它为 (g(n,s))
那么考虑新加入一个 (q_{n+1}=k),这个式子就变为
再设
容易发现
边界是 (f(0,0)=1,g(0,0)=0),这样我们就可以 (mathcal O(n^3)) DP 了。
期望得分 (20) 分。
算法二
我们仔细观察,设 (f(i,*),g(i,*)) 的生成函数分别为 (F_i(x),G_i(x)),那么我们有
那么就可以 (mathcal O(n^2log n)) FFT 了,常数有点大不太能过得去,可能要优化一下常数或者用些啥技巧。
(或者可能这档分压根就不是这么做的 qwq)
期望得分 (35sim 40) 分。假装它就是 (40) 吧。
算法三
所有 (a_i) 都一样的话,我们发现转移用到的生成函数也是一样的,因此不妨设
多项式乘法是有交换律和结合律的,简单推导可以得到
因为我们只需要 ([x^{n-2}]G_n(x)),我们可以多项式快速幂一下。
时间复杂度就是 (mathcal O(nlog n)) 或者 (mathcal O(nlog^2n))。
结合算法二可以获得 (60) 分。
算法四
剩下的部分就是一些牛逼(套路)操作了。
仔细观察,转移用到的生成函数除了 (a_i),其它部分都很相似,我们不妨设
那么有
简单推导可以得到
把 (G_n(x)) 的表达式写得好一点是
显然对于某个多项式 (F(x)),求 (sum_{i=1}^nF(a_ix)) 比求 (prod_{i=1}^nF(a_ix)) 容易得多,我们考虑先求 ln 再求 exp
整理一下,答案就是
现在的问题转化为,对于一个多项式 (F(x)),求 (sum_{i=1}^n F(a_ix))。
因为是求和,我们可以写成
那么现在的问题就是,对于每个 (i),求出 (sum_{j=1}^na_j^i)。
众所周知,(frac{1}{1-ax}=sum_{igeq0}a^ix^i),因此上面的问题可以有如下转化
这是个经典问题。因为问题规模不允许我们对于每个 (1-a_jx) 求逆后相加,所以我们考虑直接从分式入手。我们尝试分治这个和式,然后合并两边的分式的时候,就模拟分式通分后相加的过程。
这样能保证分治的时候,该区间的多项式次数为该区间长度,从而保证复杂度。
至此我们就解决了这个问题,时间复杂度 (mathcal O(nlog^2n+nlog m))。所以 (m) 其实可以出到 (10^{18})。
注意特判 (n=1),否则你会在 UOJ 上获得 97 分的好分数,别问我是怎么知道的。
#include <bits/stdc++.h>
template <class T>
inline void read(T &x)
{
static char ch;
while (!isdigit(ch = getchar()));
x = ch - '0';
while (isdigit(ch = getchar()))
x = x * 10 + ch - '0';
}
const int mod = 998244353;
inline int qpow(int x, int y)
{
int res = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1)
res = 1LL * res * x % mod;
return res;
}
inline void add(int &x, const int &y)
{
x += y;
if (x >= mod)
x -= mod;
}
inline void dec(int &x, const int &y)
{
x -= y;
if (x < 0)
x += mod;
}
typedef std::vector<int> vi;
typedef std::pair<vi, vi> pvi;
#define mp(x, y) std::make_pair(x, y)
const int MaxN = 2e5 + 5;
const int INF = 0x3f3f3f3f;
int fac[MaxN], fac_inv[MaxN], pwm[MaxN], ind[MaxN];
inline void fac_init(int n)
{
ind[1] = 1;
for (int i = 2; i <= n; ++i)
ind[i] = 1LL * ind[mod % i] * (mod - mod / i) % mod;
fac[0] = 1;
for (int i = 1; i <= n; ++i)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; i >= 0; --i)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
}
namespace polynomial
{
int P, L;
int rev[MaxN];
inline void DFT_init(int n)
{
P = 0, L = 1;
while (L < n)
L <<= 1, ++P;
for (int i = 1; i < L; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (P - 1));
}
inline void DFT(vi &a, int n, int opt)
{
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
int g = opt == 1 ? 3 : (mod + 1) / 3;
for (int k = 1; k < n; k <<= 1)
{
int omega = qpow(g, (mod - 1) / (k << 1));
for (int i = 0; i < n; i += k << 1)
{
int x = 1;
for (int j = 0; j < k; ++j)
{
int u = a[i + j];
int v = 1LL * a[i + j + k] * x % mod;
add(a[i + j] = u, v);
dec(a[i + j + k] = u, v);
x = 1LL * x * omega % mod;
}
}
}
if (opt == -1)
{
int inv = ind[n];
for (int i = 0; i < n; ++i)
a[i] = 1LL * a[i] * inv % mod;
}
}
inline vi plus(vi a, vi b)
{
int sze = std::max(a.size(), b.size());
a.resize(sze), b.resize(sze);
for (int i = 0; i < sze; ++i)
add(a[i], b[i]);
return a;
}
inline vi mul(vi a, vi b, int lim = INF)
{
int sze = a.size() + b.size() - 1;
DFT_init(sze), a.resize(L, 0), b.resize(L, 0);
vi c(L);
DFT(a, L, 1), DFT(b, L, 1);
for (int i = 0; i < L; ++i)
c[i] = 1LL * a[i] * b[i] % mod;
DFT(c, L, -1);
return c.resize(std::min(sze, lim)), c;
}
inline vi inverse(vi a)
{
int n = a.size(), m = 1;
vi b(1, qpow(a[0], mod - 2)), ta;
while (m < n)
{
m <<= 1;
DFT_init(m << 1);
b.resize(L, 0);
(ta = a).resize(m);
ta.resize(L, 0);
DFT(b, L, 1), DFT(ta, L, 1);
for (int i = 0; i < L; ++i)
b[i] = 1LL * b[i] * (mod + 2 - 1LL * ta[i] * b[i] % mod) % mod;
DFT(b, L, -1);
b.resize(m, 0);
}
return b.resize(n), b;
}
inline vi derivative(vi a)
{
vi res(0);
for (int i = 1, lim = a.size(); i < lim; ++i)
res.push_back(1LL * i * a[i] % mod);
return res;
}
inline vi anti_derivative(vi a)
{
vi res(1, 0);
for (int i = 0, lim = a.size(); i < lim; ++i)
res.push_back(1LL * a[i] * ind[i + 1] % mod);
return res;
}
inline vi ln(vi a)
{
return anti_derivative(mul(derivative(a), inverse(a), a.size() - 1));
}
inline vi exp(vi a)
{
int n = a.size(), m = 1;
vi b(1, 1), ta;
while (m < n)
{
m <<= 1;
b.resize(m, 0);
vi ln_b = ln(b);
(ta = a).resize(m);
add(ta[0], 1);
for (int i = 0; i < m; ++i)
dec(ta[i], ln_b[i]);
b = mul(b, ta, m);
}
return b.resize(n), b;
}
}
vi sum;
int n, m;
int a[MaxN];
inline pvi solve(int l, int r)
{
using namespace polynomial;
if (l == r)
{
vi t(1, 1); t.push_back(mod - a[l]);
return mp(vi(1, 1), t);
}
int mid = (l + r) >> 1;
pvi lef = solve(l, mid), rit = solve(mid + 1, r);
return mp(plus(mul(lef.first, rit.second), mul(rit.first, lef.second)), mul(lef.second, rit.second));
}
inline vi get_sum(vi a)
{
vi res(0); int n = a.size();
for (int i = 0; i < n; ++i)
res.push_back(1LL * a[i] * sum[i] % mod);
return res;
}
int main()
{
read(n), read(m), fac_init(MaxN - 1);
for (int i = 0; i <= (n << 1); ++i)
pwm[i] = qpow(i, m);
int prod = 1;
for (int i = 1; i <= n; ++i)
{
read(a[i]);
prod = 1LL * prod * a[i] % mod;
}
if (n == 1)
return puts(m ? "0" : "1"), 0;
using namespace polynomial;
pvi t = solve(1, n);
sum = mul(t.first, inverse(t.second), n - 1);
vi A(0), B(0);
for (int i = 0; i < n - 1; ++i)
{
A.push_back(1LL * pwm[i + 1] * fac_inv[i] % mod);
B.push_back(1LL * pwm[i + 1] * pwm[i + 1] % mod * fac_inv[i] % mod);
}
B = get_sum(mul(B, inverse(A), n - 1));
A = exp(get_sum(ln(A)));
int res = mul(A, B)[n - 2];
std::cout << 1LL * fac[n - 2] * prod % mod * res % mod << '
';
return 0;
}