题目大意
题解
怎么又不是正解啊
考虑算重的情况:
有一个格子(i,j),(i,1..j)和(1..i-1,j)刚好被算了一次,横竖就可以在(i,j)上有两种放法
硬点一下,当第i行选了ki时(i,ki +1)不能被竖列放,这样就不会算重
把每一列的生成函数搞出来是这样:
(A(x)=sum frac{n+1-i}{i!}x^i)
最后(A(x)[x^i])表示有i行确定,那么就有n-i行刚好放了m(要除(n-i)!),所以答案就是
(ans=sum A^m(x)[x^i]/(n-i)!)
(A^m(x))可以快速幂求,但是(应该)过不了
这个i!看着就很EGF,用泰勒公式搞♂一下
泰勒公式:(e^x=1+x+frac{x^2}{2!}+frac{x^3}{3!}+...=sum frac{x^i}{i!})
(A(x)=sum frac{n+1-i}{i!}x^i)
(=sum frac{n+1}{i!}x^i-sum frac{i}{i!}x^i)
(=(n+1)e^x-xsum_{i<n} frac{x^i}{i!})
(=(n+1)e^x-xe^x)
(=(n+1-x)e^x)
那么(A^m(x))就是
(A^m(x)=(n+1-x)^me^{mx})
左边二项式展开,右边泰勒展开,卷一下即可
简单又自然
code
#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define C(n,m) (jc[n]*Jc[m]%998244353*Jc[(n)-(m)]%998244353)
#define min(a,b) (a<b?a:b)
#define mod 998244353
#define Mod 998244351
#define ll long long
#define G 3
//#define file
using namespace std;
ll A[1048576],B[1048576],a[1048576],b[1048576],w[500001],jc[500001],Jc[500001],ans;
int N,len,n,m,i,j,k,l;
ll qpower(ll a,int b) {ll ans=1;while (b) {if (b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;} return ans;}
void swap(int &x,int &y) {int z=x;x=y;y=z;}
ll dft(ll *a,int type)
{
int i,j,k,l,S=N,s1=2,s2=1;
fo(i,0,N-1)
{
j=i;k=0;
fo(l,1,len) k=k*2+(j&1),j>>=1;
A[k]=a[i];
}
memcpy(a,A,N*8);
fo(i,1,len)
{
ll W=(type==1)?qpower(G,(mod-1)/s1):qpower(G,(mod-1)-(mod-1)/s1);
S>>=1;
fo(j,0,S-1)
{
ll w=1;
fo(k,0,s2-1)
{
ll u=a[j*s1+k],v=a[j*s1+k+s2]*w;
a[j*s1+k]=(u+v)%mod;
a[j*s1+k+s2]=(u-v)%mod;
w=w*W%mod;
}
}
s1<<=1,s2<<=1;
}
}
void mul(ll *a,ll *b)
{
ll s=qpower(N,Mod);
int i;
memset(B,0,sizeof(B));
memcpy(B,b,4*N);
dft(a,1);
dft(B,1);
fo(i,0,N-1) a[i]=a[i]*B[i]%mod;
dft(a,-1);
fo(i,0,N/2-1) a[i]=a[i]*s%mod;
fo(i,N/2,N-1) a[i]=0;
}
int main()
{
#ifdef file
freopen("agc035F.in","r",stdin);
#endif
scanf("%d%d",&n,&m);len=ceil(log2(n+1))+1;N=qpower(2,len);
if (n>m) swap(n,m);
jc[0]=jc[1]=Jc[0]=Jc[1]=w[1]=1;fo(i,2,500000) w[i]=mod-w[mod%i]*(mod/i)%mod,jc[i]=jc[i-1]*i%mod,Jc[i]=Jc[i-1]*w[i]%mod;
fo(i,0,n) a[i]=qpower(n+1,m-i)*C(m,i)*qpower(-1,i)%mod; //or min(n,m)
fo(i,0,n) b[i]=qpower(m,i)*Jc[i]%mod;
mul(a,b);
fo(i,0,n) ans=(ans+Jc[n-i]*a[i])%mod;
fo(i,1,n) ans=ans*i%mod;
printf("%lld
",(ans+mod)%mod);
fclose(stdin);
fclose(stdout);
return 0;
}