思路:
最后推出公式,发现只要用高斯消元求出矩阵A的逆即可
Code:
#pragma GCC optimize(3) #pragma GCC optimize(2) #include <map> #include <set> #include <array> #include <queue> #include <stack> #include <cmath> #include <vector> #include <cstdio> #include <cstring> #include <sstream> #include <iostream> #include <stdlib.h> #include <algorithm> #include <unordered_map> using namespace std; typedef long long ll; typedef pair<int, int> PII; #define Time (double)clock() / CLOCKS_PER_SEC #define sd(a) scanf("%d", &a) #define sdd(a, b) scanf("%d%d", &a, &b) #define slld(a) scanf("%lld", &a) #define slldd(a, b) scanf("%lld%lld", &a, &b) const int N = 200 + 10; const ll M = 4e12; const int mod = 998244353; int n; ll a[N][N], b[N], p[N][N]; ll ans = 0; ll qmi(ll a, ll b, ll p){ ll res = 1; while(b){ if(b & 1) res = res * a % p; a = a * a % p; b >>= 1; } return res; } void gauss() { int c, r; for (c = 0, r = 0; c < n; c ++ ) { int t = r; for (int i = r; i < n; i ++ ) if (a[i][c] > 0){ t = i; break; } if (a[t][c] == 0) return; for (int i = 0; i < n; i ++) { swap(a[t][i], a[r][i]); swap(p[t][i], p[r][i]); } ll k = qmi(a[r][r], mod - 2, mod); for(int i = 0; i < n; i ++){ a[r][i] = k * a[r][i] % mod; p[r][i] = k * p[r][i] % mod; } for (int i = 0; i < n; i ++){ if(i == r) continue; if (a[i][c] != 0){ k = a[i][c]; for (int j = 0; j < n; j ++){ a[i][j] = (a[i][j] - a[r][j] * k % mod + mod) % mod; p[i][j] = (p[i][j] - p[r][j] * k % mod + mod) % mod;; } } } r ++ ; } } int main() { #ifdef ONLINE_JUDGE #else freopen("/home/jungu/code/in.txt", "r", stdin); freopen("/home/jungu/code/out.txt", "w", stdout); // freopen("/home/jungu/code/out.txt","w",stdout); #endif // ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); while(~sd(n)){ memset(p, 0, sizeof(p)); for(int i = 0; i < n; i ++){ p[i][i] = 1; } for(int i = 0; i < n; i ++){ for(int j = 0; j < n; j ++){ slld(a[i][j]); a[i][j] = (a[i][j] % mod + mod) % mod; } } for(int i = 0; i < n; i ++){ slld(b[i]); b[i] = (b[i] % mod + mod) % mod; } gauss(); ans = 0; for(int i = 0; i < n; i ++){ for(int j = 0; j < n; j ++){ ans = (ans + b[j] * p[j][i] % mod * b[i] % mod) % mod; } } cout << ans << endl; } return 0; }