题目链接:https://ac.nowcoder.com/acm/contest/5666/A
想法:
我们可以发现以下的一些规律
对于任意后缀其B数组的第一个元素一定为0,并且B数组的开头一定为01111(1的个取决于开头有多少个连续的相同字符) 【当然也可以是 00 这种情况,比如 ab 】
例如 aaabba = 0111013,aaabb = 01101
如果两个字符串连续1的长度不同,那么更短的那个字典序更小,所以我们可以预处理出一个 dis 数组,代表i位置处的后缀的开头的相同字符长度 + 1【+1 是因为这样我们可以直接找到那一个不同的字符 】
现在我们再接着考虑如果两个B数组拥有相同的开头之后我们该怎么处理
对于比较两个字符串的字典序大小,我们肯定希望找到两个字符串第一个不同的位置进行比较,现在问题的关键就在于找到这样的一个位置
在找到位置前,我们还得发现B数组有这样的一个性质
我们可以发现如果长度发现【i,i+dis[i]】这个区间范围肯定是包含了 a 和 b 的(如果不包含那么dis[i]会更大),这个是有用的
根据B函数的定义我们可以发现如果对于一个字符串不断加入字符,如果a和b都出现了,前面再怎么加入字符对后面都没有影响。
比如abaaab=>102114在前两个都出现了,因此你在前面在怎么加入字符串,都不会改变2后面子串的函数值。
有了这样的一个性质,我们就可以对初始的S串求B数组,然后对于前缀(这里的前缀指开头的01序列)已经相同的两个字符串,直接在B数组中比较后半部分
但是,除去前缀的部分可能会有很长的相同部分,如果我们暴力进行比较,肯定会超时,所以此时,我们可以利用后缀数组,在O(1)的时间复杂度内求出两个后缀的LCP,然后比较LCP之后的位置就可以了
最后,我们就可以根据上面描述的规则去编写cmp函数进行排序了
如果【i + dis[i]】 一句大于等于 n 的时候,我们就取靠后面的就可以了
#include <algorithm> #include <string> #include <cstring> #include <vector> #include <map> #include <stack> #include <set> #include <queue> #include <cmath> #include <cstdio> #include <iomanip> #include <ctime> #include <bitset> #include <cmath> #include <sstream> #include <iostream> #include <unordered_map> #define ll long long #define ull unsigned long long #define ls nod<<1 #define rs (nod<<1)+1 #define pii pair<int,int> #define mp make_pair #define pb push_back #define INF 0x3f3f3f3f #define max(a, b) (a>b?a:b) #define min(a, b) (a<b?a:b) const double eps = 1e-2; const int maxn = 2e5 + 10; const ll MOD = 99999999999999; const int mlog=20; int sgn(double a) { return a < -eps ? -1 : a < eps ? 0 : 1; } using namespace std; struct Suffix_Array { int s[maxn],sa[maxn],rk[maxn],height[maxn]; int t[maxn],t2[maxn],c[maxn],n; void init() { memset(t, 0, sizeof(int) * (2 * n + 10)); memset(t2, 0, sizeof(int) * (2 * n + 10)); } void build_sa(int m=256) { int *x = t, *y = t2; for(int i=0;i<m;++i) c[i]=0; for(int i=0;i<n;++i) c[x[i]=s[i]]++; for(int i=1;i<m;++i) c[i]+=c[i-1]; for(int i=n-1;i>=0;--i) sa[--c[x[i]]]=i; for(int k=1;k<=n;k<<=1){ int p=0; for(int i=n-1;i>=n-k;--i) y[p++]=i; for(int i=0;i<n;++i) if(sa[i]>=k) y[p++]=sa[i]-k; for(int i=0;i<m;++i) c[i]=0; for(int i=0;i<n;++i) c[x[y[i]]]++; for(int i=1;i<m;++i) c[i] += c[i - 1]; for(int i=n-1;i>=0;--i) sa[--c[x[y[i]]]]=y[i]; swap(x,y); p=1; x[sa[0]]=0; for(int i=1;i<n;++i) x[sa[i]]=y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++; if(p>=n) break; m=p; } } void get_height() { int k=0; for(int i=0;i<n;++i) rk[sa[i]]=i; for(int i=0;i<n;++i){ if(rk[i]>0){ if(k) --k; int j=sa[rk[i]-1]; while(i+k<n&&j+k<n&&s[i+k]==s[j+k]) ++k; height[rk[i]]=k; } } } int d[maxn][mlog],Log[maxn]; void RMQ_init() { Log[0]=-1; for(int i=1;i<=n;++i) Log[i]=Log[i/2]+1; for(int i=0;i<n;++i) d[i][0]=height[i]; for(int j=1;j<=Log[n];++j){ for(int i=0;i+(1<<j)-1<n;++i){ d[i][j]=min(d[i][j-1],d[i+(1<<(j-1))][j-1]); } } } int lcp(int i,int j)//返回下标i开始的后缀与下标j开始的后缀的最长公共前缀。 { if(i==j) return n-i; if(rk[i]>rk[j]) swap(i,j); int x=rk[i]+1,y=rk[j]; int k=Log[y-x+1]; return min(d[x][k],d[y-(1<<k)+1][k]); } pair <int, int> Locate(int l, int r)//返回一个最长的区间[L, R]使得sa中下标从L到R的所有后缀都以s[l, r]为前缀。 { int pos=rk[l],length=r-l+1; int L=0,R=pos,M; while(L<R){ M=(L+R)>>1; if(lcp(l,sa[M])>=length) R=M; else L=M+1; } int tmp=L; L=pos,R=n-1; while(L<R){ M=(L+R+1)>>1; if(lcp(l,sa[M])>=length) L=M; else R=M-1; } return make_pair(tmp,L); } }SA; int n; int b[maxn],dis[maxn],ans[maxn]; char s[maxn]; bool cmp(int i,int j) { if (dis[i] != dis[j]) return dis[i] < dis[j]; if (i + dis[i] >= n && j + dis[j] >= n) return i > j; if (i + dis[i] >= n) return 1; if (j + dis[j] >= n) return 0; int lcp = SA.lcp(i+dis[i],j+dis[j]); return b[i + dis[i] + lcp] < b[j + dis[j] + lcp]; } int main() { while (~scanf("%d",&n)) { b[n] = 0; // 这里需要特别注意 (后缀数组最后一个都需要特殊处理) scanf("%s",s); int fla = -1,flb = -1; for (int i = 0;i < n;i++) { if (s[i] == 'a') { if (fla != -1) b[i] = i - fla; else b[i] = 0; fla = i; } else { if (flb != -1) b[i] = i - flb; else b[i] = 0; flb = i; } } SA.n = n; for (int i = 0;i < n;i++) SA.s[i] = b[i]; SA.init(); SA.build_sa(); SA.get_height(); SA.RMQ_init(); dis[n - 1] = 2; for (int i = n - 2;i >= 0;i--) { if (s[i] == s[i+1]) dis[i] = dis[i+1] + 1; else dis[i] = 2; } for (int i = 0;i < n;i++) ans[i] = i; sort(ans,ans+n,cmp); for (int i = 0;i < n;i++) printf("%d ",ans[i]+1); printf(" "); } return 0; }