Description
令 (p = 998244353)。
给你整数 (n,c,d) 。现有整数 (x_1,dots,x_n) 和 (b_1, dots , b_n) 满足 (0 le x_1 , dots , x_n , b_1 , dots , b_n < p) ,且对于 (1le i le n) 满足:
[sum_{j = 1}^{n} gcd(i, j)^c cdot operatorname{lcm}(i, j)^d cdot x_j equiv b_i pmod{p}
]
有 (q) 个询问,每次给出 (b_1, dots, b_n),请你解出 (x_1, dots, x_n) 的值。
Solution
先丢两个链接:
首先,根据小学奥数知识我们可以知道 (operatorname{lcm}(i,j)=frac{ij}{gcd(i,j)}),于是原式子变为
[sum_{j = 1}^{n} gcd(i, j)^{c-d} i^d j^d x_j equiv b_i pmod{p}
]
其实写成函数的形式后这种解法都可用
[sum_{j = 1}^{n} f(gcd(i, j)) h(i) h(j) x_j equiv b_i pmod{p}
]
枚举 (gcd)
[egin{split}
b_i &= sum limits_{d=1}^{n}f(d) sum limits_{j=1}^{n} left[ gcd(i,j) = d
ight] h(i) h(j) x_j \
&= sum limits_{d=1}^{n}f(d) sum limits_{j=1}^{n} left[ frac{gcd(i,j)}{d} = 1
ight] h(i) h(j) x_j \
&= sum limits_{d=1}^{n}f(d) sum limits_{d mid j}^{n} h(i) h(j) sum limits_{kmid frac{gcd(i,j)}{d}} mu(k) x_j
end{split}
]
令 (T = kd),再变换求和顺序,则有
[egin{split}
b_i &= sum limits_{d=1}^{n}f(d) sum limits_{d mid j}^{n} h(i) h(j) sum limits_{Tmid gcd(i,j)} mu(frac{T}{d}) x_j \
&= h(i) sum limits_{T mid i} sum limits_{T mid j} sum limits_{d mid T} f(d) h(j) mu(frac{T}{d}) x_j \
&= h(i) sum limits_{T mid i} sum limits_{T mid j} h(j) x_j sum limits_{d mid T} f(d)mu(frac{T}{d})
end{split}
]
后半部分可以提前预处理,记作 (f_r(T))
[egin{split}
b_i &= h(i) sum limits_{T mid i} sum limits_{T mid j} h(j) x_j f_r(T) \
&= h(i) sum limits_{T mid i} f_r(T) sum limits_{T mid j} h(j) x_j
end{split}
]
后面那部分可以提前算出来,记作 (g(T))
[egin{split}
b_i &= h(i) sum limits_{T mid i} f_r(T) g(T)
end{split}
]
令 (g_r(T) = f_r(T) g(T))
[egin{split}
b_i &= h(i) sum limits_{T mid i} g_r(T)
end{split}
]
再莫比乌斯反演一次
[egin{split}
frac{b_i}{h(i)} &= sum limits_{T mid i} g_r(T) \
g_r(i) &= sum limits_{T mid i} frac{b_T}{h(T)} mu(frac{i}{T})
end{split}
]
那么 (g_r) 可以算出来,(f_r) 也可以算出来,于是可以算出 (g)
又由于 (g(T)=sum limits_{T mid j} h(j) x_j),所以再进行一次莫比乌斯反演
[h(j)x_j = sum limits_{Tmid j} g(T) mu(frac{j}{T})
]
就可以把 (h(j)x_j) 算出来,就求出了 (x_j)。
所以其实本质上就是用 (b_i) 除以 (h(i)) 然后莫比乌斯反演,然后再除以 (f) 的莫比乌斯反演,再莫比乌斯反演,再除以 (h(j))。
三个莫比乌斯反演掷地有声。XDXDXD
sto vfleaking orz
Code
Talk is cheap. Show me the code.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int _ = 1e5 + 10;
int N, c, d, Q, mu[_], b[_];
ll h[_], f[_], fr[_], invh[_], invfr[_], t[_], gr[_], g[_];
inline int ty() {
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
ll fastPow(ll a, ll b) {
ll ret = 1;
for (; b; b >>= 1) {
if (b & 1) ret = ret * a % mod;
a = a * a % mod;
}
return ret;
}
void getMu(int lim = 1e5) {
static int cnt, p[_], vis[_];
mu[1] = 1;
for (int i = 2; i <= lim; ++i) {
if (!vis[i]) p[++cnt] = i, mu[i] = -1;
for (int j = 1; j <= cnt && i * p[j] <= lim; ++j) {
vis[i * p[j]] = 1;
if (i % p[j] == 0) {
mu[i * p[j]] = 0;
break;
}
mu[i * p[j]] = -mu[i];
}
}
}
int main() {
#ifndef ONLINE_JUDGE
freopen("run.in", "r", stdin);
freopen("run.out", "w", stdout);
#endif
getMu();
N = ty(), c = ty(), d = ty(), Q = ty();
for (int i = 1; i <= N; ++i) {
h[i] = fastPow(i, d);
invh[i] = fastPow(h[i], mod - 2);
f[i] = fastPow(i, (c - d + (mod - 1)) % (mod - 1)); // 扩展欧拉定理
}
for (int i = 1; i <= N; ++i)
for (int j = i; j <= N; j += i)
fr[j] = (fr[j] + mu[j / i] * f[i] % mod) % mod;
for (int i = 1; i <= N; ++i) invfr[i] = fastPow(fr[i], mod - 2);
while (Q--) {
bool flag = true;
for (int i = 1; i <= N; ++i) b[i] = ty();
for (int i = 1; i <= N; ++i) {
t[i] = b[i] * invh[i] % mod;
if (b[i] && !h[i]) flag = false;
}
memset(gr, 0, sizeof(ll) * (N + 1));
for (int i = 1; i <= N; ++i)
for (int j = i; j <= N; j += i)
gr[j] = (gr[j] + mu[j / i] * t[i] % mod + mod) % mod;
for (int i = 1; i <= N; ++i) {
g[i] = gr[i] * invfr[i] % mod;
if (gr[i] && !invfr[i]) flag = false;
}
memset(t, 0, sizeof(ll) * (N + 1));
for (int i = 1; i <= N; ++i)
for (int j = i; j <= N; j += i)
t[i] = (t[i] + mu[j / i] * g[j] % mod + mod) % mod;
for (int i = 1; i <= N; ++i) t[i] = t[i] * invh[i] % mod;
if (flag)
for (int i = 1; i <= N; ++i) printf("%lld ", t[i]);
else
cout << -1;
puts("");
}
return 0;
}