直接求不好求,考虑用总方案数减去不合法方案数。设 (f_{i,j}) 为至少有 (i) 行 (j) 列颜色相同的方案数,设 (g_{i,j}) 为恰好有 (i) 行 (j) 列颜色相同的方案数,得:
[large f(x,y)=sum_{i=x}^nsum_{j=y}^ninom{i}{x}inom{i}{y}g(i,j)
]
然后发现其可以高维二项式反演,高维反演即为每个维度的反演系数直接乘起来,进行高维二项式反演得:
[large g(x,y)=sum_{i=x}^nsum_{j=y}^n(-1)^{i+j-x-y}inom{i}{x}inom{i}{y}f(i,j)
]
即答案为:
[large ans=3^{n^2}-g(0,0)
]
考虑计算 (f),得:
[large f(i,j) =
egin{cases}
inom{n}{i}inom{n}{j}3^{(n-i)(n-j)+1} &ij
ot = 0 \
\
inom{n}{i+j}3^{n(n-i-j)+i+j} &ij=0,i+j
ot =0 \
\
3^{n^2} &i+j=0
end{cases}\
]
当行列都存在相同颜色时,因为行列有交叉部分,所以其只能取同一种颜色,而只有行存在相同颜色或只有列存在相同颜色时,其颜色的选取时不受限制的。
然后计算 (g(0,0)):
[largeegin{aligned}
g(0,0)&=sum_{i=0}^nsum_{j=0}^n(-1)^{i+j}f(i,j) \
&=sum_{i=1}^nsum_{j=1}^n(-1)^{i+j}f(i,j)+2sum_{i=1}^n(-1)^if(i,0)+f_{0,0}
end{aligned} \
]
因为第一个和式的原因,直接计算的复杂度是 (O(n^2)) 的,无法接受,考虑进行化简:
[largeegin{aligned}
&sum_{i=1}^nsum_{j=1}^n(-1)^{i+j}f(i,j) \
=&sum_{i=1}^nsum_{j=1}^n(-1)^{i+j}inom{n}{i}inom{n}{j}3^{(n-i)(n-j)+1} \
=&sum_{i=1}^n(-1)^{i}inom{n}{i}sum_{j=1}^n(-1)^{j}inom{n}{j}3^{n^2-(i+j)n+ij+1} \
=&3^{n^2+1}sum_{i=1}^n(-1)^{i}inom{n}{i}3^{-in}sum_{j=1}^ninom{n}{j}(-3^{i-n})^j \
=&3^{n^2+1}sum_{i=1}^n(-1)^{i}inom{n}{i}3^{-in}((1-3^{i-n})^n-1) \
end{aligned}\
]
最终答案为:
[large ans=3^{n^2}-g(0,0)=-3^{n^2+1}sum_{i=1}^n(-1)^{i}inom{n}{i}3^{-in}((1-3^{i-n})^n-1)-2sum_{i=1}^n(-1)^iinom{n}{i}3^{n(n-i)+i}
]
#include<bits/stdc++.h>
#define maxn 1000010
#define p 998244353
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
ll n,v1,v2;
ll fac[maxn],ifac[maxn];
ll qp(ll x,ll y)
{
ll v=1;
while(y)
{
if(y&1) v=v*x%p;
x=x*x%p,y>>=1;
}
return v;
}
ll C(int n,int m)
{
return fac[n]*ifac[m]%p*ifac[n-m]%p;
}
void init()
{
fac[0]=ifac[0]=1;
for(int i=1;i<=n;++i) fac[i]=fac[i-1]*i%p;
ifac[n]=qp(fac[n],p-2);
for(int i=n-1;i;--i) ifac[i]=ifac[i+1]*(i+1)%p;
}
int main()
{
read(n),init();
for(int i=1;i<=n;++i)
{
ll v=C(n,i)*qp(qp(3,i*n),p-2)%p*(qp((1-qp(qp(3,n-i),p-2)+p)%p,n)-1+p)%p;
if(i&1) v1=(v1-v+p)%p;
else v1=(v1+v+p)%p;
}
for(int i=1;i<=n;++i)
{
ll v=C(n,i)*qp(3,n*(n-i)%(p-1)+i)%p;
if(i&1) v2=(v2-v+p)%p;
else v2=(v2+v+p)%p;
}
printf("%lld",((-qp(3,(n*n+1)%(p-1))*v1%p+p)%p-2*v2%p+p)%p);
return 0;
}