You are given a string s. Each pair of numbers l and r that fulfill the condition 1 ≤ l ≤ r ≤ |s|, correspond to a substring of the string s, starting in the position l and ending in the position r (inclusive).
Let's define the function of two strings F(x, y) like this. We'll find a list of such pairs of numbers for which the corresponding substrings of string x are equal to string y. Let's sort this list of pairs according to the pair's first number's increasing. The value of function F(x, y)equals the number of non-empty continuous sequences in the list.
For example: F(babbabbababbab, babb) = 6. The list of pairs is as follows:
(1, 4), (4, 7), (9, 12)
Its continuous sequences are:
- (1, 4)
- (4, 7)
- (9, 12)
- (1, 4), (4, 7)
- (4, 7), (9, 12)
- (1, 4), (4, 7), (9, 12)
Your task is to calculate for the given string s the sum F(s, x) for all x, that x belongs to the set of all substrings of a string s.
The only line contains the given string s, consisting only of small Latin letters (1 ≤ |s| ≤ 105).
Print the single number — the sought sum.
Please do not use the %lld specificator to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specificator.
aaaa
20
abcdef
21
abacabadabacaba
188
In the first sample the function values at x equal to "a", "aa", "aaa" and "aaaa" equal 10, 6, 3 and 1 correspondingly.
In the second sample for any satisfying x the function value is 1.
题意:如果某一种子串s在原串中出现了k次,根据题目定义的函数,它产生的贡献是(k+1)*k/2
这个条件很奇怪,我们尝试转化模型,就会发现这个函数相当于我们将这k个s串排成一排,每个
串和它自己以及后面的串匹配一次,总次数就是题目要求的函数
于是我们可以上后缀数组+高度数组,对于每一个后缀,和后面的每一个后缀的算一个最长公共前缀,然后根据长度统计答案
这个东西可以用单调栈搞一搞,最后每个后缀和自己可以匹配一次,也就是说如果读入的串长度为n,ans+=(n+1)*n/2
代码:
1 //#include"bits/stdc++.h" 2 #include"cstdio" 3 #include"map" 4 #include"set" 5 #include"cmath" 6 #include"queue" 7 #include"vector" 8 #include"string" 9 #include"ctime" 10 #include"stack" 11 #include"deque" 12 #include"cstdlib" 13 #include"cstring" 14 #include"iostream" 15 #include"algorithm" 16 17 #define db double 18 #define ll long long 19 #define vec vector<ll> 20 #define Mt vector<vec> 21 #define ci(x) scanf("%d",&x) 22 #define cd(x) scanf("%lf",&x) 23 #define cl(x) scanf("%lld",&x) 24 #define pi(x) printf("%d ",x) 25 #define pd(x) printf("%f ",x) 26 #define pl(x) printf("%lld ",x) 27 //#define rep(i, x, y) for(int i=x;i<y;i++) 28 #define rep(i, n) for(int i=0;i<n;i++) 29 using namespace std; 30 const int N = 1e6 + 5; 31 const int mod = 1e9 + 7; 32 const int MOD = mod - 1; 33 const int inf = 0x3f3f3f3f; 34 const db PI = acos(-1.0); 35 const db eps = 1e-10; 36 int sa[N]; 37 int rk[N]; 38 int tmp[N]; 39 int lcp[N]; 40 int n,k; 41 bool cmp(int i,int j){ 42 if(rk[i] != rk[j]) return rk[i]<rk[j]; 43 else 44 { 45 int ri=i+k<=n?rk[i+k]:-1; 46 int rj=j+k<=n?rk[j+k]:-1; 47 return ri<rj; 48 } 49 } 50 void bulid(string s,int *sa) 51 { 52 n=(int)s.size(); 53 for(int i=0;i<=n;i++){ 54 sa[i]=i; 55 rk[i]=i<n?s[i]:-1; 56 } 57 for(k=1;k<=n;k*=2){ 58 sort(sa,sa+n+1,cmp); 59 tmp[sa[0]]=0; 60 for(int i=1;i<=n;i++){ 61 tmp[sa[i]]=tmp[sa[i-1]]+(cmp(sa[i-1],sa[i])?1:0); 62 } 63 for(int i=0;i<=n;i++){ 64 rk[i]=tmp[i]; 65 } 66 } 67 } 68 void LCP(string s,int *sa,int *lcp){ 69 n=(int)s.size(); 70 for(int i=0;i<=n;i++) rk[sa[i]]=i; 71 int h=0; 72 lcp[0]=0; 73 for(int i=0;i<n;i++){ 74 int j=sa[rk[i]-1]; 75 for (h ? h-- : 0; j + h < n&&i + h < n&&s[j + h] == s[i + h]; h++); 76 lcp[rk[i]-1] = h; 77 } 78 } 79 #define x first 80 #define y second 81 #define Pair pair<int,int> 82 #define mp make_pair 83 84 stack<Pair> sta; 85 int main () 86 { 87 string s; 88 cin>>s; 89 n=s.length(); 90 bulid(s,sa); 91 LCP(s,sa,lcp); 92 ll ans=(ll)n*(ll)(n+1)/2; 93 ll cnt=0; 94 for (int i=0;i<=n;i++) 95 { 96 Pair ins=mp(lcp[i],1);//贡献为lcp[i]*num 97 while (!sta.empty() && sta.top().x>ins.x) 98 { 99 cnt-=(ll)sta.top().x*sta.top().y; 100 ins.y+=sta.top().y; 101 sta.pop(); 102 } 103 cnt+=(ll)ins.x*ins.y; 104 sta.push(ins); 105 ans+=cnt; 106 } 107 cout<<ans<<endl; 108 return 0; 109 }