比赛时推出来了没写,血亏,赛后补上。
#include <bits/stdc++.h>
using namespace std;
#define LL long long
const int maxn = 2100000;
LL qpow(LL b, LL n, LL MOD) {
if (MOD == 1) return 0;
LL x = 1, Power = b % MOD;
while (n) {
if (n & 1) x = x * Power % MOD;
Power = Power * Power % MOD;
n >>= 1;
}
return x;
}
const LL P = 998244353, G = 3, Gi = 332748118;
namespace Poly {
int r[maxn];
int L, limit;
LL pinv(LL x) { return qpow(x, P - 2, P); }
//快速数论变换 type=1:正变换 type=-1:逆变换
void NTT(LL* A, int type) {
for (int i = 0; i < limit; i++)
if (i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
LL Wn = qpow(type == 1 ? G : Gi, (P - 1) / (mid << 1), P);
for (int j = 0; j < limit; j += (mid << 1)) {
LL w = 1;
for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
int x = A[j + k], y = w * A[j + k + mid] % P;
A[j + k] = (x + y) % P;
A[j + k + mid] = (x - y + P) % P;
}
}
}
if (type == 1) return;
LL inv_limit = pinv(limit);
for (int i = 0; i < limit; ++i)
A[i] = A[i] * inv_limit % P;
}
//多项式卷积 a(x): N-1次多项式 b(x): M-1次多项式
void Conv(LL* a, int N, LL* b, LL M, LL* c) {
L = 0; limit = 1;
while (limit <= N + M) limit <<= 1, L++;
for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1); NTT(b, 1);
for (int i = 0; i < limit; i++) c[i] = a[i] * b[i] % P;
NTT(c, -1);
}
}
struct node { int len, id; };
struct cmp { bool operator()(const node& a, const node& b) { return a.len > b.len; } };
priority_queue<node, vector<node>, cmp> Q;
vector<LL> vec[100010];
bool vis[100010];
int nxt[100010], num[100010];
int inv[100010], fact[100010], finv[100010];
LL a[maxn], b[maxn];
int n, m;
void Init() {
inv[1] = fact[0] = fact[1] = finv[0] = finv[1] = 1;
for (int i = 2;i <= 100000;++i) {
inv[i] = ((-1LL * (P / i) * inv[P % i]) % P + P) % P;
fact[i] = 1LL * fact[i - 1] * i % P;
finv[i] = 1LL * finv[i - 1] * inv[i] % P;
}
}
LL C(LL n, LL m) {
if (m<0 || m>n) return 0;
return 1LL * fact[n] * finv[m] % P * finv[n - m] % P;
}
void Convolution(int u, int v) {
int n = vec[u].size(), m = vec[v].size();
int limit = 1;while (limit <= n + m) limit <<= 1;
fill(a, a + limit, 0);
fill(b, b + limit, 0);
for (int i = 0;i < n;++i) a[i] = vec[u][i];
for (int i = 0;i < m;++i) b[i] = vec[v][i];
Poly::Conv(a, n, b, m, a);
vec[u].resize(n + m - 1);
for (int i = 0;i < n + m - 1;++i)
vec[u][i] = a[i];
}
LL solve() {
while (Q.size() > 1) {
int u = Q.top().id; Q.pop();
int v = Q.top().id; Q.pop();
Convolution(u, v);
Q.push((node) { (int)vec[u].size(), u });
}
vector<LL>& g = vec[Q.top().id];
LL ans = 0;
for (LL i = 0;i < g.size();++i)
ans = (ans + ((i & 1) ? -1LL : 1LL) * fact[n - i] * g[i] % P) % P;
ans = (ans % P + P) % P;
return ans;
}
int main() {
Init();
scanf("%d", &n);
for (int i = 1;i <= n;++i)
scanf("%d", &nxt[i]);
for (int i = 1;i <= n;++i) {
if (vis[i]) continue;
int u = i; num[++m] = 1;
vis[u] = true;
while (!vis[nxt[u]]) { u = nxt[u]; vis[u] = true; ++num[m]; }
}
for (int i = 1;i <= m;++i) {
vec[i].resize(num[i] + 1);
for (int j = 0;j <= num[i];++j)
vec[i][j] = C(num[i], j);
vec[i][num[i]] = (vec[i][num[i]] + P - 1) % P;
Q.push((node) { (int)vec[i].size(), i });
}
printf("%lld\n", solve());
return 0;
}