恶心至极!!!!!!!!
思路
求 (sumlimits_{i = 1}^{n} (n mod i)sumlimits_{j=1}^{m}(mmod j)[i eq j])
假设没有限制情况(i eq j)
(sumlimits_{i = 1}^{n} (n mod i)sumlimits_{j=1}^{m}(mmod j))
只看左半部分:
( sumlimits_{i=1}^{n}(n\% i))
(= sumlimits_{i=1}^{n}(n - lfloor frac{n}{i} floor * i))
(= sumlimits_{i=1}^{n}n - sumlimits_{i = 1}^{n}lfloorfrac{n}{i} floor*i)
显然数论分块,右半部分同理,都可以数论分块做(余数求和那道题的完全一样的做法)
再看(i=j)的情况,即
$ sumlimits_{i=1}^{k=min(n,m)}(nmod i)(mmod i)$
(=sumlimits_{i=1}^{k}(n-lfloorfrac{n}{i} floor*i)(m-lfloorfrac{m}{i} floor*i))
(=sumlimits_{i=1}^{k}(nm-(lfloorfrac{n}{i} floor*m+lfloorfrac{m}{i} floor*n)*i+lfloorfrac{n}{i} floor*lfloorfrac{m}{i} floor*i^2))
用数论分块求出上面两个式子,用总的减去下面这个式子,注意除法要用逆元
(ps):
能模就模,好事多模
小知识点
[sumlimits_{i = l}^{r}i=(r -l + 1)*(l + r) / 2 ][sumlimits_{i = 1} ^{n}i^2= n * (n + 1) * (2* n + 1)/ 6 ]
时间复杂度(O(sqrt{n}))
代码
/*
Author:loceaner
*/
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define int long long
using namespace std;
const int A = 5e5 + 11;
const int B = 1e6 + 11;
const int inv6 = 3323403;
const int inv2 = 9970209;
const int mod = 19940417;
const int inf = 0x3f3f3f3f;
inline int read() {
char c = getchar();
int x = 0, f = 1;
for( ; !isdigit(c); c = getchar()) if(c == '-') f = -1;
for( ; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
return x * f;
}
int n, m, k;
inline int sum(int l, int r) {
return (r - l + 1) * (l + r) / 2 % mod;
}
inline int sum1(int r) {
return r * (r + 1) % mod * (2 * r + 1) % mod * inv6 % mod;
}
inline int sum2(int l, int r) {
return (sum1(r) - sum1(l - 1)) % mod;
}
inline int solve(int n) {
int ans = n * n;
for (int l = 1, r; l <= n; l = r + 1) {
if (n / l == 0) break;
r = min(n / (n / l), n);
ans -= (n / l) % mod * sum(l, r) % mod;
ans %= mod;
}
return (ans % mod + mod) % mod;
}
signed main() {
n = read(), m = read();
int ans1 = solve(n) * solve(m) % mod, ans2 = 0;
for (int l = 1, r, now1, now2, now3, now4; l <= min(n, m); l = r + 1) {
r = min(n / (n / l), m / (m / l));
now1 = n * m % mod * (r - l + 1) % mod;
now2 = ((n / l) * (m / l) % mod * sum2(l, r) % mod + mod) % mod;
now3 = ((n / l) * m % mod + (m / l) * n % mod) * sum(l, r) % mod;
now4 = (now1 + now2 - now3 + mod) % mod;
ans2 = (ans2 + now4) % mod;
}
cout << ((ans1 - ans2) % mod + mod) % mod;
return 0;
}