这题挺神仙的
注意一些比较显然的性质
1.gcd(a, n) = 1, 所以 a * i % n 的结果互不相同,
所以 (a * i % n + b) % n 的结果也互不相同
2. (a * i + b) % n 的结果是每个差了 + a 再 % n 的
对于每个位置,它的值为 0 / 1 都会分别对应一个方程
设某个匹配位置的起始点为 bgn,该点的值为 val
那么它后面的值就是 (val + a) % n
就是说当起点固定时,后面每个位置上的值就都确定了
其实后边的点也是可以推起点的,
例如当 b[i] = 0 时,设起点的值为 val
那么 b[i] = 0 的条件就是 0 ≤ (val + (i - 1) * a) % n < p
这样可以给 val 求出一个值域,那么对每一个位置求出的值域求交就是 val 合法的值
求交不太好求?
可以补集转化,求补集的并集
代码:
#include <algorithm> #include <iostream> #include <cstring> #include <cstdlib> #include <cctype> #include <cstdio> #include <locale> using namespace std; typedef long long ll; const int MAX_M = 1000005; struct INTERVAL { int l, r; explicit INTERVAL(int L = 0, int R = 0) {l = L; r = R;} bool operator < (const INTERVAL& b) const { return l < b.l; } }line[MAX_M * 3]; int n, a, b, p, m, tot_line; int seq[MAX_M]; inline int rd() { register int x = 0; register int c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) { x = x * 10 + (c ^ 48); c = getchar(); } return x; } inline void get_str(int *str) { register int c = getchar(), len = 0; while (!isdigit(c)) c = getchar(); while (isdigit(c) && len < m) { str[++len] = (c ^ 48); c = getchar(); } } inline void add(int l, int r) { if (l <= r) line[++tot_line] = INTERVAL(l, r); else { line[++tot_line] = INTERVAL(l, n - 1); line[++tot_line] = INTERVAL(0, r); } } int main() { n = rd(); a = rd(); b = rd(); p = rd(); m = rd(); get_str(seq); register int lef, rig; for (int i = 1; i <= m; ++i) { if (seq[i] == 0) { lef = int((p - ((i - 1ll) * a) % n + n) % n); rig = int((n - 1ll - ((i - 1ll) * a) % n) % n); } else { lef = int((0 - ((i - 1ll) * a % n) + n) % n); rig = int((p - 1ll - ((i - 1ll) * a % n) + n) % n); } add(lef, rig); } register int tmp = 0; for (int i = n - m + 1; i < n; ++i) { tmp = int((1ll * a * i % n + b) % n); add(tmp, tmp); } sort(line + 1, line + tot_line + 1); register ll ans = 0ll; register int max_rig = -1ll; for (int i = 1; i <= tot_line; ++i) { if (line[i].l > max_rig) ans += line[i].l - max_rig - 1ll; max_rig = max(max_rig, line[i].r); } printf("%lld ", ans + (n - max_rig - 1ll)); return 0; }