题目描述
给定一个长度为(N)的数列(A_1, A_2, A_3, ldots, A_N)。
请你求出(sum_{i=1}^{N}sum_{j=i+1}^{N}mathrm{lcm}(A_i,A_j))的值模(998244353)的结果。
(1leq N leq 2 imes 10^5,1 leq A_i leq 10^6)。
题解
(sum_{i = 1} ^ {N}sum_{j = i + 1} ^ {N}mathrm{lcm}(A_i, A_j))
(= frac{sum_{i = 1} ^ {N} sum_{j = 1} ^ {N}mathrm{lcm}(A_i, A_j) - sum_{i = 1} ^ {N}A_i}{2})
所以我们只要维护(sum_{i = 1} ^ {N}sum_{j = 1} ^ {N}mathrm{lcm}(A_i, A_j))就很容易得到答案。
(sum_{i = 1} ^ {N}sum_{j = 1} ^ {N}mathrm{lcm}(A_i, A_j))
(=sum_{d = 1} ^ {Max} frac{1}{d} sum_{i = 1} ^ {N}A_i sum_{j = 1} ^ {N}A_j [mathrm{gcd}(A_i, A_j) == d])
令(F(d) = sum_{i = 1} ^ {N}A_i sum_{j = 1} ^ {N}A_j [d | mathrm{gcd}(A_i, A_j)], f(d) = sum_{i = 1} ^ {N}A_i sum_{j = 1} ^ {N}A_j [mathrm{gcd}(A_i, A_j) == d])
则(F(d) = sum f(e) [d|e])
关于(F), 我们可以计算每个数的所有倍数之和并让他们平方得到(两两组合均合法)
由莫比乌斯反演可得,(f(d) = sum F(e) * mu(frac{e}{d}))
以上各种倍数的枚举均可以(O(n * ln_n))得出,再加上线性处理逆元即可。
#include <iostream>
#include <cstdio>
#define ll long long
#define int long long
using namespace std;
const int N = 2e5 + 5, M = 1e6 + 5;
int n, a[N], mx, v[M], prime[M], tot, mu[M];
ll ans, t[M], inv[M], f[M], sum, F[M];
const int mod = 998244353;
inline int read()
{
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
void init(int n)
{
mu[1] = 1;
for(int i = 2; i <= n; i ++)
{
if(!v[i]) {prime[++ tot] = i; mu[i] = -1;}
for(int j = 1; j <= tot && prime[j] * i <= n; j ++)
{
v[i * prime[j]] = 1;
if(i % prime[j] == 0)
{
mu[i * prime[j]] = 0;
break;
}
mu[i * prime[j]] = - mu[i];
}
}
inv[0] = inv[1] = 1;
for(int i = 2; i <= n; i ++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
void work()
{
n = read();
for(int i = 1; i <= n; i ++) a[i] = read(), t[a[i]] ++, mx = max(mx, a[i]), sum = (sum + a[i]) % mod;
init(mx);
for(int i = 1; i <= mx; i ++)
{
for(int j = i; j <= mx; j += i) F[i] = (F[i] + t[j] * j) % mod;
F[i] = (F[i] * F[i]) % mod;
}
for(int i = 1; i <= mx; i ++) for(int j = i; j <= mx; j += i) f[i] = (f[i] + (F[j] * mu[j / i] % mod + mod)) % mod;
for(int d = 1; d <= mx; d ++) ans = (ans + inv[d] * f[d] % mod) % mod;
ans = (ans - sum + mod) % mod * inv[2] % mod;
printf("%lld
", (ans + mod) % mod);
}
signed main() {return work(), 0;}