Description
给定小写字母组成的字符串 (S)。定义 (AB) 表示字符串 (A, B) 拼接,(A^n=A^{n-1}A) 表示 (A) 复制 (n) 遍。求三元组 ((A, B, C)) 的个数,满足 (S) 可以写成 ((AB)^i C) 的形式。共 (T) 组数据。
Constraints
(1le |S| le 2^{20}, 1le Tle 5)。
Solution
首先是一个比较大众也比较好想的做法。记 (pre(i)) 表示前缀 (i) 中出现奇数次的字符个数,同理对后缀定义 (suf(i))。字符集用 (Sigma) 表示。
考虑枚举 (AB) 的长 (x),那么前缀 (S[1:x]) 就是 (AB)。考虑 (S) 将会由 (AB) 循环若干次构成,剩下的就是 (C),那么考虑找到一个最大的循环次数 (k),哈希即可。找到 (k) 之后,就能对于每个循环次数 (iin [1, k]),求出 (pre) 上 ([1, ix)) 中有多少个 (le suf(ix+1)) 就是对答案的贡献。考虑到值域是 ([0, 26]),树状数组维护单次操作是 (O(log |Sigma|))。直接做是 (O(Tsum_{i=1}^n frac n i) = O(Tnlog nlog |Sigma|)) 的,想要通过比较困难。
优化其实并不难,考虑一下两个要点:
- (k) 的合法性是单调的;
- 对于一个 (k),不需要枚举 (i),奇数偶数分开算,同为奇数或偶数贡献是一样的。
第二个比较简单。第一个我们可以考虑二分或者倍增找到 (k)。这样的话复杂度大概是 (O(sum_{i=1}^nlog ( frac n i))approx O(n))。参考 这里。
二分我不太会保证复杂度,这里介绍一种倍增方法。
设 (X=S[1:x]),其哈希值为 (H(X))。那么我们可以得到复制 (t) 倍的串的哈希值:
其中 (b) 为哈希的基数。计算一个哈希值,如果使用快速幂的话,需要 (O(log tx)) 的时间。不过如果是倍增的话,我们只需要计算 (lfloorlog_2 frac n x floor) 个 (b^{tx}) 值即可,每一项等于前一项的平方。逆元直接算是 (O(log mod)) 的,尽管每个 (x) 都只算一次也是不可接受的。那么只好用一个 离线求逆元的 trick,预处理所有 (x) 的 ((b^x-1)^{-1})。这样复杂度就只有 (O(Tnlog |Sigma|)) 了。不会二分是因为倍增可以预处理 (lfloorlog_2 frac n x floor) 个 (b^{tx}) 而二分我就不知道了。
Code
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
const int N = 1 << 20 | 5;
typedef unsigned long long ull;
const ull base = 19260817;
const ull mod = 1e9 + 7;
ull pw[N];
int n;
long long ans;
char s[N];
int pre[N], suf[N];
ull hs[N];
inline ull fastpow(ull a, ull b) {
ull r = 1llu;
for (; b; b >>= 1, (a *= a) %= mod)
if (b & 1) (r *= a) %= mod;
return r;
}
ull buf[N], inv[N];
namespace inversion {
ull pre[N], suf[N];
void process(int n) {
memset(pre, 0, sizeof(pre));
memset(suf, 0, sizeof(suf));
memset(inv, 0, sizeof(inv));
pre[0] = suf[n + 1] = 1llu;
for (int i = 1; i <= n; i++)
pre[i] = pre[i - 1] * buf[i] % mod;
for (int i = n; i >= 1; i--)
suf[i] = suf[i + 1] * buf[i] % mod;
ull all = fastpow(pre[n], mod - 2);
for (int i = 1; i <= n; i++)
inv[i] = pre[i - 1] * suf[i + 1] % mod * all % mod;
}
}
struct bit {
int t[28];
inline int get(int x) {
int r = 0;
for (++x; x; x -= x & -x) r += t[x];
return r;
}
inline void add(int x) {
for (++x; x <= 27; x += x & -x) ++t[x];
}
inline void reset() {
memset(t, 0, sizeof(t));
}
} tr;
signed main() {
pw[0] = 1llu;
for (int i = 1; i < N; i++)
pw[i] = pw[i - 1] * base % mod;
int T;
scanf("%d", &T);
while (T--) {
scanf("%s", s + 1);
n = strlen(s + 1);
ans = 0;
tr.reset();
memset(hs, 0, sizeof(hs));
memset(pre, 0, sizeof(pre));
memset(suf, 0, sizeof(suf));
memset(buf, 0, sizeof(buf));
for (int i = 1; i <= n; i++)
hs[i] = (hs[i - 1] * base + s[i]) % mod;
pre[0] = suf[n + 1] = 0;
for (int i = 1, v = 0; i <= n; i++) {
int nv = v ^ (1 << (s[i] - 'a'));
if (nv > v) pre[i] = pre[i - 1] + 1;
else pre[i] = pre[i - 1] - 1;
v = nv;
}
for (int i = n, v = 0; i >= 1; i--) {
int nv = v ^ (1 << (s[i] - 'a'));
if (nv > v) suf[i] = suf[i + 1] + 1;
else suf[i] = suf[i + 1] - 1;
v = nv;
}
for (int x = 2; x < n; x++)
buf[x - 1] = pw[x] - 1;
inversion::process(n - 2);
tr.add(pre[1]);
for (int x = 2; x < n; x++) {
int k = 0, maxb = log2(n / x);
ull cst = inv[x - 1] * hs[x] % mod;
ull fix = 0;
ull tpw[maxb + 1];
tpw[0] = pw[x];
for (int j = 1; j <= maxb; j++)
tpw[j] = tpw[j - 1] * tpw[j - 1] % mod;
for (int j = maxb; j >= 0; j--) {
ull cur = ((tpw[j] - 1) * cst % mod * pw[k * x] % mod + fix) % mod;
if (cur == hs[x * (k + (1 << j))])
k += (1 << j), fix = cur;
}
if (x * k == n) --k;
ans += tr.get(suf[x + 1]) * ((k + 1) / 2);
if (k > 1) ans += tr.get(suf[x * 2 + 1]) * (k / 2);
tr.add(pre[x]);
}
printf("%lld
", ans);
}
return 0;
}