在知乎上看到一个问题:求十亿内所有质数的和,怎么做最快?
记录一下第一回答
定义为到所有整数中,在普通筛法中外层循环筛完时仍然幸存的数的和。因此这些数要不本身是素数,要不其最小的素因子也大于。因此我们需要求的是,其中是十亿。
为了计算,先考虑几个特殊情况
- 。此时所有数都还没有被筛掉,所以
- 不是素数。因为筛法中早已被别的数筛掉,所以在这步什么都不会做,所以此时;
- 是素数,但是。因为每个合数都一定有一个不超过其平方根的素因子,如果筛到时还没筛掉一个数,那么筛到时这个数也还在。所以此时也有。
现在考虑最后一种稍微麻烦些的情况:是素数,且。
此时,我们要用素数去筛掉剩下的那些数中的倍数。注意到现在还剩下的合数都没有小于的素因子。因此有:
后面那项中提取公共因子,有:
因为整除,稍微变形一下,令,有:
注意到一开始提到的的定义(“这些数要不本身是素数,要不其最小的素因子也大于(注意!)”),此时后面这项可以用来表达:
再用替换素数和得到最终表达式:
我们最终的结果是。计算时可以使用记忆化,也可以直接自底向上动态规划。
至于算法的复杂度就留作练习,是低于以上任何一种暴力方法的。
1 import time
2 def P10(n):
3 r = int(n ** 0.5)
4 # assert r*r <= n and (r+1)**2 > n
5 V = [n // i for i in range(1, r + 1)]
6 # print V
7 V += list(range(V[-1] - 1, 0, -1))
8 #print V
9 S = {i: i * (i + 1) // 2 - 1 for i in V}
10 #print S
11 st = time.clock();
12 for p in range(2, r + 1):
13 if S[p] > S[p - 1]: # p is prime
14 sp = S[p - 1] # sum of primes smaller than p
15 p2 = p * p
16 for v in V:
17 if v < p2: break
18 S[v] -= p * (S[v // p] - sp)
19 end = time.clock();
20 print end - st
21 return S[n]
22
23 while(True):
24 N = input()
25
26 print P10(N)
1e9的数据能在1s能跑出来, 真是神一样的算法, 神一样的代码。
————————————————————————更新——————————————————————————————
之后用C++实现了一遍,发现速度还没 线性筛 快,比较循环部分, 发现python 的速度要比 C++快50+倍,如果去除 list 的影响,dict 与 map的效率可能相差两个数量级,去查找dict 与 map 的实现原理, 原来, dict 是用 hash 复杂度O(1) 而 map为了元素的有序性, 采用红黑树 复杂度为O(logn)。
1 const int maxn = 1e6;
2 map<ll, ll> mp;
3 vector<ll> vr;
4 ll arr[maxn];
5 ll solve(ll n){
6 ll r = sqrt(n);
7 vr.clear();
8 mp.clear();
9 for (int i = 1; i <= r; ++i){
10 vr.pk(n / i);
11 }
12 for (int i = vr.back() - 1; i > 0; --i){
13 vr.pk(i);
14 }
15 int len = vr.size();
16 for (int i = 0; i < len; ++i){
17 arr[i] = vr[i];
18 mp[vr[i]] = vr[i] * (vr[i] + 1) / 2 - 1;
19 }
20 ll sp;
21 int p2;
22 double st = clock();
23 for (int i = 2; i <= r; ++i){
24 if (mp[i] > mp[i - 1]){
25 sp = mp[i - 1];
26 p2 = i * i;
27 for (int v = 0; v < len; ++v){
28 if (arr[v] < p2) break;
29 mp[arr[v]] -= i * (mp[arr[v] / i] - sp);
30 }
31 }
32 }
33 double ed = clock();
34 cout << ed - st << endl;
35 return mp[n];
36 }
37
38 int main(){
39 //freopen("data.out", "w", stdout);
40 //freopen("data.in", "r", stdin);
41 //cin.sync_with_stdio(false);
42 int n;
43 while (cin >> n){
44 cout << solve(n) << endl;
45 }
46 return 0;
47 }