Sky Full of Stars
题目链接:http://codeforces.com/problemset/problem/997/C
数据范围:略。
题解:
首先考虑拟对象,如果至少有一行完全相等即可。
这个的答案就需要多步容斥:$sumlimits_{i = 1} ^ n (-1)^{i + 1}cdot 3 ^ icdot 3 ^ {n cdot (n - i)}$。
那么至少有一列的答案跟这个一样。
把他俩加一起就是答案么?我们需要减去什么?
显然,需要减掉至少有一行且至少有一列的。
这个怎么弄?
是这样的,如果我们钦定了$i$行$j$列都必须相等之后,$x$行$y$列相等这种情况会被算$C_{x} ^ {i} imes C_{y} ^ {j}$次。
假设,$F_{(i,j)}$表示$i$行$j$列都相等的答案。
假设$A_{(i,j)}$表示$F_{(i,j)}$的容斥系数,发现当$A_{(i,j)} = (-1)^{i + j + 1}$时满足题意。
故此,至少有一列切至少有一行的答案是:
$sumlimits_{i = 1} ^ nsumlimits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} cdot 3cdot 3^{(n - i) imes (n - j)}$。
这个怎么看都是$O(n^2)$的对不对....
我们考虑把他转化转化:
$=sumlimits_{i = 1} ^ nsumlimits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} cdot 3^{(n - i) imes (n - j) + 1}$
$=sumlimits_{i = 1} ^ nsumlimits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} cdot 3^{n^2 - in - jn +ij + 1}$
$=sumlimits_{i = 1} ^ nsumlimits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} cdot 3^{n^2}cdot 3^{-in}cdot 3^{-jn}cdot 3^{ij}$
$=-3^{n^2}cdot sumlimits_{i = 1} ^ n (-1) ^ icdot C_{n} ^ {i}cdot 3^{-in} sumlimits_{j = 1} ^ n (-1) ^ j C_{n}^{j} cdot 3^{-jn}cdot 3^{ij}$
$=-3^{n^2}cdot sumlimits_{i = 1} ^ n (-1) ^ icdot C_{n} ^ {i}cdot 3^{-in} sumlimits_{j = 1} ^ n (-1) ^ j C_{n}^{j} cdot (3^{-n})^jcdot (3^i)^j$
$=-3^{n^2}cdot sumlimits_{i = 1} ^ n (-1) ^ icdot C_{n} ^ {i}cdot 3^{-in} sumlimits_{j = 1} ^ n C_{n}^{j} cdot (-3^{i - n})^j$
发现,第二个$sum$后面的式子类似于二项式反演,$(-3^{i - n})$相当于$(a-b)^n$中的$b$。
故此式子可以被化简成:
$=-3^{n^2}cdot sumlimits_{i = 1} ^ n (-1) ^ icdot C_{n} ^ {i}cdot 3^{-in} cdot (1 - (-3^{i - n}))^n$。
这个就能$O(n)$求了。
代码:
#include <bits/stdc++.h>
#define setIO(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)
#define N 1000010
using namespace std;
typedef long long ll;
const int mod = 998244353 ;
int fac[N], inv[N];
int qpow(int x, int y) {
int ans = 1;
while (y) {
if (y & 1) {
ans = (ll)ans * x % mod;
}
y >>= 1;
x = (ll)x * x % mod;
}
return ans;
}
inline int C(int x, int y) {
return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}
int main() {
// setIO("c");
int n;
cin >> n ;
fac[0] = inv[0] = 1;
for (int i = 1; i <= n; i ++ ) {
fac[i] = (ll)fac[i - 1] * i % mod;
inv[i] = qpow(fac[i], mod - 2);
}
int ans = 0;
for (int i = 1; i <= n; i ++ ) {
ans = (ans + (ll)qpow(3, ((ll)n * (n - i) % (mod - 1) + i) % (mod - 1))
* qpow(mod - 1, i + 1) % mod
* C(n, i) % mod) % mod;
}
ans = ans * 2 % mod;
int mdl = 0;
for (int i = 0; i < n; i ++ ) {
int t = (mod - qpow(3, i)) % mod;
mdl = (mdl + (ll)C(n, i)
* qpow(mod - 1, i + 1) % mod
* ( ( (ll) qpow(t + 1, n) + mod - qpow(t, n) ) % mod) % mod) % mod;
}
ans = (ans + (ll)mdl * 3) % mod;
cout << ans << endl ;
fclose(stdin), fclose(stdout);
return 0;
}