给出一个由 01 组成的字符串,问该字符串有多少不同的子序列满足:
- 子序列是一个回文序列
- 子序列不连续,即这个回文串不可以是原字符串上连续的子串
回文序列不仅要求值回文, 且要求位置回文.
设前者为 , 后者为 , 其中 可以通过 得出, 没学过的可以看 这里 .
所以现在只需考虑 怎么计算 .
- 以整数位置 为对称轴, 设满足 的字母对数为 , 算上 总共产生了 种回文子序列, 那个 是减去 全部都不选 的情况 .
- 以小数位置 为对称轴, 设满足 的字母对数为 , 则总共产生了 种回文子序列 .
构造多项式 , 则
于是 就表示以 位置为对称轴, 的对数
于是 就表示以 位置为对称轴, 的对数
同理设 , 就表示以 为对称轴, 的对数 .
然后 就可以得到以 为对称轴, 的对数 .
纵使 为小数也成立, 所以可以完美覆盖上方情况 .
其中多项式乘法可以使用 实现, 没学过的可以看 这里 .
#include<bits/stdc++.h>
typedef long long ll;
#define reg register
const int maxn = 200005;
const int mod = 1e9 + 7;
const double Pi = acos(-1);
int N;
int FFT_Len;
int pw[maxn];
int rev[maxn<<2];
int hw[maxn<<1];
char S[maxn<<1];
char t[maxn<<1];
struct complex{
double x, y;
complex(double x=0, double y=0):x(x), y(y) {}
} A[maxn<<2], B[maxn<<2];
complex operator + (complex a, complex b){ return complex(a.x+b.x, a.y+b.y); }
complex operator - (complex a, complex b){ return complex(a.x-b.x, a.y-b.y); }
complex operator * (complex a, complex b){ return complex(a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x); }
int Ksm(int a, ll b){
int s = 1;
while(b){
if(b & 1) s = 1ll*s*a % mod;
a = 1ll*a*a % mod; b >>= 1;
}
return s;
}
int Manacher(){
int res = 0;
t[0] = '#';
for(reg int i = 1; i <= N; i ++) t[i*2-1] = S[i], t[i*2] = '#';
t[N*2+1] = '#';
int Max_r = 0, mid = 0;
for(reg int i = 1; i <= N<<1; i ++){
if(i <= Max_r) hw[i] = std::min(hw[(mid<<1)-i], Max_r-i+1);
while(i-hw[i] >= 0 && i+hw[i] <= (N<<1)+1 && t[i-hw[i]] == t[i+hw[i]]) hw[i] ++;
if(i+hw[i]-1 > Max_r) Max_r = i+hw[i]-1, mid = i;
res += hw[i]/2; // !
if(res >= mod) res -= mod;
}
return res;
}
void FFT(complex *F, int opt){
for(reg int i = 0; i < FFT_Len; i ++)
if(i < rev[i]) std::swap(F[i], F[rev[i]]);
for(reg int p = 2; p <= FFT_Len; p <<= 1){
int half = p >> 1;
complex t = complex(cos(Pi/half), opt*sin(Pi/half));
for(reg int i = 0; i < FFT_Len; i += p){
complex buf = complex(1, 0);
for(reg int k = i; k < i+half; k ++){
complex Tmp = buf * F[k + half];
F[k + half] = F[k] - Tmp;
F[k] = F[k] + Tmp;
buf = buf * t;
}
}
}
}
int Calc(){
int res = 0;
for(reg int i = 1; i <= N; i ++) A[i].x = S[i]=='a', B[i].x = S[i]=='b';
FFT_Len = 1; int bit_n = 0;
while(FFT_Len <= (N<<1)) bit_n ++, FFT_Len <<= 1;
for(reg int i = 0; i < FFT_Len; i ++) rev[i] = (rev[i>>1]>>1) | ((i&1) << bit_n-1);
FFT(A, 1), FFT(B, 1);
for(reg int i = 0; i < FFT_Len; i ++) A[i] = A[i]*A[i] + B[i]*B[i];
FFT(A, -1);
for(reg int i = 0; i < FFT_Len; i ++) A[i].x = (A[i].x + 0.5)/FFT_Len;
pw[0] = 1;
for(reg int i = 1; i <= N; i ++) pw[i] = 2ll*pw[i-1] % mod;
for(reg int i = 1; i <= (N<<1)+1; i ++){
ll t1 = (A[i].x + 1)/2;
res += pw[t1] - 1;
if(res >= mod) res -= mod;
}
return res;
}
int main(){
scanf("%s", S+1);
N = strlen(S+1);
int p = Manacher();
printf("%d
", (1ll*Calc()-p+mod)%mod);
return 0;
}