题意:
找出所有【i,j】为回文串【j+1,k】也为回文串的i*k乘积之和。
题解:
设sum1【i】 为正着插入,到 i 的所有回文串的起始位置的前缀和,sum2【i】 表示反正插入的前缀和
ans+=sum1【i]*sum1【i+1】
上面的式子很容易让我们想到两遍回文树正着和反着插入操作,
回文树的num【】表示到达 i 这个节点的回文串个数
我们用一个sum【i】数组到 i 的时候所有出现的回文串的长度的前缀和
sum1[i] = (1LL * (i + 1) * pam.num[pam.last] % mod - pam.sum[pam.last] + mod) % mod;
由于卡内存 ,所以我们就不记录sum2[]了
1 #include <set> 2 #include <map> 3 #include <stack> 4 #include <queue> 5 #include <cmath> 6 #include <ctime> 7 #include <cstdio> 8 #include <string> 9 #include <vector> 10 #include <cstring> 11 #include <iostream> 12 #include <algorithm> 13 #include <unordered_map> 14 15 #define pi acos(-1.0) 16 #define eps 1e-9 17 #define fi first 18 #define se second 19 #define rtl rt<<1 20 #define rtr rt<<1|1 21 #define bug printf("****** ") 22 #define mem(a, b) memset(a,b,sizeof(a)) 23 #define name2str(x) #x 24 #define fuck(x) cout<<#x" = "<<x<<endl 25 #define sfi(a) scanf("%d", &a) 26 #define sffi(a, b) scanf("%d %d", &a, &b) 27 #define sfffi(a, b, c) scanf("%d %d %d", &a, &b, &c) 28 #define sffffi(a, b, c, d) scanf("%d %d %d %d", &a, &b, &c, &d) 29 #define sfL(a) scanf("%lld", &a) 30 #define sffL(a, b) scanf("%lld %lld", &a, &b) 31 #define sfffL(a, b, c) scanf("%lld %lld %lld", &a, &b, &c) 32 #define sffffL(a, b, c, d) scanf("%lld %lld %lld %lld", &a, &b, &c, &d) 33 #define sfs(a) scanf("%s", a) 34 #define sffs(a, b) scanf("%s %s", a, b) 35 #define sfffs(a, b, c) scanf("%s %s %s", a, b, c) 36 #define sffffs(a, b, c, d) scanf("%s %s %s %s", a, b,c, d) 37 #define FIN freopen("../in.txt","r",stdin) 38 #define gcd(a, b) __gcd(a,b) 39 #define lowbit(x) x&-x 40 #define IO iOS::sync_with_stdio(false) 41 #pragma comment(linker, "/STACK:102400000,102400000") 42 43 using namespace std; 44 typedef long long LL; 45 typedef unsigned long long ULL; 46 const ULL seed = 13331; 47 const LL INFLL = 0x3f3f3f3f3f3f3f3fLL; 48 const int maxn = 1e6 + 3; 49 const int maxm = 8e6 + 10; 50 const int INF = 0x3f3f3f3f; 51 const int mod = 1e9 + 7; 52 char s[maxn]; 53 int sum1[maxn]; 54 55 struct Palindrome_Automaton { 56 int len[maxn], next[maxn][26], fail[maxn], sum[maxn]; 57 int num[maxn], S[maxn], sz, n, last; 58 59 int newnode(int l) { 60 for (int i = 0; i < 26; ++i)next[sz][i] = 0; 61 num[sz] = 0, len[sz] = l; 62 return sz++; 63 } 64 65 void init() { 66 sz = n = last = 0; 67 newnode(0); 68 newnode(-1); 69 S[0] = -1; 70 fail[0] = 1; 71 } 72 73 int get_fail(int x) { 74 while (S[n - len[x] - 1] != S[n])x = fail[x]; 75 return x; 76 } 77 78 void add(int c) { 79 c -= 'a'; 80 S[++n] = c; 81 int cur = get_fail(last); 82 if (!next[cur][c]) { 83 int now = newnode(len[cur] + 2); 84 fail[now] = next[get_fail(fail[cur])][c]; 85 next[cur][c] = now; 86 num[now] = num[fail[now]] + 1; 87 sum[now] = (sum[fail[now]] + len[now]) % mod; 88 } 89 last = next[cur][c]; 90 } 91 92 } pam; 93 94 int main() { 95 //FIN; 96 while (~sfs(s+1)) { 97 int n = strlen(s + 1); 98 pam.init(); 99 for (int i = 1; i<=n; ++i) { 100 pam.add(s[i]); 101 sum1[i] = (1LL * (i + 1) * pam.num[pam.last] % mod - pam.sum[pam.last] + mod) % mod; 102 } 103 LL ans = 0; 104 pam.init(); 105 for (int i = n; i >= 1; i--) { 106 pam.add(s[i]); 107 ans = (ans + sum1[i - 1] * (1LL * (i - 1) * pam.num[pam.last] % mod + pam.sum[pam.last]) % mod) % mod; 108 } 109 printf("%lld ", ans); 110 } 111 return 0; 112 }