题目描述
给出字符串s1、s2、s3,找出一个字符串w,满足:
1、w是s1的子串;
2、w是s2的子串;
3、s3不是w的子串。
4、w的长度应尽可能大
求w的最大长度。
输入
输入有三行,第一行为一个字符串s1第二行为一个字符串s2,
第三行为一个字符串s3。输入仅含小写字母,字符中间不含空格。
输出
输出仅有一行,为w的最大可能长度,如w不存在,则输出0。
样例输入
abcdef
abcf
bc
样例输出
2
题解
Kmp+二分+Hash
先使用Kmp处理出s3在s1、s2中出现的所有位置,那么w的选择不能包含这些位置。
然后答案显然满足二分性质,因此二分答案,判断是否有s1和s2的公共长度为mid的子串。
将s1的所有长度为mid且不包含s3的子串的Hash值处理出来,放到哈希表中,然后将s2的所有长度为mid且不包含s3的子串的Hash值放到哈希表里查询即可。
其中判断是否包含s3的子串可以使用前缀后缀和:对于当前的[l,r],如果不合法,相当于在r前面出现过的右端点加上l后面出现过的左端点大于总数目。
Hash的过程可以直接使用自然溢出。
时间复杂度 $O(nlog n)$
#include <cstdio> #include <cstring> #include <algorithm> #define N 50010 #define M 30000000 using namespace std; typedef unsigned long long ull; ull base[N]; int n[3] , next[N] , sa[2][N] , sb[2][N]; char s[3][N]; struct data { int head[M] , next[N] , tot; ull v[N]; data() {tot = 0;} inline void insert(ull x) { if(!head[x % M]) head[x % M] = ++tot; else { int i; for(i = head[x % M] ; next[i] ; i = next[i]); next[i] = ++tot; } v[tot] = x; } inline bool count(ull x) { int i; for(i = head[x % M] ; i ; i = next[i]) if(v[i] == x) return 1; return 0; } inline void clear() { int i; for(i = 1 ; i <= tot ; i ++ ) v[i] = next[i] = head[v[i] % M] = 0; tot = 0; } }mp; void kmp(int p) { int i , j; for(i = j = 0 ; i < n[p] ; i ++ ) { base[i + 1] = base[i] * 233; while(~j && s[p][i] != s[2][j]) j = next[j]; if(++j == n[2]) sa[p][i - j + 1] ++ , sb[p][i] ++ , j = next[j]; } for(i = n[p] - 2 ; ~i ; i -- ) sa[p][i] += sa[p][i + 1]; for(i = 1 ; i < n[p] ; i ++ ) sb[p][i] += sb[p][i - 1]; } bool judge(int mid) { int i; ull v = 0; mp.clear(); for(i = 0 ; i < mid - 1 ; i ++ ) v = v * 233 + s[0][i]; for(i = mid - 1 ; i < n[0] ; i ++ ) { v = v * 233 + s[0][i]; if(sa[0][i - mid + 1] + sb[0][i] <= sa[0][0]) mp.insert(v); v -= s[0][i - mid + 1] * base[mid - 1]; } v = 0; for(i = 0 ; i < mid - 1 ; i ++ ) v = v * 233 + s[1][i]; for(i = mid - 1 ; i < n[1] ; i ++ ) { v = v * 233 + s[1][i]; if(sa[1][i - mid + 1] + sb[1][i] <= sa[1][0] && mp.count(v)) return 1; v -= s[1][i - mid + 1] * base[mid - 1]; } return 0; } int main() { int i , j , l , r , mid , ans = 0; for(i = 0 ; i < 3 ; i ++ ) scanf("%s" , s[i]) , n[i] = strlen(s[i]); next[0] = -1; for(i = 1 , j = -1 ; i <= n[2] ; i ++ ) { while(~j && s[2][j] != s[2][i - 1]) j = next[j]; next[i] = ++j; } base[0] = 1 , kmp(0) , kmp(1); l = 1 , r = min(n[0] , n[1]); while(l <= r) { mid = (l + r) >> 1; if(judge(mid)) ans = mid , l = mid + 1; else r = mid - 1; } printf("%d " , ans); return 0; }