题目大意
(T) 组数据,每次给出 (n,s,a_0,a_1,a_2,a_3),求以下式子的值:
[left[sum_{i=0}^nleft(inom{n}{i}cdot s^icdot a_{imathrm{ mod } 4}
ight)
ight]mathrm{ mod } 998244353
]
(1leq nleq 10^{18},1leq s,a_0,a_1,a_2,a_3leq 10^8)
题解
首先观察式子,式子中带着组合数,可能需要使用二项式定理来消除组合数。乘上的 (a_i) 与 (imathrm{ mod } 4) 的值有关,可以枚举 (imathrm{ mod } 4) 的值来分别计算每个 (a_i) 的贡献。然后就来推以下式子,推的过程中尽可能凑出 (inom{n}{i}x^iy^{n-i}) 这样的二项式定理形式。
[sum_{i=0}^nleft(inom{n}{i}cdot s^icdot a_{imathrm{ mod } 4}
ight)=sum_{k=0}^3sum_{i=0}^ninom{n}{i}s^ia_k[4|i-k]\
=frac{1}{4}sum_{k=0}^3sum_{i=0}^ninom{n}{i}s^ia_ksum_{j=0}^3omega_4^{(i-k)j}\
=frac{1}{4}sum_{k=0}^3a_ksum_{j=0}^3sum_{i=0}^ninom{n}{i}s^iomega_4^{ij}omega_4^{-kj}\
=frac{1}{4}sum_{k=0}^3a_ksum_{j=0}^3omega_4^{-kj}sum_{i=0}^ninom{n}{i}s^i(omega_4^{j})^i\
=frac{s^n}{4}sum_{k=0}^3a_ksum_{j=0}^3omega_4^{-kj}sum_{i=0}^ninom{n}{i}left(frac{1}{s}
ight)^{n-i}(omega_4^{j})^i\
=frac{s^n}{4}sum_{k=0}^3a_ksum_{j=0}^3omega_4^{-kj}left(frac{1}{s}+omega_4^j
ight)^n
]
(998244353) 的原根是 (3),所以可以用 (3^{frac{998244353-1}{4}}) 来代替 (omega_4),然后直接计算即可。
时间复杂度 (O(log n))。
Code
#include <bits/stdc++.h>
using namespace std;
#define RG register int
#define LL long long
template<typename elemType>
inline void Read(elemType& T) {
elemType X = 0, w = 0; char ch = 0;
while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
T = (w ? -X : X);
}
const LL MOD = 998244353;
LL qpow(LL b, LL n) {
LL x = 1, Power = b % MOD;
while (n) {
if (n & 1) x = x * Power % MOD;
Power = Power * Power % MOD;
n >>= 1;
}
return x;
}
const LL g = qpow(3, (MOD - 1) / 4);
LL n, s, a[4];
int T;
LL calc() {
LL sinv = qpow(s, MOD - 2), res = 0;
for (int k = 0;k < 4;++k) {
LL sum = 0;
for (int j = 0;j < 4;++j) {
LL temp = qpow(sinv + qpow(g, j), n) * qpow(qpow(g, k * j), MOD - 2) % MOD;
sum = (sum + temp) % MOD;
}
sum = sum * a[k] % MOD;
res = (res + sum) % MOD;
}
res = res * qpow(s, n) % MOD * qpow(4, MOD - 2) % MOD;
return res;
}
int main() {
Read(T);
while (T--) {
Read(n);Read(s);
for (int i = 0;i < 4;++i) Read(a[i]);
LL ans = calc();
printf("%lld
", ans);
}
return 0;
}