传送门
假设 (f^k(i)) 就是 (f(i))
莫比乌斯反演得到
[ans=sum_{i=1}^{N}lfloorfrac{N}{i}
floor^2sum_{d|i}f(d)mu(frac{i}{d})
]
令 (g(N)=sum_{i=1}^{N}(f imes mu)(i))
而 ((f imes mu) imes 1=f imes (mu imes 1)=f)
所以
[sum_{i=1}^{N}f(i)=sum_{i=1}^{N}(f imes mu imes 1)(i)=sum_{i=1}^{N}g(lfloorfrac{N}{i}
floor)
]
[g(N)=sum_{i=1}^{N}f(i)-sum_{i=2}^{N}g(lfloorfrac{N}{i}
floor)
]
类似 (UOJ188:sanrd) 一样筛出 (f) 的和即可
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned int uint;
const int maxn(1e6 + 5);
inline uint Pow(uint x, int y) {
register uint ret = 1;
for (; y; y >>= 1, x = x * x) if (y & 1) ret = ret * x;
return ret;
}
int pr[maxn], tot, id1[maxn], id2[maxn], d, cnt, k;
bitset <maxn> ispr;
uint n, f[maxn], val[maxn], prk[maxn], g[maxn];
inline void Sieve(int mx) {
register int i, j;
for (i = 2, ispr[1] = 1; i <= mx; ++i) {
if (!ispr[i]) pr[++tot] = i, prk[tot] = Pow(i, k);
for (j = 1; j <= tot && pr[j] * i <= mx; ++j) {
ispr[pr[j] * i] = 1;
if (!(i % pr[j])) break;
}
}
}
# define ID(x) (x) <= d ? id1[x] : id2[n / (x)]
uint Calc(uint x, int m) {
if (x <= 1 || pr[m] > x) return 0;
register uint i, t, ret = 0;
for (i = m; i <= tot && (ll)pr[i] * pr[i] <= x; ++i)
for (t = pr[i]; (ll)pr[i] * t <= x; t *= pr[i])
ret += Calc(x / t, i + 1) + (f[ID(x / t)] - i + 1) * prk[i];
return ret;
}
inline void Init(uint _n) {
register uint i, j;
for (cnt = 0, d = sqrt(n = _n), i = 1; i <= n; i = j + 1) {
j = n / (n / i), val[++cnt] = n / i;
val[cnt] <= d ? id1[val[cnt]] = cnt : id2[n / val[cnt]] = cnt;
f[cnt] = val[cnt] - 1;
}
for (i = 1; i <= tot && (ll)pr[i] * pr[i] <= n; ++i)
for (j = 1; j <= cnt && (ll)pr[i] * pr[i] <= val[j]; ++j)
f[j] -= f[ID(val[j] / pr[i])] - i + 1;
}
inline uint Solve(uint r) {
if (~g[ID(r)]) return g[ID(r)];
register uint ret = Calc(r, 1) + f[ID(r)], i, j;
for (i = 2, j; i <= r; i = j + 1) j = r / (r / i), ret -= Solve(r / i) * (j - i + 1);
return g[ID(r)] = ret;
}
int main() {
memset(g, -1, sizeof(g));
scanf("%u%d", &n, &k), Sieve(sqrt(n)), Init(n);
register uint i, j, ans = 0, lst = 0, cur;
for (i = 1; i <= n; i = j + 1) {
j = n / (n / i), cur = Solve(j);
ans += (n / i) * (n / i) * (cur - lst);
lst = cur;
}
printf("%u
", ans);
return 0;
}