https://ac.nowcoder.com/acm/contest/885/C
这个跟平时那种求离散对数的bsgs不一样,虽然可以转化成离散对数来做不过会T掉。展开递推式然后合并具有a的项,发现其实是离散对数。
#include<cstdio>
#include<algorithm>
const int AA = 1000000;
std::pair<int, int> d[AA];
int mypow(long long x, int y, int p) {
long long res = 1;
while(y) {
if(y & 1)
res = res * x % p;
x = x * x % p;
y >>= 1;
}
return res;
}
int inv(int x, int p) {
return mypow(x, p - 2, p);
}
int val[AA], pos[AA];
void solve() {
long long n, x0, a, b, p;
int Q;
scanf("%lld%lld%lld%lld%lld%d", &n, &x0, &a, &b, &p, &Q);
if(a == 0) {
while(Q--) {
int v;
scanf("%d", &v);
if(v == x0)
puts("0");
else if(v == b)
puts("1");
else
puts("-1");
}
return;
}
long long now = x0;
int m = std::min((long long)AA, n);
for(int i = 0; i < m; i++) {
d[i] = {now, i};
now = (now * a + b) % p;
}
sort(d, d + m);
{
int new_m = 0;
for(int i = 0; i < m; i++) {
val[new_m] = d[i].first;
pos[new_m++] = d[i].second;
while(i + 1 < AA && d[i + 1].first == d[i].first)
i++;
}
m = new_m;
}
int BB = p / AA + 3;
int inv_a = inv(a, p);
int inv_b = (p - b) % p * inv_a % p;
long long aa = 1, bb = 0;
for(int i = 0; i < AA; i++) {
aa = aa * inv_a % p;
bb = (bb * inv_a + inv_b) % p;
}
while(Q--) {
int v;
scanf("%d", &v);
int it = std::lower_bound(val, val + m, v) - val;
if(it < m && val[it] == v) {
printf("%d
", pos[it]);
continue;
}
if(n < AA) {
puts("-1");
continue;
}
bool suc = false;
for(int i = 1; i <= BB; i++) {
v = (aa * v + bb) % p;
it = std::lower_bound(val, val + m, v) - val;
if(it < m && val[it] == v) {
suc = true;
int res = i * AA + pos[it];
if(res >= n)
res = -1;
printf("%d
", res);
break;
}
}
if(!suc)
puts("-1");
}
}
int main() {
int T;
scanf("%d", &T);
while(T--)
solve();
return 0;
}
但这里要学的不是套这个模板,而是用BSGS算法的思路去改。好像在这里unorderedmap的表现不如二分好。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXBUF = 1 << 20;
char buf[MAXBUF], *fh, *ft;
inline char gc() {
if(fh == ft) {
int l = fread(buf, 1, MAXBUF, stdin);
ft = (fh = buf) + l;
}
return *fh++;
}
inline int read() {
int x = 0;
char c = gc();
for(; c < '0' || c > '9'; c = gc())
;
for(; c >= '0' && c <= '9'; c = gc())
x = (x << 3) + (x << 1) + c - '0';
return x ;
}
inline ll readll() {
ll x = 0;
char c = gc();
for(; c < '0' || c > '9'; c = gc())
;
for(; c >= '0' && c <= '9'; c = gc())
x = (x << 3) + (x << 1) + c - '0';
return x ;
}
ll n, x0, a, b, v, A, B;
int Q, p;
inline int qpow(ll a, int n, int p) {
ll s = 1;
while(n) {
if(n & 1) {
s = s * a;
if(s >= p)
s %= p;
}
a = a * a;
if(a >= p)
a %= p;
n >>= 1;
}
return s;
}
inline int inv(ll a, int p) {
return qpow(a % p, p - 2, p);
}
const int BSCEIL = 1000000;
int GSCEIL;
pair<int, int> pii[BSCEIL + 5];
int fi[BSCEIL + 5];
int se[BSCEIL + 5];
struct BinaryMap {
int cnt,piitop;
inline void clear() {
piitop = 0;
}
inline void insert(int s, int i) {
pii[piitop++] = {s, i};
}
inline void build() {
sort(pii , pii + piitop);
cnt = 0;
for(int i = 0; i < piitop; ++i) {
if(i == 0 || pii[i].first != pii[i - 1].first) {
fi[cnt] = pii[i].first;
se[cnt++] = pii[i].second;
}
}
}
inline int query(int v) {
int t = lower_bound(fi, fi+ cnt, v) - fi;
if(t == cnt || fi[t] != v)
return -1;
else
return se[t];
}
} M;
void bs() {
M.clear();
ll s = x0;
for(int i = 0; i < BSCEIL; ++i) {
M.insert(s, i);
s = s * a + b;
if(s >= p)
s %= p;
}
M.build();
int inva = inv(a, p), invb = (p - b) % p * inva % p;
A = 1, B = 0;
for(int i = 1; i <= BSCEIL; ++i) {
A *= inva ;
if(A >= p)
A %= p;
B = B * inva + invb;
if(B >= p)
B %= p;
}
GSCEIL = p / BSCEIL + 3;
}
ll gs() {
ll s = v;
int ans = M.query(s);
if(ans >= n)
return -1;
if(ans!=-1)
return ans;
if(n < BSCEIL)
return -1;
for(int i = 1; i <= GSCEIL; ++i) {
s = s*A+B;
if(s >= p)
s %= p;
int tmp=M.query(s);
if(tmp!=-1) {
ll ans = 1ll * BSCEIL * i + tmp;
if(ans >= n)
ans = -1;
return ans;
}
}
return -1;
}
int main() {
#ifdef Yinku
freopen("Yinku.in", "r", stdin);
#endif // Yinku
int T = read();
while(T--) {
n = readll(), x0 = read(), a = read(), b = read(), p = read(), Q = read();
if(a == 0) {
while(Q--) {
v = read();
if(v == x0)
puts("0");
else if(v == b)
puts("1");
else
puts("-1");
}
} else {
bs();
while(Q--) {
v = read();
ll ans = gs();
if(ans == -1) {
puts("-1");
} else
printf("%lld
", ans);
}
}
}
}