Problem
给定一个大小为 (n) 的图,求所有生成树权值和的 (k) 次方和。
Sol
经典题。把边权设成 (e^{wx}) 即可。最终答案为 ([x^k]k!A(x)),(A(x)) 为求行列式得到的多项式。复杂度 (mathcal O(n^3k^2))。
2020 年省选的 D2T3 用到该 trick,只要维护 (1+wx)(展开前两项)即可
Code
#include <bits/stdc++.h>
using std::vector; using std::min; using std::max;
typedef vector<int> poly;
const int N = 35, P = 998244353;
int n, k, fac[N], ifac[N];
int qpow(int a, int b) {
int t = 1;
for (; b; b >>= 1, a = 1LL * a * a % P)
if (b & 1) t = 1LL * t * a % P;
return t;
}
poly a[N][N], d[N], e[N][N];
void fix(poly &a) {
int x = a.size();
while (x > 1 && !a[x - 1]) x--;
a.resize(min(x, k + 1));
}
poly operator + (poly a, poly b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++)
a[i] = (a[i] + b[i]) % P;
return fix(a), a;
}
poly operator - (poly a, poly b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++)
a[i] = (a[i] - b[i] + P) % P;
return fix(a), a;
}
poly operator * (poly a, poly b) {
poly c(a.size() + b.size() - 1);
for (int i = 0; i < a.size(); i++)
for (int j = 0; j < b.size(); j++)
c[i + j] = (c[i + j] + 1LL * a[i] * b[j]) % P;
return fix(c), c;
}
poly operator ~ (poly a) {
poly b(k + 1);
a.resize(k + 1); b[0] = qpow(a[0], P - 2);
for (int i = 1; i <= k; i++) {
int tmp = 0;
for (int j = 1; j <= i; j++)
tmp = (tmp - 1LL * a[j] * b[i - j] % P + P) % P;
b[i] = 1LL * b[0] * tmp % P;
}
return b;
}
poly exp(int x) {
poly a(k + 1);
for (int i = 0, t = 1; i <= k; i++) {
a[i] = 1LL * t * ifac[i] % P;
t = 1LL * t * x % P;
}
return a;
}
poly Det(int n) {
poly res = {1};
for (int i = 1; i <= n; i++) {
int r = i;
if (!a[i][i][0]) {
res = poly{0} - res;
for (int j = i + 1; j <= n; j++)
if (a[j][i][0]) { r = j; break; }
if (r == i) return poly{0};
for (int j = i; j <= n; j++)
std::swap(a[i][j], a[r][j]);
}
res = res * a[i][i];
for (int j = i + 1; j <= n; j++) {
poly tmp = a[j][i] * ~a[i][i];
for (int k = i; k <= n; k++)
a[j][k] = a[j][k] - tmp * a[i][k];
}
}
return res;
}
int Mtree(int n) {
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
a[i][j] = (i == j ? d[i] : poly{0}) - e[i][j];
return 1LL * fac[k] * Det(n)[k] % P;
}
int main() {
scanf("%d%d", &n, &k);
fac[0] = 1;
for (int i = 1; i <= k; i++)
fac[i] = 1LL * fac[i - 1] * i % P;
ifac[k] = qpow(fac[k], P - 2);
for (int i = k - 1; ~i; i--)
ifac[i] = 1LL * ifac[i + 1] * (i + 1) % P;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++) {
int w; scanf("%d", &w);
e[i][j] = exp(w);
d[j] = d[j] + e[i][j];
}
printf("%d", Mtree(n - 1));
return 0;
}