整理题目转化为数学语言
题目要我们求:
[sum_{i=1}^nsum_{i=1}^m[gcd(i,j)=p]
]
其中
[pin ext{质数集合}
]
这样表示显然不是很好,所以我们需要更加数学一点:
[sum_{k=1}^{n}sum_{i=1}^nsum_{j=1}^m[gcd(i,j)=k] (kin ext{素数集合})
]
按照套路我们转化为:
[sum_{k=1}^{n}sum_{i=1}^{lfloorfrac{n}{k}
floor}sum_{j=1}^{lfloorfrac{m}{k}
floor}[gcd(i,j)=1] (kin ext{素数集合})
]
下面正式开始莫比乌斯反演
根据
[[gcd(i,j)=1]=sum_{d|gcd(i,j)}mu(d)
]
[sum_{k=1}^{n}sum_{i=1}^{lfloorfrac{n}{k}
floor}sum_{j=1}^{lfloorfrac{m}{k}
floor}sum_{d|gcd(i,j)}mu(d) (kin ext{素数集合})
]
继续套路,我们枚举d取代多余的(sum)
简单来说用枚举(d)来取代枚举(i,j)
于是式子变成了
[sum_{k=1}^{n}sum_{i=d}^{lfloorfrac{n}{k}
floor}mu(d) imeslfloorfrac{n}{kd}
floor imeslfloorfrac{m}{kd}
floor (kin ext{素数集合})
]
但是现在我们的效率还是不够高
于是我们再应用一个套路(套路2)
我们使用( heta)替换(k imes d),则
[sum_{k=1}^{n}sum_{i=d}^{lfloorfrac{n}{k}
floor}mu(d) imeslfloorfrac{n}{ heta}
floor imeslfloorfrac{m}{ heta}
floor (kin ext{素数集合})
]
然后我们在使用一遍套路3,枚举一下( heta)以替代枚举(k,d)。
式子转化为:
[sum_{T=1}^nlfloorfrac{n}{ heta}
floor imeslfloorfrac{m}{ heta}
floorsum_{k| heta,kin ext{整数集合}} mu(frac{ heta}{k})
]
我们发现
[sum_{k| heta,kin ext{整数集合}} mu(frac{ heta}{k})
]
可以预处理
具体的,对于每一个质数(p),(p)的倍数( heta)加上(mu(frac{ heta}{k}))。
最后再加一个整除分块优化即可AC
代码
#include <cstdio>
#include <algorithm>
#define ll long long
#define MAXNUM 10000005
int mu[MAXNUM], is_not_prime[MAXNUM], primes[MAXNUM / 10], prime_num;
// prefix
ll f[MAXNUM], qzh[MAXNUM];
int read(){
int x = 0; int zf = 1; char ch = ' ';
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') zf = -1, ch = getchar();
while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x * zf;
}
void init(){
mu[1] = 1; is_not_prime[0] = is_not_prime[1] = 1;
for (int i = 2; i <= MAXNUM; ++i){
if (!is_not_prime[i]) mu[primes[++prime_num] = i] = -1;
for (int j = 1; j <= prime_num && primes[j] * i <= MAXNUM; ++j){
is_not_prime[i * primes[j]] = 1;
if (!(i % primes[j])) break;
else
mu[primes[j] * i] = -mu[i];
}
}
for (int i = 1; i <= prime_num; ++i)
for (int j = 1; primes[i] * j <= MAXNUM; ++j)
f[primes[i] * j] += mu[j];
for (int i = 1; i <= MAXNUM; ++i)
qzh[i] = qzh[i - 1] + f[i];
}
int main(){
init();
int T = read(), n, m; ll ans;
while (T--){
n = read(), m = read();
if (n > m) n ^= m ^= n ^= m; ans = 0;
for (int l = 1, r; l <= n; l = r + 1){
r = std::min(n / (n / l), m / (m / l));
ans += (ll)(qzh[r] - qzh[l - 1]) * (n / l) * (m / l);
}
printf("%lld
", ans);
}
return 0;
}