Address
Solution
记 (ans_i) 表示在给定的 (13) 张牌以外,再选出 (i) 张牌,使得这 (13+i) 张牌不存在 胡的子集的方案数,那么答案就是 ((frac{1}{(4n-13)!}sum_{i=1}^{4n-13}i!(4n-13-i)!ans_i)+1)
接下来,考虑给你一个牌的集合,怎么判断它是否存在一个胡的子集。
首先判断胡的第二个条件:记 (a_i) 表示集合中有多少张第 (i) 种牌。若 (sum [a_ige 2]ge 7),则存在胡的子集。
再判断第一个条件:考虑 (dp)。记 (f_{i,j,k}) 表示考虑前 (i) 种牌,拿走 (j) 对 ((i-1,i)),拿走 (k) 个 (i),剩下的牌最多能组成多少个面子。(注意 (j) 个 ((i-1,i)),(k) 个 (i) 拿出来必须跟后面的牌组成面子)特殊地,(f_{i,j,k}=-1) 表示不存在这种状态。
记 (g_{i,j,k}) 表示考虑前 (i) 种牌,拿走 (j) 对 ((i-1,i)),拿走 (k) 个 (i),再拿走一个对子,剩下的牌最多能组成多少个面子。
考虑到 (3) 个相同的顺子(形如 (x,x+1,x+2))可以变成 (3) 个相同的刻子 (形如 (x,x,x)),因此 (j,kin[0,2])。
记集合中最大的牌为 (m),如果存在 (g_{m,j,k}ge 4),那么存在胡的子集。
考虑转移,枚举加入 (x) 张大小为 (i+1) 的牌,枚举拿走 (h) 张 (i+1),那么要组成 (k) 对 ((i,i+1)),组成 (j) 对 ((i-1,i,i+1)),再枚举要不要拿走 (i+1) 当对子,有:
考虑建一个自动机,自动机上的每一个节点对应一些不存在胡的子集的集合。每个节点都记录信息:(f_{m,j,k},g_{m,j,k},cnt)。(f_{m,j,k},g_{m,j,k},cnt) 都相同的集合对应同一个节点,注意 (m) 可以不同,所以只要记 (f_{j,k},g_{j,k},cnt)。节点之间的转移边权 (x) 表示加入 (x) 张大小为 (m+1) 的牌。
初始节点:(f_{0,0}=cnt=0),其它为 (-1)。考虑用 (dfs) 构造自动机,枚举加入 (x(x∈[0,4])) 张新牌转移即可。转移可能成环,扩展出重复状态要剪枝。(dfs) 后可得节点数为 (2091)。
记 (ch_{x,y}) 表示节点 (x) 走转移边 (y) 到达的节点,(dp_{i,j,k}) 表示考虑前 (i) 种牌,总共取走 (j) 张,目前走到自动机上的节点 (k) 的方案数。枚举第 (i) 张牌取了 (h) 张,有:$$dp_{i,j+h,ch_{k,h}}+=dp_{i-1,j,k}*c_{4-b_i}^{h-b_i}$$
其中 (b_i) 表示给定的 (13) 张牌中,有多少张大小为 (i) 的牌。
(dp) 要使用滚动数组,时间复杂度 (O(2091×n^2))。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 505, o = 3005, mod = 998244353;
struct point
{
int cnt, f[3][3], g[3][3];
inline bool check()
{
if (cnt >= 7) return 1;
for (int i = 0; i <= 2; i++)
for (int j = 0; j <= 2; j++)
if (g[i][j] >= 4) return 1;
return 0;
}
inline point trans(int x)
{
point a;
int i, j, k;
a.cnt = min(7, cnt + (x >= 2));
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
a.f[i][j] = a.g[i][j] = -1;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
{
if (f[i][j] != -1)
{
for (k = 0; i + j + k <= x && k <= 2; k++)
a.f[j][k] = max(a.f[j][k], f[i][j] + i + (x - i - j - k >= 3));
for (k = 0; i + j + k <= x - 2; k++)
a.g[j][k] = max(a.g[j][k], f[i][j] + i);
}
if (g[i][j] != -1)
{
for (k = 0; i + j + k <= x && k <= 2; k++)
a.g[j][k] = max(a.g[j][k], g[i][j] + i + (x - i - j - k >= 3));
}
}
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
a.f[i][j] = min(a.f[i][j], 4), a.g[i][j] = min(a.g[i][j], 4);
return a;
}
};
inline bool operator < (point a, point b)
{
if (a.cnt != b.cnt) return a.cnt < b.cnt;
int i, j;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
{
if (a.f[i][j] != b.f[i][j]) return a.f[i][j] < b.f[i][j];
if (a.g[i][j] != b.g[i][j]) return a.g[i][j] < b.g[i][j];
}
return 0;
}
inline bool operator == (point a, point b)
{
if (a.cnt != b.cnt) return 0;
int i, j;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
{
if (a.f[i][j] != b.f[i][j]) return 0;
if (a.g[i][j] != b.g[i][j]) return 0;
}
return 1;
}
map<point, int> id;
int cnt, fac[e], inv[e], ch[o][6], n, m, a[e], dp[2][e][o], ans;
inline void dfs(point a)
{
int x = id[a];
for (int i = 0; i <= 4; i++)
{
point b = a.trans(i);
if (b.check()) continue;
int y = id[b];
if (y) ch[x][i] = y;
else
{
id[b] = ++cnt;
ch[x][i] = cnt;
dfs(b);
}
}
}
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = (ll)res * x % mod;
y >>= 1;
x = (ll)x * x % mod;
}
return res;
}
inline void init()
{
point s;
s.cnt = 0;
int i, j;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
s.f[i][j] = s.g[i][j] = -1;
s.f[0][0] = 0;
id[s] = cnt = 1;
dfs(s);
}
inline void add(int &x, int y)
{
(x += y) >= mod && (x -= mod);
}
inline int c(int x, int y)
{
return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}
inline void prepare()
{
int i;
fac[0] = 1;
for (i = 1; i <= m; i++) fac[i] = (ll)fac[i - 1] * i % mod;
inv[m] = ksm(fac[m], mod - 2);
for (i = m - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
}
int main()
{
freopen("mahjong.in", "r", stdin);
freopen("mahjong.out", "w", stdout);
read(n); m = n << 2;
int i, j, k, h, x, y, sum = 0;
init(); prepare();
for (i = 1; i <= 13; i++) read(x), read(y), a[x]++;
dp[0][0][1] = 1;
for (i = 1; i <= n; i++)
{
int nxt = i & 1, lst = nxt ^ 1;
for (j = 0; j <= sum + 4; j++)
for (k = 1; k <= cnt; k++)
dp[nxt][j][k] = 0;
for (j = 0; j <= sum; j++)
for (k = 1; k <= cnt; k++)
if (dp[lst][j][k])
{
int v = dp[lst][j][k];
for (h = a[i]; h <= 4; h++)
if (ch[k][h])
dp[nxt][j + h][ch[k][h]] = (dp[nxt][j + h][ch[k][h]] + (ll)v
* c(4 - a[i], h - a[i])) % mod;
}
sum += 4;
}
for (i = 1; i <= m - 13; i++)
{
sum = 0;
for (j = 1; j <= cnt; j++) add(sum, dp[n & 1][13 + i][j]);
ans = (ans + (ll)sum * fac[i] % mod * fac[m - 13 - i]) % mod;
}
ans = (ll)ans * inv[m - 13] % mod;
add(ans, 1);
cout << ans << endl;
fclose(stdin);
fclose(stdout);
return 0;
}