zoukankan      html  css  js  c++  java
  • AC自动机

    对于多串匹配一种能够理论上时间复杂度为O(n+m)的多串匹配方式,但是时间复杂度并不稳定。

    原理:在trie树上建立类似KMP的next指针的东西,也就是AC自动机的fail指针,在每次匹配的时候,不停的跳fail指针直到根节点。这是最裸的实现,但是许多情况下,这种最朴素的实现方式过不去...因为这样它的时间复杂度很不稳定,可能被卡到O(nm)这种模式下,我们就需要将AC自动机的BFS序拿出来,单独进行操作,甚至将fail指针建成fail树实现,还有什么打上差分标记之类的实现方式降低时间复杂度。

    例题:

    BZOJ3940: [Usaco2015 Feb]Censoring

    分析:

    AC自动机裸题,同年同月银组题是KMP+栈实现,而这道题是AC自动机+栈实现,因为题目满足模式串互不包含,所以直接贪心+栈模拟维护一下即可。

    附上代码:

    #include <cstdio>
    #include <algorithm>
    #include <cmath>
    #include <cstdlib>
    #include <cstring>
    #include <queue>
    #include <iostream>
    using namespace std;
    #define N 1000005
    struct Aho
    {
    	int ch[N][26],last[N],fail[N],cnt,rot;
    	int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;}
    	void init(){cnt=0;rot=new_node();}
    	void insert(char *s,int x)
    	{
    		int len=strlen(s),rt=0;
    		for(int i=0;i<len;i++)
    		{
    			if(ch[rt][s[i]-'a']==-1)ch[rt][s[i]-'a']=new_node();
    			rt=ch[rt][s[i]-'a'];
    		}
    		last[rt]=x;
    	}
    	void get_fail()
    	{
    		queue <int>q;fail[0]=0;
    		for(int i=0;i<26;i++)
    		{
    			if(ch[0][i]==-1)ch[0][i]=0;
    			else fail[ch[0][i]]=0,q.push(ch[0][i]);
    		}
    		while(!q.empty())
    		{
    			int x=q.front();q.pop();
    			for(int i=0;i<26;i++)
    			{
    				if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i];
    				else fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]);
    			}
    		}
    	}
    }ac;
    int n,m,len[N],sta[N],top,pos[N];
    char sub[N],str[N];
    int main()
    {
    	scanf("%s%d",str+1,&n);m=strlen(str+1);ac.init();
    	for(int i=1;i<=n;i++)scanf("%s",sub),ac.insert(sub,i),len[i]=strlen(sub);
    	ac.get_fail();int rt=0;
    	for(int i=1;i<=m;i++)
    	{
    		rt=ac.ch[rt][str[i]-'a'];pos[++top]=rt,sta[top]=str[i];
    		if(ac.last[rt])top-=len[ac.last[rt]],rt=pos[top];
    	}
    	for(int i=1;i<=top;i++)printf("%c",sta[i]);puts("");
    	return 0;
    }
    

    BZOJ3172: [Tjoi2013]单词

    分析:

    这道题,最裸的AC自动机实现方式过不了...不用测试了,我试过了...

    将AC自动机建起来,在每个节点打一个标记,每次插入的时候遍历到这个节点,这个节点的出现次数就++,之后每次讲节点的出现次数传递给fail节点,最后统计每一个串的终止节点出现次数即可。

    附上代码:

    #include <cstdio>
    #include <algorithm>
    #include <queue>
    #include <cstring>
    #include <cstdlib>
    #include <cmath>
    #include <iostream>
    #include <set>
    using namespace std;
    #define N 1000005
    int ans[N];
    struct Aho
    {
        int ch[N][26],pos[N],fail[N],vis[N],cnt,rot,que[N],fa[N],l,r;
        int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));vis[cnt++]=0;return cnt-1;}
        void init(){l=r=cnt=0,rot=new_node();}
        void insert(char *s,int x)
        {
            int len=strlen(s),rt=rot;
            for(int i=0;i<len;i++)
            {
                int t=s[i]-'a';
                if(ch[rt][t]==-1)ch[rt][t]=new_node();
                rt=ch[rt][t];vis[rt]++;
            }
            pos[x]=rt;
        }
        void get_fail()
        {
            for(int i=0;i<26;i++)
            {
                if(ch[0][i]==-1)ch[0][i]=0;
                else que[r++]=ch[0][i],fail[ch[0][i]]=0;
            }
            while(l<r)
            {
                int x=que[l++];
                for(int i=0;i<26;i++)
                {
                    if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i];
                    else fail[ch[x][i]]=ch[fail[x]][i],que[r++]=ch[x][i];
                }
            }
        }
        int match()
        {
            for(int i=cnt;i>=0;i--)
            {
                int x=que[i];
                vis[fail[x]]+=vis[x];
            }
        }
    }ac;
    char s[N];
    int main()
    {
        int n;
        scanf("%d",&n);ac.init();
        for(int i=1;i<=n;i++)
        {
            scanf("%s",s);
            ac.insert(s,i);
        }
        ac.get_fail();ac.match();
        for(int i=1;i<=n;i++)
        {
            //printf("%d
    ",ac.pos[i]);
            printf("%d
    ",ac.vis[ac.pos[i]]);
        }
    }
    

    BZOJ1030: [JSOI2007]文本生成器

    分析:

    这道题看起来和GT考试很像,用全部的生成数量去掉完全不可读的数量,就是答案。全部的数量是26^m,而完全不可读的是类似GT考试的求法,建立AC自动机,之后类似遍历AC自动机的方式进行DP,顺便注意一点,即:如果跳fail指针能跳到某一个字符串的终止节点,那么这个节点就不能作为答案出现,即:我们统计所有的不存在一个终止节点作为祖先的节点。

    附上代码:

    #include <cstdio>
    #include <cmath>
    #include <algorithm>
    #include <iostream>
    #include <queue>
    #include <cstdlib>
    #include <cstring>
    using namespace std;
    #define N 200005
    #define mod 10007
    int f[120][N],n,m;char s[N];
    struct Aho
    {
    	int ch[N][26],fail[N],last[N],cnt,q[N],l,r,rot;
    	int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;}
    	void init(){cnt=0;rot=new_node();}
    	void insert(char *s)
    	{
    		int len=strlen(s),rt=rot;
    		for(int i=0;i<len;i++)
    		{
    			int t=s[i]-'A';
    			if(ch[rt][t]==-1)ch[rt][t]=new_node();
    			rt=ch[rt][t];
    		}
    		last[rt]=1;
    	}
    	void get_fail()
    	{
    		l=r=0;
    		for(int i=0;i<26;i++)
    		{
    			if(ch[0][i]==-1)ch[0][i]=rot;
    			else q[r++]=ch[0][i],fail[ch[0][i]]=rot;
    			last[0]|=last[fail[0]];
    		}
    		while(l<r)
    		{
    			int x=q[l++];
    			for(int i=0;i<26;i++)
    			{
    				if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i];
    				else q[r++]=ch[x][i],fail[ch[x][i]]=ch[fail[x]][i];
    			}
    			last[x]|=last[fail[x]];
    		}
    	}
    	int match()
    	{
    		f[0][0]=1;
    		for(int i=1;i<=m;i++)
    			for(int j=0;j<cnt;j++)
    				if(!last[j]&&f[i-1][j])
    					for(int k=0;k<26;k++)
    						f[i][ch[j][k]]=(f[i][ch[j][k]]+f[i-1][j])%mod;
    		int num=1,ans=0;
    		for(int i=1;i<=m;i++)num=num*26%mod;
    		for(int i=0;i<cnt;i++)
    		{
    			if(!last[i])ans=(ans+f[m][i])%mod;
    		}
    		return (num-ans+mod)%mod;
    	}
    }ac;
    int main()
    {
    	scanf("%d%d",&n,&m);ac.init();
    	for(int i=1;i<=n;i++)
    	{
    		scanf("%s",s);ac.insert(s);
    	}
    	ac.get_fail();
    	printf("%d
    ",ac.match());
    	return 0;
    }
    

    BZOJ2553: [BeiJing2011]禁忌

    分析:

    看到len<=10^9就知道和矩阵乘法有关。建立矩阵,如果一个节点是终止节点(或者它的祖先存在终止节点),那么将矩阵i,0变成抽到对应字符的概率,和i,cnt也变成对应概率,如果不是终止节点,那么将i和对应子节点的矩阵改成概率即可。之后矩阵乘法实现一下即可。

    附上代码:

    #include <cstdio>
    #include <cmath>
    #include <algorithm>
    #include <iostream>
    #include <queue>
    #include <cstdlib>
    #include <cstring>
    using namespace std;
    #define N 2005
    #define mod 10007
    int n,m,alpha,cnt;char s[N];
    struct node
    {
        long double a[100][100];
        friend node operator*(const node &a,const node &b)
        {
            node c;memset(c.a,0,sizeof(c.a));
            for(int i=0;i<=cnt;i++)
            {
                for(int j=0;j<=cnt;j++)
                {
                    for(int k=0;k<=cnt;k++)
                    {
                        c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]);
                    }
                }
            }
            return c;
        }
        void print()
        {
            for(int i=0;i<=cnt;i++)
            {
                for(int j=0;j<=cnt;j++)
                {
                    printf("%.5lf ",a[i][j]);
                }
                puts("");
            }
        }
    }ret,map;
    double q_pow(int n)
    {
        for(int i=0;i<=cnt;i++)ret.a[i][i]=1;
        while(n)
        {
            if(n&1)ret=ret*map;
            map=map*map;n=n>>1;
        }
        return ret.a[0][cnt];
    }
    struct Aho
    {
        int ch[N][26],fail[N],last[N],q[N],l,r,rot;
        int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;}
        void init(){cnt=0;rot=new_node();}
        void insert(char *s)
        {
            int len=strlen(s),rt=rot;
            for(int i=0;i<len;i++)
            {
                int t=s[i]-'a';
                if(ch[rt][t]==-1)ch[rt][t]=new_node();
                rt=ch[rt][t];
            }
            //printf("%d",rt);
            last[rt]=1;
        }
        void get_fail()
        {
            l=r=0;
            for(int i=0;i<alpha;i++)
            {
                if(ch[0][i]==-1)ch[0][i]=rot;
                else q[r++]=ch[0][i],fail[ch[0][i]]=rot;
                last[0]|=last[fail[0]];
            }
            while(l<r)
            {
                int x=q[l++];
                for(int i=0;i<alpha;i++)
                {
                    if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i];
                    else q[r++]=ch[x][i],fail[ch[x][i]]=ch[fail[x]][i];
                }
                last[x]|=last[fail[x]];
            }
        }
        void match()
        {
            long double addv=1.0/(1.0*alpha);
            for(int i=0;i<cnt;i++)
            {
                for(int j=0;j<alpha;j++)
                {
                    if(last[ch[i][j]])map.a[i][0]+=addv,map.a[i][cnt]+=addv;
                    else map.a[i][ch[i][j]]+=addv;
                }
            }
            map.a[cnt][cnt]=1;
        }
    }ac;
    int main()
    {
        scanf("%d%d%d",&n,&m,&alpha);ac.init();
        for(int i=1;i<=n;i++)
        {
            scanf("%s",s);ac.insert(s);
        }
        ac.get_fail();ac.match();//map.print();
        printf("%.10f
    ",q_pow(m));
        return 0;
    }

    BZOJ2938: [Poi2000]病毒

    分析:

    如果在trie树上存在一个环,那么就一定可以出现无穷的情况。判一下环是否存在即可。

    附上代码:

    #include <cstdio>
    #include <cmath>
    #include <algorithm>
    #include <iostream>
    #include <queue>
    #include <cstdlib>
    #include <cstring>
    using namespace std;
    #define N 30005
    struct Aho
    {
        int ch[N][2],fail[N],last[N],cnt,rot,vis[N],inq[N];
        int new_node(){ch[cnt][1]=ch[cnt][0]=-1;last[cnt++]=0;return cnt-1;}
        void init(){cnt=0;rot=new_node();}
        void insert(char *s)
        {
            int len=strlen(s),rt=rot;
            for(int i=0;i<len;i++)
            {
                int t=s[i]-'0';
                if(ch[rt][t]==-1)ch[rt][t]=new_node();
                rt=ch[rt][t];
            }
            last[rt]=1;
        }
        void get_fail()
        {
            queue <int>q;
            for(int i=0;i<2;i++)
            {
                if(ch[0][i]==-1)ch[0][i]=0;
                else fail[ch[0][i]]=0,q.push(ch[0][i]);
            }
            while(!q.empty())
            {
                int x=q.front();q.pop();
                for(int i=0;i<2;i++)
                {
                    if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i];
                    else fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]);
                }
                last[x]|=last[fail[x]];
            }
        }
        int match(int x)
        {
            inq[x]=vis[x]=1;
            for(int i=0;i<2;i++)
            {
                int t=ch[x][i];
                if(inq[t]||((!last[t])&&(!vis[t])&&match(t)))return 1;
            }
            inq[x]=0;
            return 0;
        }
    }ac;
    char s[N];
    int main()
    {
        int n;scanf("%d",&n);ac.init();
        for(int i=1;i<=n;i++)
        {
            scanf("%s",s);ac.insert(s);
        }
        ac.get_fail();
        if(ac.match(ac.rot))puts("TAK");
        else puts("NIE");
        return 0;
    }
    

      

    BZOJ2434: [Noi2011]阿狸的打字机

    分析:

    先将给你的串建立成AC自动机,之后单独拎出fail树,求出每个节点在fail树上的入栈出栈序维护出来,之后再遍历一遍所有节点,将每个节点对应的询问求出即可。而正确性是因为fail树的每个节点的父节点的串都是这个节点的串的后缀。

    附上代码:

    #include <cstdio>
    #include <cmath>
    #include <algorithm>
    #include <iostream>
    #include <queue>
    #include <cstdlib>
    #include <cstring>
    using namespace std;
    #define N 100005
    struct node
    {
        int to,next,val;
    }ask[N],e[N];
    int head[N],head_ask[N],cnt2,cnt1,in1[N],tims,out1[N],sum[N<<1],Q,flg[N],ans[N];char s[N];
    void add(int x,int y){e[cnt2].to=y;e[cnt2].next=head[x];head[x]=cnt2++;}
    void add_ask(int x,int y,int z){ask[cnt1].to=y;ask[cnt1].next=head_ask[x];ask[cnt1].val=z;head_ask[x]=cnt1++;}
    void fix(int x,int c){for(;x<=tims;x+=x&-x)sum[x]+=c;}
    int find(int x){int ret=0;for(;x;x-=x&-x)ret+=sum[x];return ret;}
    struct Aho
    {
        int ch[N][26],fail[N],q[N],last[N],rot,cnt,fa[N];
        int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;}
        void init(){cnt=0,rot=new_node();}
        void insert(char *s)
        {
            int len=strlen(s),rt=rot,tot=0;
            for(int i=0;i<len;i++)
            {
                if(s[i]=='B')rt=fa[rt];
                else if(s[i]=='P')last[rt]++,flg[++tot]=rt;
                else
                {
                    if(ch[rt][s[i]-'a']==-1)ch[rt][s[i]-'a']=new_node(),fa[cnt-1]=rt;
                    rt=ch[rt][s[i]-'a'];
                }
            }
        }
        void get_fail()
        {
            int l=0,r=0;
            for(int i=0;i<26;i++)
            {
                if(ch[0][i]==-1)ch[0][i]=0;
                else fail[ch[0][i]]=0,q[r++]=ch[0][i];
            }
            while(l<r)
            {
                int x=q[l++];
                for(int i=0;i<26;i++)
                {
                    if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i];
                    else fail[ch[x][i]]=ch[fail[x]][i],q[r++]=ch[x][i];
                }
            }
            for(int i=1;i<cnt;i++)add(fail[i],i);//printf("%d %d
    ",fail[i],i);
        }
        void solve(char *s)
        {
            for(int i=0,len=strlen(s),rt=0;i<len;i++)
            {
                if(s[i]=='B')fix(out1[rt],-1),rt=fa[rt];
                else if(s[i]=='P')
                {
                    for(int j=head_ask[rt];j!=-1;j=ask[j].next)
                    {
                        int to1=ask[j].to,v=ask[j].val;
                        ans[v]=find(out1[to1])-find(in1[to1]-1);
                    }
                }else
                {
                    rt=ch[rt][s[i]-'a'];
                    fix(in1[rt],1);
                }
            }
        }
    }ac;
    void dfs(int x)
    {
        in1[x]=++tims;
        for(int i=head[x];i!=-1;i=e[i].next)
        {
            dfs(e[i].to);
        }
        out1[x]=++tims;
    }
    int main()
    {
        scanf("%s%d",s,&Q);memset(head,-1,sizeof(head));memset(head_ask,-1,sizeof(head_ask));
        ac.init();ac.insert(s);ac.get_fail();dfs(0);
        for(int i=1;i<=Q;i++){int x,y;scanf("%d%d",&x,&y);add_ask(flg[y],flg[x],i);}ac.solve(s);
        for(int i=1;i<=Q;i++)printf("%d
    ",ans[i]);return 0;
     
    

      

  • 相关阅读:
    运算符
    格式化输出
    while循环
    if 判断语句
    Swift # 字典
    Swift # 数组
    Swift # 字符串
    [ Swift # 函数 ]
    [ Bubble Sort ]& block
    数据结构 # 二叉树/堆/栈
  • 原文地址:https://www.cnblogs.com/Winniechen/p/9196959.html
Copyright © 2011-2022 走看看