题目大意是问在$S$串中找区间$[i,j]$,在$T$串中找位置$k$,使得$S[i,j]$和$T[1,k]$可以组成回文串,并且$j-i+1>k$,求这样的三元组$(i,j,k)$的个数。
一开始有点懵,但是仔细一想,因为$j-i+1>k$,所以$S[i,j]$中一定包含了回文串后半段的一部分,即$S[i,j]$中一定有后缀是回文串。
如果回文串是$S[x,j]$,则剩余的$S[i,x-1]$与$T[1,k]$应该也能组成回文串。如果将串$S$倒置,则串$S^{'}$上的原$S[i,x-1]$位置与$T[1,k]$应该相同。
所以解题方式应该比较明了,将串$S$倒置,然后求扩展$kmp$,得到串$S^{'}$每个后缀与串$T$的最长公共前缀。然后对串$S^{'}$构建回文自动机。
可以得到串$S^{'}$每个位置作为回文子串的结尾时的回文串个数。然后枚举串$S^{'}$每个位置$i$,以当前位置作为上文中的$x$,然后计算当前位置对答案的贡献。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int maxn = 1e6 + 100; 5 int Next[maxn]; 6 int Ex[maxn]; 7 void getN(char* s1) {//求子串与自身匹配 8 int i = 0, j, p, len = strlen(s1); 9 Next[0] = len; 10 while (i + 1 < len && s1[i] == s1[i + 1]) 11 i++; 12 Next[1] = i; 13 p = 1; 14 for (i = 2; i < len; i++) { 15 if (Next[i - p] + i < Next[p] + p) 16 Next[i] = Next[i - p]; 17 else { 18 j = Next[p] + p - i; 19 if (j < 0) 20 j = 0; 21 while (i + j < len && s1[j] == s1[i + j]) 22 j++; 23 Next[i] = j; 24 p = i; 25 } 26 } 27 } 28 void getE(char* s1, char* s2) {//求子串与主串匹配 29 int i = 0, j, p, len1 = strlen(s1), len2 = strlen(s2); 30 while (i < len1 && i < len2 && s1[i] == s2[i]) 31 i++; 32 Ex[0] = i; 33 p = 0; 34 for (i = 1; i < len1; i++) { 35 if (Next[i - p] + i < Ex[p] + p) 36 Ex[i] = Next[i - p]; 37 else { 38 j = Ex[p] + p - i; 39 if (j < 0) 40 j = 0; 41 while (i + j < len1 && j < len2 && s1[i + j] == s2[j]) 42 j++; 43 Ex[i] = j; 44 p = i; 45 } 46 } 47 } 48 struct Palindromic_Tree { 49 int next[maxn][26];//指向的串为当前串两端加上同一个字符构成 50 int fail[maxn];//fail指针,失配后跳转到fail指针指向的节点 51 int cnt[maxn]; //表示节点i表示的本质不同的串的个数,最后用count统计 52 int num[maxn]; //表示节点i表示的最长回文串的最右端点为回文串结尾的回文串个数 53 int len[maxn];//len[i]表示节点i表示的回文串的长度 54 int id[maxn];//表示数组下标i在自动机的哪个位置 55 int S[maxn]; 56 int last;//指向上一个字符所在的节点,方便下一次add 57 int n; int p; 58 int newnode(int x) { 59 for (int i = 0; i < 26; ++i) next[p][i] = 0; 60 cnt[p] = 0; num[p] = 0; len[p] = x; 61 return p++; 62 } 63 void init() {//初始化 64 p = 0; 65 newnode(0); newnode(-1); 66 last = 0; n = 0; 67 S[n] = -1; 68 fail[0] = 1; 69 } 70 int get_fail(int x) {//失配后找一个最长的 71 while (S[n - len[x] - 1] != S[n]) x = fail[x]; 72 return x; 73 } 74 void add(int x) { 75 S[++n] = x; 76 int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置 77 if (!next[cur][x]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串 78 int now = newnode(len[cur] + 2);//新建节点 79 id[n - 1] = now; 80 fail[now] = next[get_fail(fail[cur])][x];//建立fail指针,以便失配后跳转 81 next[cur][x] = now; 82 num[now] = num[fail[now]] + 1; 83 } 84 else 85 id[n - 1] = next[cur][x]; 86 last = next[cur][x]; 87 cnt[last]++; 88 } 89 void count() { 90 for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i]; 91 } 92 93 }a; 94 char s[maxn], s1[maxn], t[maxn]; 95 int main() { 96 scanf("%s%s", s, t); 97 int n = strlen(s), m = strlen(t); 98 for (int i = 0; i < n; i++) 99 s1[i] = s[n - i - 1]; 100 getN(t); 101 getE(s1, t); 102 a.init(); 103 for (int i = 0; i < n; i++) 104 a.add(s1[i] - 'a'); 105 a.count(); 106 ll ans = 0; 107 for (int i = n - 1; i >= 0; i--) { 108 int w = Ex[i]; 109 ans += 1LL * w * a.num[a.id[i - 1]]; 110 } 111 printf("%lld ", ans); 112 }