题解参见罗穗骞的论文《后缀数组——处理字符串的有力工具》。
最近开始重拾一些重点算法,比如网络流,强连通分量,还有后缀数组(字符串和数论相关一直都是弱项)。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 #include <vector> 5 #include <numeric> 6 7 using LL = long long; 8 9 const size_t maxN = 50000 + 10; 10 11 char str[maxN]; 12 int sa[maxN], rank[maxN], height[maxN]; 13 size_t len; 14 15 void input() 16 { 17 scanf("%s", str); 18 len = strlen(str); 19 } 20 21 void buildSA() 22 { 23 struct Node { 24 int k1, k2, id; 25 bool operator == (const Node& rhs) const { 26 return k1 == rhs.k1 && k2 == rhs.k2; 27 } 28 }; 29 30 std::vector<Node> node(len), temp(len); 31 std::vector<int> bucketSize; 32 33 for (int i = 0; i < len; i++) 34 rank[i] = (int)str[i]; 35 36 auto radixSort = [&] (size_t alphaSize) -> void 37 { 38 bucketSize.resize(alphaSize); 39 std::fill(bucketSize.begin(), bucketSize.end(), 0); 40 41 for (auto& cur: node) 42 bucketSize[cur.k2] += 1; 43 std::partial_sum(bucketSize.begin(), bucketSize.end(), bucketSize.begin()); 44 45 for (auto it = node.crbegin(); it != node.crend(); ++it) 46 temp[--bucketSize[it->k2]] = *it; 47 48 std::fill(bucketSize.begin(), bucketSize.end(), 0); 49 50 for (auto& cur: temp) 51 bucketSize[cur.k1] += 1; 52 std::partial_sum(bucketSize.begin(), bucketSize.end(), bucketSize.begin()); 53 54 for (auto it = temp.crbegin(); it != temp.crend(); ++it) 55 node[--bucketSize[it->k1]] = *it; 56 }; 57 58 for (size_t step = 0; step < len; step == 0 ? step = 1 : step <<= 1) 59 { 60 for (int i = 0; i < len; i++) 61 node[i] = {rank[i], i + step < len ? rank[i + step] : 0, i}; 62 63 radixSort(step == 0 ? size_t(128) : len); 64 65 int last = 1; 66 rank[node[0].id] = 1; 67 for (size_t i = 1; i < len; i++) 68 { 69 if (!(node[i] == node[i - 1])) 70 last += 1; 71 rank[node[i].id] = last; 72 } 73 74 if (last == len) 75 break; 76 } 77 78 for (int i = 0; i < len; i++) 79 sa[rank[i] -= 1] = i; 80 } 81 82 void buildHeight() 83 { 84 height[0] = 0; 85 int last = 0; 86 87 for (int i = 0; i < len; i++) 88 { 89 if (rank[i] == 0) 90 continue; 91 92 int lp = sa[rank[i] - 1], cp = i; 93 int cur = std::max(0, last - 1); 94 95 for (; str[lp + cur] == str[cp + cur]; cur++) {} 96 height[rank[i]] = last = cur; 97 } 98 } 99 100 //#include <fmt/format.h> 101 102 LL solve() 103 { 104 buildSA(); 105 buildHeight(); 106 107 // for (int i = 0; i < len; i++) 108 // fmt::print("sa[{i}] = {1}, rank[{i}] = {2}, height[{i}] = {3} ", 109 // fmt::arg("i", i), sa[i], rank[i], height[i]); 110 111 LL ans = len - sa[0]; 112 for (int i = 1; i < len; i++) 113 ans += (len - sa[i] - height[i]); 114 115 return ans; 116 } 117 118 int main() 119 { 120 int T; 121 for (scanf("%d", &T); T; T--) 122 { 123 input(); 124 printf("%lld ", solve()); 125 } 126 return 0; 127 }