题解:
很容易想到将第一个串反过来,然后对于s串的每个位置可以求出t的前缀和它匹配了多少个(EXKMP 或者 二分+hash)。
然后剩下的就是要处理以某个位置为结束的回文串有多少个(manacher + 差分),因为要求s串选取的要多一点。
这道题是个痛啊。。。当时的金牌题,不会EXKMP可以用二分+字符串hash啊,比赛前的暑假还写过,比赛时就没想到,还以为KMP可以搞出这个东西,
然后就三个人一起自闭地调KMP,说到底还是菜呀。
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << " "; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 1e6 + 10; int p[N*2], cnt[N]; char s[N], t[N]; int nxt[N], ex[N]; void GETNEXT(char *str) { int i = 0, j, po, len=strlen(str); nxt[0] = len; while(str[i] == str[i+1] && i+1 < len) i++; nxt[1] = i; po = 1; for(i = 2; i < len; i++) { if(nxt[i-po] + i < nxt[po] + po) nxt[i] = nxt[i-po]; else { j=nxt[po] + po - i; if(j < 0) j = 0; while(i + j < len && str[j] == str[j+i]) j++; nxt[i] = j; po = i; } } } void EXKMP(char *s1,char *s2) { int i = 0, j, po, len = strlen(s1), l2=strlen(s2); GETNEXT(s2); while(s1[i] == s2[i] && i < l2 && i < len) i++; ex[0] = i; po = 0; for(i = 1; i < len; i++) { if(nxt[i-po] + i < ex[po] + po) ex[i]=nxt[i-po]; else { j = ex[po] + po - i; if(j < 0) j = 0; while(i + j < len && j < l2 && s1[j+i] == s2[j]) j++; ex[i] = j; po = i; } } } void manacher(char *s) { string t = "$#"; int n = strlen(s); for (int i = 0; i < n; ++i) { t += s[i]; t += '#'; } int mx = 0, id = 0, resl = 0, resc = 0; for (int i = 1; i < t.size(); ++i) { p[i] = mx > i ? min(p[2*id-i], mx-i) : 1; while(t[i+p[i]] == t[i-p[i]]) ++p[i]; if(mx < i+p[i]) mx = i+p[i], id = i; if(resl < p[i]) resl = p[i], resc = i; } for (int i = 1; i < t.size(); ++i) { if(p[i] == 1 && t[i] == '#') continue; int l, r; if(p[i]&1) { l = (i-1)/2; int d = (p[i]-1)/2; r = l+d; } else { l = (i-2)/2; int d = p[i]/2; r = l+d; } cnt[l]++, cnt[r]--; } for (int i = 1; i < n; ++i) cnt[i] += cnt[i-1]; } int main() { scanf("%s", s); scanf("%s", t); int n = strlen(s); for (int i = 0, j = n-1; i < j; ++i, --j) { swap(s[i], s[j]); } manacher(s); EXKMP(s, t); LL ans = 0; for (int i = 1; i < n; ++i) { ans += 1LL * ex[i] * cnt[i-1]; } printf("%lld ", ans); return 0; }