zoukankan      html  css  js  c++  java
  • 「字符串算法」第3章 KMP 算法课堂过关

    「字符串算法」第3章 KMP 算法课堂过关

    关于KMP

    upd on 2021/4/1 优化一些细节

    声明:本文的字符串下标均从1开始,对于某个字符串a,a.substr(i,j)表示a从第i位开始,长度为j的字串

    模板题

    传送门

    KMP算法的大致原理

    个人认为其他博客已经讲得很好,这里简单讲,把重点放在next数组上

    先推几篇博客:

    首先,我们把模板题中的(s_1)串称为文本串,重命名为(s)(s_2)称为模式串,重命名为(t)(本文中不区分s与t的大小写)

    (n)(s)的长度,(m)(t)的长度(会在代码片中出现)

    看图,在第一轮匹配中,匹配到了一个不相等的位置,如果用暴力,那就是从头再匹配,但是可以看到(t)串中有一段重复的“ABC”,无需重复匹配,所以第二轮直接跳到如图所示的位置比较两个蓝色的部分

    这就是KMP算法的大致思路

    next数组

    定义

    看了KMP的大致原理,相信大家都产生了疑问:我怎么知道要让T串跳到哪个位置呢?这就要用到next数组了,这是KMP的核心,也是难点

    先不用管怎么求next数组,看定义(我自己写的):

    (j=next_i),则有(j<i)(t.substr(1,j)==t.substr(i-j+1,i)),且对于任意(k(j<k<i)),(t.substr(1,k)≠t.substr(i-k+1,k))

    也就是说,next[i]表示“T中以i结尾的非前缀字串”与“T的前缀”能匹配的最长长度,当不存在这样的j时,next[i]=0

    举个例子:

    若T="ABCDABCE",则对应的next={0 0 0 0 1 2 3 0}

    应用

    根据next数组的定义,next中存储的是长度,但是由于它是T的某个前缀字串的长度,我们也可以将next当做下标使用(一定要弄清楚,不然后面很蒙)

    仍然用上面的图片真懒呐

    设S的指针为i,T的指针为j,表示当前完成匹配的位置(也就是说S[i]和T[j]是相等的)

    第一轮匹配中,当(j==7)时,我们发现(t)的下一位和(s)的下一位不等,但是(t)的第57位和13位是一样的,即next[7]=3,所以我们需要将(t)的指针(j)跳到第3位,也就是j=next[j],这里有一些细节不是很好理解,KMP在实现时是很巧妙的,我们放到整段代码理解:

    		while(j != 0 && s[i] != t[j+1])
    			j = next[j];
    		if(s[i] == t[j+1])
    			j++;
    		if(j == m){//j==m标志着已经全部完成匹配
    			printf("%d
    ",i - m + 1);
    			j = next[j];
    		}
    

    求法

    这里是整个KMP最难理解的部分,所以放到最后

    先贴出代码

    	next[1] = 0;//初始化
    	for(int i = 2 , j = 0 ; i <= m ; i++){
    		while(j != 0 && t[j+1] != t[i])
    			j=next[j];//全算法最confusing的语句
    		if(t[j+1] == t[i])
    			j++;
    		next[i] = j;
    	}
    

    考虑暴力枚举:最外层循环枚举每一位(i),第二层枚举next[i],里层判断第二层枚举的是否合法

    显然,时间复杂度是在(O(n^2)~O(n^3)),还不如(O(ncdot m))的暴力匹配

    优化求法:

    先提前声明:求next[i]是要用到next[1~i-1]的,所以我们要从前向后顺序枚举i

    定义“候选项”的概念(可能跟《算法竞赛……》的不大一样):如果j满足 t.substr(1,j)==t.substr(i-1-j-1,j)&&j<i-1则j是next[i]的一个候选项

    例子:

    绿色表示相等的两个字串,则j是next[i]的一个候选项,若标成蓝色的两个字符相等,则候选项j是合法的,next[i]就是所有合法的(j)中的最大值+1

    很显然,对于next[i]而言,next[i-1]是它的候选项,但是,问题是next[next[i-1]],next[next[next[i-1]]],......都是候选项,为什么呢?还是看图:

    假设next[13]=5,根据(next)的定义,标绿色部分是相等的,再细化一下绿色部分中相等的部分:假设next[5]=2,同理,第二行(不计最上面的下标行)的黄色部分相等,又因为绿色部分相等,我们可以得到第三行的黄色部分都是相等的,再简化为第4行,会发现:这不是和第一行一样了吗(只是长度小了)!

    以此类推,可以得到next[i-1],next[next[i-1]],next[next[next[i-1]]],......都是候选项,且他们的值是从左向右递减的,因此,按照这个顺序找到第一个合法的候选值之后,我们就可以确定next[i]

    重新看一下代码:

    	next[1] = 0;
    	for(int i = 2 , j = 0 ; i <= m ; i++){
    		while(j != 0 && t[j+1] != t[i])//找到第一个合法的候选项
    			j=next[j];//缩小长度
    		if(t[j+1] == t[i])
    			j++;
    		next[i] = j;
    	}
    

    发现,每一轮循环没有j=next[i-1]的语句。原因很简单:上一轮结束时语句next[i]=j决定了这一轮刚开始就有j==next[i-1],注意这里的前后的(i)不一样(都不是同一轮循环了)不要学傻了

    时间复杂度

    上结论:(O(n+m))

    (next)数组的求值为例:

    	next[1] = 0;
    	for(int i = 2 , j = 0 ; i <= m ; i++){
    		while(j != 0 && t[j+1] != t[i])
    			j=next[j];
    		if(t[j+1] == t[i])
    			j++;
    		next[i] = j;
    	}
    

    最外层显然是(O(m))的,问题是里面

    while循环中,(j)是递减的,但是又不会变成负数,所以整个过程中,(j)的减小幅度不会超过(j)增加的幅度,而(j)每次才增加1,最多增加(m)次,故(j)的总变化次数不超过(2m),整个时间复杂度近似认为是(O(m))

    如果还不能理解,就想像一个平面直角坐标系,(x)轴为(i)(y)轴为(j),从原点出发,(i)每向右一个单位,(j)最多向上一个单位,(j)也可以往下掉(while循环),但不能掉到第四象限,(j)向下掉的高度之和就是while内语句执行的总次数,是绝对不会超过(m)

    匹配的循环与上述相近,时间为(O(n+m)),不再赘述

    所以,总的时间复杂度为(O(n+m))

    模板题代码

    不要问模板题输出的最后一行是什么意思,我也不知道,反正输出(next)数组就对了

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #define nn 1000010
    using namespace std;
    int sread(char s[]) {
    	int siz = 1;
    	do
    		s[siz] = getchar();
    	while(s[siz] < 'A' || s[siz] > 'Z');
    	while(s[siz] >= 'A' && s[siz] <= 'Z') {
    		++siz;
    		s[siz] = getchar();
    	}
    	--siz;
    	return siz;
    }
    char s[nn];
    char t[nn];
    int next[nn];
    int n , m;
    int main() {
    	n = sread(s);
    	m = sread(t);
    	next[1] = 0;
    	for(int i = 2 , j = 0 ; i <= m ; i++){
    		while(j != 0 && t[j+1] != t[i])
    			j=next[j];
    		if(t[j+1] == t[i])
    			j++;
    		next[i] = j;
    	}
    	for(int i = 1 , j = 0 ; i <=n ; i++){
    		while(j != 0 && s[i] != t[j+1])
    			j = next[j];
    		if(s[i] == t[j+1])
    			j++;
    		if(j == m){
    			printf("%d
    ",i - m + 1);
    			j = next[j];
    		}
    	}
    	for(int i = 1 ; i <= m ; i++)
    		printf("%d " , next[i]);
    	return 0;
    }
    

    A. 【例题1】子串查找

    题目

    代码

    #include <iostream>
    #include <cstdio>
    #define nn 1000010
    using namespace std;
    int ssiz , tsiz;
    int sread(char *s) {
    	int siz = 0;
    	char c = getchar();
    	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')))
    		c = getchar();
    		
    	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
    		s[++siz] = c , c = getchar();
    	return siz;
    }
    char s[nn] , t[nn];
    int nxt[nn];
    int main() {
    	ssiz = sread(s);
    	tsiz = sread(t);
    	
    	nxt[1] = 0;
    	for(int i = 2 ; i <= tsiz ; i++) {
    		int j = nxt[i - 1];
    		while(j != 0 && t[j + 1] != t[i])
    			j = nxt[j];
    		nxt[i] = j + 1;
    	}
    	
    	int ans = 0;
    	for(int i = 1 , j = 0 ; i <= ssiz ; i++) {
    		while(j != 0 && s[i] != t[j + 1])
    			j = nxt[j];
    		if(s[i] == t[j + 1])
    			++j;
    		if(j == tsiz)
    			j = nxt[j] , ++ans;
    	}
    	cout << ans;
    	return 0;
    }
    

    B. 【例题2】重复子串

    题目

    题目有误

    输入若干行,每行有一个字符串,字符串仅含英文字母。特别的,字符串可能为.即一个半角句号,此时输入结束。

    第五组数据的字符串包含数字字符,有图为证:

    思路

    设字符串长度为(siz)

    Hash

    关于字符串Hash

    不难想也很好写的一种方法

    直接枚举最小周期长度(i),显然,(siz)一定是(i)的倍数,所以,这只需要(O(sqrt n))的时间复杂度

    假设我们已经枚举到(p)的因数(x),就可以直接用(O(frac{siz}{x}))的时间复杂度验证该子字符串是否是周期,代码如下:

    inline bool check(int x) {
    	ul key = hs[x];
    	for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
    		if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)//获取字符串s从下标i开始,长度为x的子串的Hash值 , 判断和key是否相等
    			return false;
    	return true;
    }
    

    KMP

    下面讲好些不太好想的KMP做法

    先上结论:
    命名输入进来的字符串为(S),预处理得到(S)(nxt)数组
    (siz\%(siz-nxt_{siz})==0),则(siz-nxt_{siz})(S)的最小周期,也就是说,此时答案为(siz / (siz - nxt_{siz}))
    否则,答案为"1"

    献上图解:

    代码

    Hash

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #define nn 1000010
    #define ul unsigned long long
    using namespace std;
    #define value(_) (_ >= 'A' && _ <= 'Z' ? (1 + _ - 'A') : (_ >= 'a' && _ <= 'z' ? (27 + _ - 'a') : (_ - '0' + 53) ))
    const ul p = 131;
    
    ul hs[nn];
    ul pw[nn];
    int siz;
    char c[nn];
    
    int sread(char *s) {
    	int siz = 0;
    	char c = getchar();
    	if(c == '.')return -1;
    	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
    		if((c = getchar()) == '.')	return -1;
    		
    	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
    		s[++siz] = c , c = getchar();
    	return siz;
    }
    inline bool check(int x) {
    	ul key = hs[x];
    	for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
    		if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)
    			return false;
    	return true;
    }
    int main() {
    	pw[0] = 1;
    	for(int i = 1 ; i <= nn - 5 ; i++)
    		pw[i] = pw[i - 1] * p;
    	while((siz = sread(c)) != -1) {
    		for(int i = 1 ; i <= siz ; i++)
    			hs[i] = hs[i - 1] * p + value(c[i]);
    		
    		int ans = 0;
    		for(int i = 1 ; i * i <= siz ; i++) {
    			if(siz % i == 0)
    				if(check(i)) {
    					ans = i;
    					break;
    				}
    				else {
    					if(check(siz / i))
    						ans = siz / i;
    				}
    		}
    		printf("%d
    " , siz / ans);
    		memset(c, 0  , sizeof(c));
    		memset(hs , 0 , sizeof(hs));
    	}
    	return 0;
    }
    

    KMP

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #define nn 1000010
    using namespace std;
    int siz;
    int sread(char *s) {
    	int siz = 0;
    	char c = getchar();
    	if(c == '.')return -1;
    	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
    		if((c = getchar()) == '.')	return -1;
    		
    	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
    		s[++siz] = c , c = getchar();
    	return siz;
    }
    char s[nn];
    int nxt[nn];
    int main() {
    	while(true) {
    		memset(s , 0 , sizeof(s));
    		memset(nxt , 0 , sizeof(nxt));
    		siz = sread(s);
    		if(siz == -1)	break;
    		nxt[1] = 0;
    		for(int i = 2 , j = 0 ; i <= siz ; i++) {
    			while(s[i] != s[j + 1] && j != 0)
    				j = nxt[j];
    			if(s[i] == s[j + 1])
    				++j;
    			nxt[i] = j;
    		}
    		if(siz % (siz - nxt[siz]) == 0)
    			printf("%d
    " , siz / (siz - nxt[siz]));
    		else
    			printf("1
    ");
    	}
    	return 0;
    }
    

    C. 【例题3】周期长度和

    题目

    传送门

    思路&代码

    以前写过,传送门

    题目

    传送门

    这题意不是一般人能读懂的,为了读懂题目,我还特意去翻了题解[手动笑哭]

    题目大意:

    给定一个字符串s

    对于(s)的每一个前缀子串(s1),规定一个字符串(Q),(Q)满足:(Q)(s1)的前缀子串且(Q)不等于(s1)(s1)是字符串(Q+Q)的前缀.设(siz)为所有满足条件的(Q)(Q)的最大长度(注意这里仅仅针对(s1)而不是(s),即一个(siz)的值对应一个(s1))

    求出所有(siz)的和

    不要被这句话误导了:

    求给定字符串所有前缀的最大周期长度之和

    正确断句:求给定字符串 所有/前缀的最大周期长度/之和

    我就想了半天:既然是"最大周期长度",那不是唯一的吗?为什么还要求和呢?

    思路

    其实这题要AC并不难(看通过率就知道)

    看图

    要满足(Q)(s1)的前缀,则(Q)(1)~(5)位和(s1)的1~5位是一样的,又因为(s1)(Q+Q)的前缀,所以又要满足(s1)的6~8位和(Q+Q)的6~8位一样,即(s1)的6~8位和Q的1~3位相等,回到(s1),标蓝色的两个位置相等.

    回顾下KMP中(next)数组的定义:next[i]表示对于某个字符串a,"a中长度为next[i]的前缀子串"与"a中以第i为结尾,长度为next[i]的非前缀子串"相等,且next[i]取最大值

    是不是悟到了什么,是不是感觉这题和(next)数组冥冥之中有某种相似之处?

    但是,这仅仅只是开始

    按照题目的意思,我们要让(Q)的长度最大,也就是图中蓝色部分长度最小,但是(next)中存的是蓝色部分的最大值,显然,两者相违背,难道我们要改造(next)数组吗?明显不行,若(next)存储的改为最小值,则原来求(next)的方法行不通.考虑换一种思路(一定要对KMP中(next)的求法理解透彻,不然下面看不懂,不行的复习一下),我们知道对于next[i],next[next[i-1]],next[next[next[i]]]...都能满足"前缀等于以(i)结尾的子串"这个条件,且越往后,值越小,所以,我们的目标就定在上面序列中从后往前第一个不为0的(next)

    极端条件下,暴力跑可以去到(O(n^2)),理论上会超时(我没试过)

    两种优化:

    1. 记忆化,时间效率应该是O(n)这里不详细讲,可以去到洛谷题解查看
    2. 倍增(我第一时间想到并AC的做法):
      我们将j=next[j]这一语句称作"j跳了一次"(感觉怪怪的),将next拓展为2维,next[i][k]表示结尾为i,j跳了2^k的前缀字符长度(也就是next[i][0]等价于原来的next[i])
      借助倍增LCA的思想(没学没关系,现学现用),这里不做赘述,上代码
    		int tmp = i;
    		for(rr int j = siz[i] ; j >= 0 ; --j)//siz[i]是next[i][j]中第一个为0的小标j,注意倒序枚举
    			if(next[tmp][j] != 0)//如果不为0则跳
    				tmp = next[tmp][j];
    

    倍增方法在字符串长度去到(10^6)时是非常危险的,带个(log)理论是(2cdot 10^7)左右,常数再大那么一丢丢就TLE了,还好数据比较水,但是作为倍增和KMP的练习做一下也是不错的

    最后,记得开longlong(不然我就一次AC了)

    完整代码

    #include <iostream>
    #include <cmath>
    #include <cstdio>
    #define nn 1000010
    #define rr register
    #define ll long long
    using namespace std;
    int next[nn][30] ;
    int siz[nn];
    char s[nn];
    int n;
    int main() {
    //	freopen("P3435_3.in" , "r" , stdin);
    	cin >> n;
    	do
    		s[1] = getchar();
    	while(s[1] < 'a' || s[1] > 'z');
    	for(rr int i = 2 ; i <= n ; i++)
    		s[i] = getchar();
    	
    	next[1][0] = 0;
    	for(rr int i = 2 , j = 0 ; i <= n ; i++) {
    		while(j != 0 && s[i] != s[j + 1])
    			j = next[j][0];
    		if(s[j + 1] == s[i])
    			++j;
    		next[i][0] = j;
    	}
    	
    	rr int k = log(n) / log(2) + 1;
    	for(rr int j = 1 ; j <= k ; j++)
    		for(rr int i = 1 ; i <= n ; i++) {
    			next[i][j] = next[next[i][j - 1]][j - 1];
    			if(next[i][j] == 0)
    				siz[i] = j;
    		}
    	ll ans = 0;
    	for(rr int i = 1 ; i <= n ; ++i) {
    		int tmp = i;
    		for(rr int j = siz[i] ; j >= 0 ; --j)
    			if(next[tmp][j] != 0)
    				tmp = next[tmp][j];
    		if(2 * (i - tmp) >= i && tmp != i)
    			ans += (ll)i - tmp;
    	}
    	cout << ans;
    	return 0;
    } 
    

    D. 【例题4】子串拆分

    题目

    思路

    说明,以下思路时间大致复杂度为(O(n^2 )),最坏可以去到(O(n^3)),但数据较水可以通过,看了书,上面的解法也是(O(n^2)),对于(1leq |S|leq 1.5×10^4)来说已经是很极限了

    其实思路很简单,我们直接枚举子串的左右边界(L,R),在右边界扩张的同时把新加入的字符的(nxt)求出来.至此,我们得到了子串(c),和(c)(nxt)数组,时间复杂度为(O(n^2))
    那么我们如何判断(c)是否符合(c=A+B+C(kle len(A),1le len(B) ))呢?看代码(其实做了B,C题这里很好理解)

    			int p = nxt[m];//m为c数组的长度,p即是可能的A的长度
    			while(p >= k && p > 0) {
    				if(m - p - p >= 1) {
    					++ans;
    					break;//直接退出,优化
    				}
    				p = nxt[p];
    			}
    

    这个判断的复杂度是可以达到(O(n))的,在数据范围下十分危险

    下面看下书中是怎么说的:

    我还以为书里有严格(O(n^2))的做法

    下面(p_i)(nxt_i)意义相同

    考虑没枚举左端点,假设左端点为(l),(A=S[l,|S|]),那么对字符串(A)跑一次KMP,在匹配的过程中,设匹配到第(i)个位置,那么我们就要考虑当前得出的(j),显然(A[1,j]=A[i-j+1,i]).如果(ile 2cdot j),那么令(j=p_j),此时(A[i,j]=A[i-j+1,i]),(j)沿指针(p)不断回跳,直到(2cdot j<i).然后判断(j)是否大于(k),如果是,那么累加答案.

    因为每次KMP的复杂度是(O(n)),所以总时间复杂度为(O(n^2))

    核心代码

    //每次KMP匹配 
    inline void solve(char *a) {
    	p[1] = 0;
    	int n = strlen(a + 1);
    	for(int i = 1 , j = 0 ; i < n ; i++) {
    		while(j && a[j + 1] != a[i + 1])
    			j = p[j];
    		if(a[j + 1] == a[i + 1])
    			++j;
    		p[i + 1] = j;
    	}
    	for(int i = 1 , j = 0 ; i < n ; i++) {
    		while(j && a[j + 1] != a[i + 1])
    			j = p[j];
    		if(a[j + 1] == a[i + 1])
    			j++;
    		while(j * 2 >= i + 1)
    			j = p[j];
    		if(j >= k)
    			++ans;
    	}
    }
    
    //枚举左端点 
    int len = strlen(str + 1) - (k << 1);
    for(int i = 0 ; i < len ; i++)
    	solve(str + i);
    

    代码

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #define nn 15000
    using namespace std;
    int sread(char *s) {
    	int siz = 0;
    	char c = getchar();
    	if(c == '.')return -1;
    	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
    		if((c = getchar()) == '.')	return -1;
    		
    	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
    		s[++siz] = c , c = getchar();
    	return siz;
    }
    int n , k , m;
    int ans;
    int nxt[nn];
    char c[nn] , s[nn];
    
    int main() {
    	n = sread(s);
    	cin >> k;
    	for(int L = 1 ; L <= n ; L++) {
    		memset(nxt , 0 , sizeof(nxt));
    		memset(c , 0 , sizeof(c));
    		for(int i = L ; i <= L + k + k ; i++)
    			c[i - L + 1] = s[i];
    		m = k + k;
    		
    		nxt[1] = 0;
    		for(int i = 2 ; i <= m ; i++) {
    			int j = nxt[i - 1];
    			while(c[j + 1] != c[i] && j != 0)
    				j = nxt[j];
    			if(c[j + 1] == c[i])	++j;
    			nxt[i] = j;
    		}
    		
    		for(int R = L + k + k; R <= n ; R++) {
    			m = R - L + 1;
    			c[m] = s[R];
    			
    			int j = nxt[m - 1];
    			while(c[j + 1] != c[m] && j != 0)
    				j = nxt[j];
    			if(c[j + 1] == c[m])	++j;
    			if(m != 1)	nxt[m] = j;
    					
    			int p = nxt[m];
    			while(p >= k && p > 0) {
    				if(m - p - p >= 1) {
    					++ans;
    					break;
    				}
    				p = nxt[p];
    			}
    			
    		}
    	}
    	cout << ans;
    	return 0;
    }
    
  • 相关阅读:
    WINCE6.0+S3C6410睡眠和唤醒的实现
    WINCE6.0+S3C6410的触摸屏驱动
    S3C6410的Bootloader的两个阶段BL1和BL2编译相关学习
    amix vim vimrc 3.6 [_vimrc x64 vim (WorkPlace)]配置
    异常的开销
    A C# Reading List by Eric Lippert (ZZ)
    SQL SERVER 2008中定时备份数据库任务的创建与删除
    ASP.NET26个常用性能优化方法
    如何使用四个语句来提高 SQL Server 的伸缩性
    Cookies揭秘 [Asp.Net, Javascript]
  • 原文地址:https://www.cnblogs.com/dream1024/p/14612598.html
Copyright © 2011-2022 走看看