bsgs algorithm
ax≡b(mod n)
大步小步算法,这个算法有一定的局限性,只有当gcd(a,m)=1时才可以用
原理
此处讨论n为素数的时候。
ax≡b(mod n)(n为素数)
由费马小定理可知,只需要验证0,1,2...n-1是不是解即可,因为an-1 = 1mod(n)
算法过程
1、首先求出a0,a1,a2,...,am-1 模上n的值是否为b,存储在e[i]中,求出am的逆a-m
2、下面考虑am,am+1,...,a2m-1 模上n的值是否为b
此时不用一一检查,如果当中有解,相当于存在e[i],使得e[i] * am = b mod(n)
两边乘上a-m,e[i] = b * a-m mod(n),只需要检查存不存在这样的e[i]即可
3、同理,可以递推检查出a2m - a3m-1中解的情况
为了方便,把e[i]存储在map<int, int>x中,x[j]表示满足ei =j 的最小下标(因为可能有多个值相同)
1 map<ll, int>x; 2 ll log_mod(ll a, ll b, ll n)//n为素数 3 { 4 //if(b >= n)return -1; 5 a %= n;b %= n;//注意题目,如果b >= n是不存在解的 6 if(a == 0) 7 { 8 if(b == 0)return 1;if(b == 1)return 0; 9 else return -1; 10 } 11 ll m = ceil(sqrt(n + 0.5)); 12 ll v = pow(a, n - 1 - m, n); 13 x.clear(); 14 x[1] = m; 15 ll e = 1; 16 for(int i = 1; i < m; i++)//计算a的i次方mod n,并存下来 17 { 18 e = e * a % n; 19 if(!x[e])x[e] = i; 20 } 21 for(int i = 0; i < m; i++)//计算a^(im), a^(im+1),...,a^(im+m-1) 22 { 23 int num = x[b]; 24 if(num)return i * m + (num == m ? 0 : num); 25 b = b * v % n; 26 } 27 return -1; 28 }
扩展:
当n不为素数的时候,如何求解ax≡b(mod n) ?
转化成gcd(a, n) = 1即可
如何转化呢?
利用公式:ax≡b(mod n)⇔ax/d ≡ b/d (mod n/d),d=gcd(a,n)(这里是a乘上x,上述求解方程为ax)
每次用一个a和m和b消去gcd(a, n),消δ次,每次n /= g, b /= g,迭代更新下一个g,直到g=1
最后ax变成了ax - δ *a', 方程变成:ax - δ *a' = b' (mod n') 这里a' b' n'都是消因子之后的值
满足:a' * g = aδ b' * g = b n' * g = n g = gcd(aδ, n)
如果某一步g不整除b',直接返回-1
如果某一步b' = a',假设此时消因子次数达到cnt次,那就返回cnt
因为消因子次数cnt次等价于
原方程:ax-cntacnt = b (mod n)
消因子后:ax-cnta' = b' (mod n')
若此时b' = a' 那么等式就可以直接消去a' b',得到:ax-cnt = 1 (mod n),解就是cnt
下面求解:ax - δ *a' = b' (mod n')
可以求出a'的逆元,然后乘过去,用上述的大步小步算法求解
下面介绍一个技巧不用求逆元。
先把上述式子写成ax * tmp = b (mod n) 求出x之后加上δ就是解。
设m = sqrt(n +0.5), x = k * m - q (1 <= k <= m, q <= m)
上式可写成tmp * a k*m-q = b (mod n)
等价于 tmp * a k*m = b * aq(mod n)
可以从1-m枚举k,每次求出tmp * a k*m 判断是不是存在b*aq
最开始的大步小步算法也可以这样写
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 ll pow(ll a, ll b, ll m) 5 { 6 ll ans = 1; 7 a %= m; 8 while(b) 9 { 10 if(b & 1)ans = (ans % m) * (a % m) % m; 11 b /= 2; 12 a = (a % m) * (a % m) % m; 13 } 14 ans %= m; 15 return ans; 16 } 17 ll ext_log_mod(ll a, ll b, ll n) 18 { 19 if(b >= n)return -1;//一些特殊情况的判断 20 a %= n; 21 if(b == 1)return 0; 22 //if(n == 1)return -1; 23 ll cnt = 0;//记录消因子次数 24 ll tmp = 1;//存当前a'的值 25 for(ll g = __gcd(a, n); g != 1; g = __gcd(a, n)) 26 { 27 if(b % g)return -1;//不能整除 28 b /= g; n /= g; tmp = tmp * a / g % n; 29 cnt++; 30 if(b == tmp)return cnt; 31 } 32 33 ll m = sqrt(n + 0.5); 34 ll t = b; 35 map<ll, ll>Map;//记录b * a ^ i, i 36 Map[b] = 0; 37 for(int i = 1; i <= m; i++) 38 { 39 b = b * a % n; 40 Map[b] = i; 41 } 42 a = pow(a, m, n); 43 for(int k = 1; k <= m; k++)//枚举k 44 { 45 tmp = tmp * a % n;//求出tmp*a^(k*m) 46 if(Map.count(tmp))return k * m - Map[tmp] + cnt; 47 } 48 return -1; 49 } 50 int main() 51 { 52 ll a, b, p; 53 while(scanf("%lld%lld%lld", &a, &p, &b) != EOF) 54 { 55 ll ans = ext_log_mod(a, b, p); 56 if(ans == -1)printf("Orz,I can’t find D! "); 57 else printf("%lld ", ans); 58 } 59 return 0; 60 }