Description
Solution
好名字。
像我这种懒狗肯定是直接看简化版题意啊!
题意让我们求不连续的回文子序列个数,我们发现并不是很好求,所以转换一下思路,求:
回文子序列个数(可连续) - 回文子串个数
回文子串个数很容易求出来,就是 \(manacher\) 板子,下面我们来考虑如何求出回文子序列个数。
设点 \(id\) 为一个子序列的对称中心,若 \(S_{i - j} = S_{i + j}\) ,我们的答案就会多一种选择,假设对于对称中心 \(i\) 有 \(f_i = x\) 种选择,那么:
- \(i\) 是一个字符,答案就是 \(2^{x + 1} - 1\)。包括对称中心 \(i\) 在内有 \(x + 1\) 对选或不选的字符,所以一共有 \(2^{x + 1}\) 种情况,减去空集的 1 种情况就是答案。
- \(i\) 在两个字符中间,答案就是 \(2^x - 1\),即不能包括对称中心。
通过上面的讨论可以发现这个对称中心在两个字符中间,所以我们按照 \(manacher\) 的做法把原字符串翻倍,这样一来这个中心就一定是整点了。
那么这个 \(f_i\) 如何求呢?
构造出生成函数:
\[F(x) = \sum\limits_{i = 0}^{|s| - 1}a_ix^i
\]
然后令 \(F(x)\) 自己卷自己,即 \(F(x)^2\), 这一步用 NTT 或 FFT 都行。
不难发现,\(F(x)^2\) 中 \(x^i\) 的系数就是在原字符串中关于 \(\frac{i}{2}\) 对称的字符对的个数。
又观察到输入的字符串中只有 a
和 b
,而且这两个字符之间没有关联,所以分开考虑:
- 对于
a
,若 \(S_i = a\),那么 \(A_i = 1\),反之 \(A_i = 0\)。 - 对于
b
同理,若 \(S_i = b\),那么 \(B_i = 1\),反之 \(B_i = 0\)。
得出 \(A\) 和 \(B\) 之后我们就可以计算 \(f_i\) 啦。
\[f_i = A^2 + B^2
\]
最后统计一下答案就完啦。
Code
我这里用的 NTT 做的,需要两个模数,注意写的时候不要弄混了 QwQ
#include <bits/stdc++.h>
#define ll long long
#define cl const ll
using namespace std;
namespace IO{
inline ll read(){
ll x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
template <typename T> inline void write(T x){
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace IO;
const ll N = 3e5 + 10;
const ll mod = 998244353, Mod = 1e9 + 7;
const ll G = 3, Gi = 332748118;
ll n, m, ans1, ans2;
ll a[N], b[N], c[N], d[N], rev[N];
int p[N];
char s[N], t[N];
namespace NTT{
ll lim, len;
inline ll qpow(ll a, ll b, ll mod){
ll res = 1;
while(b){
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
inline void get_rev(cl n){
lim = 1, len = 0;
while(lim < n) lim <<= 1, ++len;
for(int i = 0; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
}
inline void ntt(ll A[], cl lim, cl type){
for(int i = 0; i < lim; ++i)
if(i < rev[i]) swap(A[i], A[rev[i]]);
for(int mid = 1; mid < lim; mid <<= 1){
ll Wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1), mod);
for(int i = 0; i < lim; i += (mid << 1)){
ll w = 1;
for(int j = 0; j < mid; ++j, w = w * Wn % mod){
ll x = A[i + j], y = w * A[i + j + mid] % mod;
A[i + j] = (x + y) % mod;
A[i + j + mid] = (x - y + mod) % mod;
}
}
}
if(type == 1) return;
ll inv = qpow(lim, mod - 2, mod);
for(int i = 0; i < lim; ++i) A[i] = A[i] * inv % mod;
}
inline void Mul(cl n, cl m, ll a[], ll b[]){
get_rev(n + m);
ntt(a, lim, 1), ntt(b, lim, 1);
for(int i = 0; i < lim; ++i) a[i] = a[i] * b[i] % mod;
ntt(a, lim, -1);
}
}
using namespace NTT;
inline void manacher(){
n = strlen(t + 1);
s[0] = '#', s[(n << 1) + 1] = '*';
for(int i = 1; i <= n; ++i)
s[(i << 1) - 1] = '*', s[i << 1] = t[i];
m = (n << 1) + 1;
int mx = 0, id = 0;
for(int i = 1; i <= m; ++i){
if(i < mx) p[i] = min(mx - i, p[(id << 1) - i]);
else p[i] = 1;
while(i - p[i] >= 1 && i + p[i] <= m && s[i - p[i]] == s[i + p[i]]) ++p[i];
if(i + p[i] > mx) id = i, mx = i + p[i];
}
for(int i = 1; i <= (n << 1) + 1; ++i) ans1 = (ans1 + (p[i] >> 1)) % Mod;
}
inline void solve(){
for(int i = 1; i <= n; ++i){
if(t[i] == 'a') a[i - 1] = c[i - 1] = 1;
else b[i - 1] = d[i - 1] = 1;
}
Mul(n, n, a, c), Mul(n, n, b, d);
for(int i = 0; i <= (n << 1); ++i) c[i] = (a[i] + b[i]) % mod;
for(int i = 0; i <= (n << 1) - 2; ++i){
if(i & 1) ans2 = (ans2 + qpow(2, (c[i] >> 1), Mod) - 1 + Mod) % Mod;
else ans2 = (ans2 + qpow(2, (c[i] >> 1) + 1, Mod) - 1 + Mod) % Mod;
}
write((ans2 - ans1 + Mod) % Mod), puts("");
}
int main(){
scanf("%s", t + 1);
manacher();
solve();
return 0;
}
\[\_EOF\_
\]