sutoringu
题意:
询问有多少一个字符串内有多少个个子区间,满足可以分成k个相同的串。
分析:
首先可以枚举一个长度len,表示分成的k个长为len的串。然后从1开始,每len的长度分成一块,分成(n-1)/k+1块,首先可以求出连续的k块的是否是合法。
此时只求了起点是1+len*i的串,还有些起点在块内的没有求。
枚举k-1个相同的块,设这些块为i...j,j-i+1=k。然后与求一下第i块和第i-1块最长后缀,设为a,求一下第j块和第j+1块的最长前缀,设为b。说明如果起点在第i-1块的串,必须是后面a个字符,这些串的终点必须是第j+1块的前b个字符。于是计算一下。
如何求连续的k块是否是一样的?可以求出这连续k块在的rank,然后取一个最大的rank和一个最小的rank,然后求之间的height最小值即可。
复杂度$nlog^2n$。
代码:
#include<cstdio> #include<algorithm> #include<iostream> #include<cstring> #include<cmath> #include<cctype> #include<set> #include<queue> #include<vector> #include<map> #include<bitset> using namespace std; typedef long long LL; inline int read() { int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1; for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f; } const int N = 600005; char s[N]; int t1[N], t2[N], c[N], sa[N], rnk[N], ht[N], f[N][21], Log[N]; void getsa(int n) { int m = 130, i, *x = t1, *y = t2; for (i = 1; i <= m; ++i) c[i] = 0; for (i = 1; i <= n; ++i) x[i] = s[i], c[x[i]] ++; for (i = 1; i <= m; ++i) c[i] += c[i - 1]; for (i = n; i >= 1; --i) sa[c[x[i]]--] = i; for (int k = 1; k <= n; k <<= 1) { int p = 0; for (i = n - k + 1; i <= n; ++i) y[++p] = i; for (i = 1; i <= n; ++i) if (sa[i] > k) y[++p] = sa[i] - k; for (i = 1; i <= m; ++i) c[i] = 0; for (i = 1; i <= n; ++i) c[x[y[i]]] ++; for (i = 1; i <= m; ++i) c[i] += c[i - 1]; for (i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i]; swap(x, y); p = 2; x[sa[1]] = 1; for (i = 2; i <= n; ++i) x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? p - 1 : p ++; if (p > n) break; m = p; } for (i = 1; i <= n; ++i) rnk[sa[i]] = i; ht[1] = 0; int k = 0; for (i = 1; i <= n; ++i) { if (rnk[i] == 1) continue; if (k) k --; int j = sa[rnk[i] - 1]; while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++; ht[rnk[i]] = k; } for (i = 1; i <= n; ++i) f[i][0] = ht[i]; for (i = 2; i <= n; ++i) Log[i] = Log[i >> 1] + 1; for (int j = 1; j <= Log[n]; ++j) for (i = 1; i + (1 << j) - 1 <= n; ++i) f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]); } int LCP(int i,int j) { i = rnk[i], j = rnk[j]; if (i > j) swap(i, j); i ++; int k = Log[j - i + 1]; return min(f[i][k], f[j - (1 << k) + 1][k]); } int LCP2(int i,int j) { i ++; int k = Log[j - i + 1]; return min(f[i][k], f[j - (1 << k) + 1][k]); } set<int> sk; int n, k, rev[N]; bool check(int len) { int l = *sk.begin(); set<int>::iterator it = sk.end(); it --; int r = *it; return LCP2(l, r) >= len; } int check2(int i,int j,int len) { if (sk.size() >= 2 && !check(len)) return 0; int a = min(len - 1, LCP(rev[i - 1], rev[i - 1 + len])); if (j + len > n) return 0; int b = min(len - 1, LCP(j, j + len)); return max(0, b - (len - a) + 1); } int main() { freopen("sutoringu.in", "r", stdin); freopen("sutoringu.out", "w", stdout); n = read(), k = read(); scanf("%s", s + 1); s[n + 1] = '#'; for (int i = 1; i <= n; ++i) s[i + n + 1] = s[n - i + 1], rev[n - i + 1] = i + n + 1; getsa(n + n + 1); LL ans = 0; for (int len = 1; len <= n; ++len) { sk.clear(); for (int i = 1; i <= n; i += len) { sk.insert(rnk[i]); if (sk.size() > k) sk.erase(rnk[i - len * k]); if (sk.size() == k) ans += check(len); } if (len == 1) continue; sk.clear(); for (int i = len + 1; i <= n; i += len) { sk.insert(rnk[i]); if (sk.size() > k - 1) sk.erase(rnk[i - len * (k - 1)]); if (sk.size() == k - 1) ans += check2(i - (k - 2) * len, i, len); } } cout << ans; return 0; }