逆推期望
#include<bits/stdc++.h> using namespace std; #define ll long long #define pb(x) push_back(x) const int maxn = 1e3+5; const ll mod = 998244353; struct node { ll x,y; ll val; bool operator < ( const node &b) const { return val < b.val; } }; node a[maxn*maxn]; ll sumr,sumr2,sumc,sumc2,sumdp; //ll arr[maxn]; ll dp[maxn][maxn]; ll mul(ll a,ll b) { return (a*b)%mod; } ll ksm(ll a,ll b) { ll res = 1; while(b > 0) { if(b & 1) res = mul(res,a); a = mul(a,a); b >>= 1; } return res; } ll add(ll a,ll b) { a += b; while(a >= mod) a -= mod; while(a < 0) a += mod; return a; } ll inv(ll a) { ll ia = ksm(a,mod-2); assert(mul(a,ia) == 1); return ia; } int main() { ll n,m; ll i,j,k; ll len; scanf("%lld %lld",&n,&m); len = 0; for(i=1;i<=n;++i) { for(j=1;j<=m;++j) { a[len].x = i; a[len].y = j; scanf("%lld",&a[len].val); len ++; } } sort(a,a+len); //for(i=0;i<len;++i) // printf("%lld %lld %lld ",a[i].x,a[i].y,a[i].val); memset(dp,0,sizeof(dp)); ll l,r; l = 0; sumr = sumr2 = sumc2 = sumc = sumdp = 0; while(l < n*m) { r = l; while(a[r].val == a[l].val && r < n*m) r ++; //cout << l << " " << r << endl; ll il = -1; if(l != 0) il = inv(l); for(i=l;i<r;++i) { ll rr,cc; rr = a[i].x; cc = a[i].y; if(il == -1) { dp[rr][cc] = 0; continue; } dp[rr][cc] = add(dp[rr][cc],mul(sumdp,il)); dp[rr][cc] = add(dp[rr][cc],mul(rr,rr)); dp[rr][cc] = add(dp[rr][cc],mul(cc,cc)); dp[rr][cc] = add(dp[rr][cc],mul(sumr2,il)); dp[rr][cc] = add(dp[rr][cc],mul(sumc2,il)); dp[rr][cc] = add(dp[rr][cc],mul(mul(-2*rr,sumr),il)); dp[rr][cc] = add(dp[rr][cc],mul(mul(-2*cc,sumc),il)); } for(i = l; i < r; ++i) { int rr,cc; rr = a[i].x; cc = a[i].y; sumdp = add(sumdp,dp[rr][cc]); sumr2 = add(sumr2,mul(rr,rr)); sumc2 = add(sumc2,mul(cc,cc)); sumr = add(sumr,rr); sumc = add(sumc,cc); } l = r; } ll c,b; scanf("%lld %lld",&c,&b); // cout << endl; cout << dp[c][b] << endl; } /* 1 4 1 1 2 1 1 3 2 3 1 5 7 2 3 1 1 2 */
这题是真的痛苦
从各个val低于指定位置val的点,向指定位置去推
至于为什么要用x、x²等前缀和,写下公式多看下就懂了