zoukankan      html  css  js  c++  java
  • 背单词(AC自动机+线段树+dp+dfs序)

    G. 背单词

    内存限制:256 MiB 时间限制:1000 ms 标准输入输出
    题目类型:传统 评测方式:文本比较
     

    题目描述

    给定一张包含N个单词的表,每个单词有个价值W。要求从中选出一个子序列使得其 中的每个单词是后一个单词的子串,最大化子序列中W的和。

    输入格式

    第一行一个整数TEST,表示数据组数。 接下来TEST组数据,每组数据第一行为一个整数N。 接下来N行,每行为一个字符串和一个整数W。

    输出格式

    TEST行,每行一个整数,表示W的和的最大值。

    数据规模 设字符串的总长度为Len 30.的数据满足,TEST≤5,N≤500,Len≤10^4 100.的数据满足,TEST≤10,N≤20000,Len≤3*10^5


    析:又是一道好题,这道题将AC自动机与线段树,dp,以及 dfs 序结合了起来;

      首先我们要明确这样一个事情,S是T的字串,相当于 T  的一个前缀可以通过 Fail 树遍历到 S 的末尾结点,也就是说, S 的末尾节点是 T 的某个前缀在 Fail 树上的祖先;

      那么这道题思路就清晰了,首先可以写出 dp 方程 :f[i]=max(x)+w[i] ,表示 在前 i 个单词中,当前枚举到第 i 个单词且选择它的最大值, max(x) 表示当前单词前缀的最大值;

      那么此时我们的问题就在于 1.如何求得前缀? 2.如何求得区间(单点)最大值?

      对于第一个问题,我们可以使用 fa 数组进行回溯查询:

     if(!use[now].son[p])
            {
                use[now].son[p]=++num;
                fa[use[now].son[p]]=now;
            }
    

       这是在构建字典树的过程中记录每一个字符的父亲节点,我们在计算过程中就可以:

     while(p)
            {
                f[i]=max(f[i],query(rt,1,num,l[p]));
                p=fa[p];
            }
    

       那么考虑第二个问题,区间查询,我们同时还要考虑到,每次在我们枚举一次选择的单词后,我们都要判断选或不选哪个是最优解,然后对一段区间进行更新,所以说,我们不仅需要区间查询,还需要区间修改的操作,

       看这数据范围,显然我们可以想到线段树。那么我们在查询,修改的过程中如何确定区间呢? 这里利用 dfs 序就是一种很妙的思路,我们求出 in[],与out[] 就可以知道当前单词的控制区间。

      那么问题来了,我们要构建一颗什么样的 dfs 树呢?

      显然,若考虑当前单词 i ,fail[i] ,fa[i],fail[fa[i]] ,那么很明显,fail[fa[i]] 应该是控制区间最大的那一个,所以,我们就要从每个节点的 fail 指针向当前单词 连一条边,进行 dfs ;

      这里我们再解释一下为什么是单点查询,注意题目要求:

    >  从中选出一个子序列使得其中的每个单词是后一个单词的子串

      1.假设现在有个序列 ABCD ,那么假设我的单词分别为 AB , 和 CD,那么如果我两个同时拿的话就无法满足题义

      2.若每次我们都之考虑某一个前缀的最大值,那么递推过来的一定是满足条件的最大值!!

      那么,在这颗线段树中,我们要维护的就是每个单词选或不选的最大值,所以在区间更新的时候我们都要取 max,到这里应该解释的差不多了;

     代码:

    #include<bits/stdc++.h>
    #define re register int
    using namespace std;
    const int N=3e6+10;
    int T,n,cnt,tot,rt,timi,num=1;
    char s[N];
    int w[N],ed[N],l[N],r[N],fa[N];
    int head[N],to[N<<1],next[N<<1];
    long long f[N];
    long long maxx;
    bool vis[N];
    queue<int> q;
    struct CUN
    {
        int flag,fail;
        int son[30];
        void clean()
        {
            flag=0;
            fail=0;
            memset(son,0,sizeof(son));
        }
    }use[N];
    struct C2
    {
        int lc,rc,sum,lazy;
        void clean()
        {
        	lc=0;
        	rc=0;
        	sum=0;
        	lazy=0;
        }
    }t[N];
    void in()
    {
        for(re i=0;i<=max(cnt,tot);i++)
            use[i].clean();
        for(re i=0;i<=max(cnt,tot);i++)
        	t[i].clean();
        cnt=0;
        num=1;
        maxx=0;
        tot=0;
        timi=0;
        memset(l,0,sizeof(l));
        memset(r,0,sizeof(r));
        memset(vis,0,sizeof(vis));
        memset(f,0,sizeof(f));
        memset(w,0,sizeof(w));
        memset(ed,0,sizeof(ed));
        memset(head,0,sizeof(head));
        memset(to,0,sizeof(to));
        memset(next,0,sizeof(next));
        memset(fa,0,sizeof(fa));
        while(!q.empty())
            q.pop();
    }
    void insert(char ss[],int pos)
    {
        int now=1;
        int l=strlen(ss);
        for(re i=0;i<l;i++)
        {
            int p=ss[i]-'a';
            if(!use[now].son[p])
            {
                use[now].son[p]=++num;
            	fa[use[now].son[p]]=now;
            }
            now=use[now].son[p];
        }
        ed[pos]=now;
    }
    void Add(int x,int y)
    {
        to[++tot]=y;
        next[tot]=head[x];
        head[x]=tot;
    }
    void dfs(int x)
    {
        if(x)
            l[x]=++timi;
        for(re i=head[x];i;i=next[i])
            dfs(to[i]);
        r[x]=timi;    
    }
    void build(int &p,int L,int R)
    {
        p=++cnt;
        if(L==R)
            return;
        int mid=(L+R)>>1;
        build(t[p].lc,L,mid);
        build(t[p].rc,mid+1,R);
    }
    void get_fail()
    {
        for(re i=0;i<26;i++)
            use[0].son[i]=1;
        use[1].fail=0;
        q.push(1);
        while(!q.empty())
        {
            int u=q.front();
            q.pop();
            int Fail=use[u].fail;
            for(re i=0;i<26;i++)
            {
                int v=use[u].son[i];
                if(!v)
                {
                    use[u].son[i]=use[Fail].son[i];
                    continue;
                }
                use[v].fail=use[Fail].son[i];
                q.push(v);
            }
        }
        for(re i=1;i<=num;i++)
            Add(use[i].fail,i);
        dfs(0);
        build(rt,1,num);
    }
    void pd(int p)
    {
        if(t[p].lazy==0)
            return;
        t[t[p].lc].lazy=max(t[t[p].lc].lazy,t[p].lazy);
        t[t[p].rc].lazy=max(t[t[p].rc].lazy,t[p].lazy);
        t[t[p].lc].sum=max(t[t[p].lc].sum,t[p].lazy);
        t[t[p].rc].sum=max(t[t[p].rc].sum,t[p].lazy);
        t[p].lazy=0;
    }
    void pp(int rt)
    {
        t[rt].sum=max(t[t[rt].lc].sum,t[t[rt].rc].sum);
    }
    long long query(int rt,int L,int R,int p)
    {
        if(L==R)
            return t[rt].sum;
        int mid=(L+R)>>1;
        pd(rt);
        if(p<=mid)
            return query(t[rt].lc,L,mid,p);
        return query(t[rt].rc,mid+1,R,p);
    }
    void updata(int p,int L,int R,int l,int r,int z)
    {
        if(l<=L&&R<=r)
        {
            t[p].sum=max(t[p].sum,z);
            t[p].lazy=max(t[p].lazy,z);
            return;
        }
        pd(p);
        int mid=(L+R)>>1;
        if(mid>=l)
            updata(t[p].lc,L,mid,l,r,z);
        if(mid<r)
            updata(t[p].rc,mid+1,R,l,r,z);
        pp(p);
    }
    void dp()
    {
        //f[i]=max{x}+w[i];
        for(re i=1;i<=n;i++)
        {
            int p=ed[i];
            while(p)
            {
                f[i]=max(f[i],query(rt,1,num,l[p]));
                p=fa[p];
            }
            f[i]=max(0*1ll,f[i]+w[i]);
            updata(rt,1,num,l[ed[i]],r[ed[i]],f[i]);
        }
        for(re i=1;i<=n;i++)
            maxx=max(maxx,f[i]);
        printf("%lld\n",maxx);
    }
    signed main()
    {
        scanf("%d",&T);
        while(T--)
        {
            scanf("%d",&n);
            in();
            for(re i=1;i<=n;i++)
            {
            	scanf("%s",s);
            	scanf("%d",&w[i]);
                insert(s,i);
            }
            get_fail();
            dp();
        }
    }
    
  • 相关阅读:
    javascript 学习笔记714章
    数据库设计的四个范式
    【转】utf8的中文是一个汉字占三个字节长度
    java 中文url的解决
    so动态链接库的使用
    linux常用命令
    控制台编译Qt程序
    构造函数初始化列表 组合类构造函数
    const volatile
    std::pair
  • 原文地址:https://www.cnblogs.com/WindZR/p/14881776.html
Copyright © 2011-2022 走看看