题解
首先我们从(d(i,j)=0)的情况考虑,我们把所有这样的(i,j)缩到一起,那么对于每个缩好的联通块来说,所有(0)边权的边是可以连通整个联通块的。
这个方案数我们可以dp出来,设f表示i个点的合法联通块个数,g表示没有限制的合法联通块的个数。
(g[i]=(k+1)^{frac{(i*(i-1))}{2}})
[f[i]=g[i]-sum_{j=1}^{j<i}f[j]*g[i-j]*inom{i-1}{j-1}*k^{j*(i-j)}
]
这个转移相当于是固定了一个点然后枚举一个和这个点不连通的联通块。
然后考虑联通块之间的边如何计算。
对于两个联通块之间的(dist),有两种情况,第一是这个(dist)是通过别的联通块转移得来的,还有可能是自己本身的边权。
如果是第一种情况,那只要边权大于等于最短路就(OK)了,那就是((k+dist+1)^{edge}),如果只由自己转移,那就是((k+dist+1)^{edge}-(k-dist)^{edge})。
到这里可能会有疑问,为什么前面要缩零边呢?因为我们发现下面的计算其实是有顺序的,如果图中有零边的话,最短路的转移就会出现环,那我们的计算就不对了。
代码
#include<bits/stdc++.h>
#define N 409
using namespace std;
typedef long long ll;
const int mod=998244353;
ll a[N][N],jie[N],ni[N],size[N],f[N],g[N];
bool vis[N][N];
int n,k,fa[N];
inline ll rd(){
ll x=0;char c=getchar();bool f=0;
while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return f?-x:x;
}
int find(int x){return fa[x]=fa[x]==x?x:find(fa[x]);}
inline ll C(int n,int m){return jie[n]*ni[m]%mod*ni[n-m]%mod;}
inline ll power(ll x,ll y){
ll ans=1;
while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;}
return ans;
}
inline ll ny(ll x){return power(x,mod-2);}
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
int main(){
n=rd();k=rd();
jie[0]=1;
for(int i=1;i<=n;++i)jie[i]=jie[i-1]*i%mod;
ni[n]=ny(jie[n]);
for(int i=n-1;i>=0;--i)ni[i]=ni[i+1]*(i+1)%mod;
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)a[i][j]=rd();
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
if(a[i][j]>k||a[i][j]!=a[j][i]){
puts("0");return 0;
}
for(int i=1;i<=n;++i)if(a[i][i]){puts("0");return 0;}
for(int l=1;l<=n;++l)
for(int i=1;i<=n;++i)if(i!=l)
for(int j=1;j<=n;++j)if(j!=i&&j!=l)
if(a[i][l]+a[l][j]<a[i][j]){puts("0");return 0;}
for(int i=1;i<=n;++i)size[i]=1,fa[i]=i;
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)if(a[i][j]==0){
int xx=find(i),yy=find(j);
if(xx!=yy){
fa[xx]=yy;
size[yy]+=size[xx];
}
}
for(int i=1;i<=n;++i){
f[i]=g[i]=power(k+1,i*(i-1)/2);
for(int j=1;j<i;++j)
f[i]=f[i]-f[j]*C(i-1,j-1)%mod*g[i-j]%mod*power(k,j*(i-j))%mod,MOD(f[i]=f[i]+mod);
}
ll ans=1;
for(int i=1;i<=n;++i)
for(int j=i+1;j<=n;++j){
int xx=find(i),yy=find(j);
if(vis[xx][yy]||xx==yy)continue;
vis[xx][yy]=vis[yy][xx]=1;
bool tag=0;
for(int l=1;l<=n;++l){
int zz=find(l);
if(zz==xx||zz==yy)continue;
if(a[i][l]+a[l][j]==a[i][j]){tag=1;break;}
}
if(!tag)(ans*=
(power(k-a[i][j]+1,size[xx]*size[yy])-power(k-a[i][j],size[xx]*size[yy])+mod)%mod)%=mod;
else ans=ans*power(k-a[i][j]+1,size[xx]*size[yy])%mod;
}
for(int i=1;i<=n;++i)if(find(i)==i)ans=ans*f[size[i]]%mod;
cout<<ans;
return 0;
}