zoukankan      html  css  js  c++  java
  • 「模拟赛20180306」回忆树 memory LCA+KMP+AC自动机+树状数组

    题目描述

    回忆树是一棵树,树边上有小写字母。

    一次回忆是这样的:你想起过往,触及心底……唔,不对,我们要说题目。

    这题中我们认为回忆是这样的:给定 \(2\) 个点 \(u,v\) (\(u\) 可能等于 \(v\))和一个非空字符串 \(s\) ,问从 \(u\)\(v\) 的简单路径上的所有边按照到 \(u\) 的距离从小到大的顺序排列后,询问边上的字符依次拼接形成的字符串中给定的串 \(s\) 出现了多少次。

    输入

    第一行 \(2\) 个整数,依次为树中点的个数 \(n\) 和回忆的次数 \(m\)
    接下来 \(n-1\) 行,每行 \(2\) 个整数 \(u,v\)\(1\) 个小写字母 \(c\) ,表示回忆树的点\(u,v\)之间有一条边,边上的字符为\(c\)
    接下来 \(2m\) 行表示 \(m\) 次回忆,每次回忆 \(2\) 行:第 \(1\)\(2\) 个整数 \(u,v\),第 \(2\) 行给出回忆的字符串 \(s\)

    输出

    对于每次回忆,输出串 \(s\) 出现的次数。

    样例

    样例输入

    12 3
    1 2 w
    2 3 w
    3 4 x
    4 5 w
    5 6 w
    6 7 x
    7 8 w
    8 9 w
    9 10 x
    10 11 w
    11 12 w
    1 7
    wwx
    1 12
    www
    1 12
    w
    

    样例输出

    2
    0
    8
    

    数据范围

    \(1≤n,m≤10^5\)
    询问字符串的总长度不超过\(3\times10^5\)

    题解

    这是一道神题,做法优美而且巧妙(同时也很恶心)。

    既然是树链上的询问,就不能不让人想到利用\(LCA\)\(u\xrightarrow{}v\)的路径转化成\(u\xrightarrow{}lca\)\(lca\xrightarrow{}v\)的两条路径了。

    那么我们就可以把询问分成三部分。

    1. \(lca\xrightarrow{}u\)\(s\)的反串出现了多少次
    2. \(lca\xrightarrow{}v\)\(s\)出现了多少次
    3. 跨越\(lca\)时,\(s\)出现了多少次

    可以发现,第一部分和第二部分其实是类似的问题,我们先放一放。


    那么我们考虑第三个问题,好像没有什么很简单的方法,于是我们考虑暴力。
    很容易发现这一种情况下涉及的字符串不长,只有\(u\xrightarrow{}lca\)路径上的\(\left|s\right|\)个和\(v\xrightarrow{}lca\)路径上的\(\left|s\right|\)个。我们可以暴力取出这一段字符,然后做一次\(KMP\),这样一次的复杂度是\(O(\left|s\right|)\),总时间复杂度就是\(O(\sum\left|s\right|)\),完全可以过。


    现在就剩前两个问题了。我们发现询问串太多,一个个做显然很吃力,这时,\(AC\)自动机的方法就呼之欲出了。我们把所有询问串做成一个\(AC\)自动机,把整棵树带进去匹配即可。

    匹配的过程很简单,模拟字符串匹配的时候即可,从根开始,依次访问子树,进栈的时候答案加,出栈的时候答案减即可,然后把询问的区间标记一下,到达合适的区间就计算答案。

    但是这样还有一个问题,\(AC\)自动机上的答案是要给\(fail\)链上的所有点增加的,暴力加显然会超时。于是我们修改一下做法,预处理出\(fail\)树的先序遍历序列,然后建立树状数组(一个比较显然的性质,同一颗子树的遍历序列是连续的)。于是修改的时候单点修改,查询的时候查询\(fail\)树上的子树和即可。


    然而,这道题说起来很轻巧,却是一道码农题……并且还卡常数……卡常数!!!
    所以,我还是把我\(250\)行的代码拿出来吧……
    \(Code:\)

    #include <queue> 
    #include <vector> 
    #include <cstdio> 
    #include <cstring> 
    #include <algorithm> 
    using namespace std; 
    #define M 600005 
    queue<int>q; 
    int n, m; 
    int f[25][M], dep[M], fa[M]; 
    int L[M], R[M], ans[M], ens[M], plc[M]; 
    vector<int>B[M], E[M]; 
    char len[M], top[M], S[M]; 
    struct node 
    { 
        int fir[M], tar[M], nex[M], cnt; 
    }T1, T2; 
    void add(int a, int b, char c) 
    { 
        ++T1.cnt; 
        T1.tar[T1.cnt] = b; 
        len[T1.cnt] = c; 
        T1.nex[T1.cnt] = T1.fir[a]; 
        T1.fir[a] = T1.cnt; 
    } 
    void add(int a, int b) 
    { 
        ++T2.cnt; 
        T2.tar[T2.cnt] = b; 
        T2.nex[T2.cnt] = T2.fir[a]; 
        T2.fir[a] = T2.cnt; 
    } 
    //dfs-begin 
    void dfs(int r) 
    { 
        for (int i = T1.fir[r]; i; i = T1.nex[i]) 
        { 
            int v = T1.tar[i]; 
            if (v != fa[r]) 
            { 
                fa[v] = r; 
                dep[v] = dep[r] + 1; 
                top[v] = len[i]; 
                dfs(v); 
            } 
        } 
    } 
    //dfs-end 
    //LCA-begin 
    int LCA(int u, int v) 
    { 
        if (dep[u] < dep[v]) 
            swap(u, v); 
        int k = dep[u] - dep[v]; 
        for (int i = 20; i >= 0; i--) 
            if (k & 1 << i) 
                u = f[i][u]; 
        if (u == v) 
            return u; 
        for (int i = 20; i >= 0; i--) 
            if (f[i][u] != f[i][v]) 
                u = f[i][u], v = f[i][v]; 
        return f[0][u]; 
    } 
    int getk(int u, int k) 
    { 
        for (int i = 0; i <= 20; i++) 
            if (k & 1 << i) 
                u = f[i][u]; 
        return u; 
    } 
    //LCA-end 
    //KMP-begin 
    char K[M]; 
    int nex[M]; 
    void KMP(int a, int b, int c, int ls, int w) 
    { 
        int len = 0; 
        while (a != c) 
            K[len++] = top[a], a = fa[a]; 
        int z = dep[b] - dep[c]; 
        len += dep[b] - dep[c]; 
        while (b != c) 
            K[--len] = top[b], b = fa[b]; 
        len += z; 
        K[len] = 0; 
        nex[0] = -1; 
        int i = 0, j = -1, ans = 0; 
        while(i < ls) 
        { 
            if (j == -1 || S[i] == S[j]) 
                nex[++i] = ++j; 
            else
                j = nex[j]; 
        } 
        i = 0, j = 0; 
        while(i < len) 
        { 
            if (j == ls) 
            { 
                ans++; 
                j = nex[j]; 
                continue; 
            } 
            if(j == -1 || K[i] == S[j]) 
                i++, j++; 
            else
                j = nex[j]; 
        } 
        if (j == ls) 
            ans++; 
        ens[w] += ans; 
    } 
    //KMP-end 
    //ACTrie-begin 
    struct ACTrie 
    { 
        int nex[M][30], fail[M], in[M], out[M]; 
        int root, cnt, tim, dfn[M], id[M]; 
        int tree[M]; 
        ACTrie(){root = cnt = 1;} 
        void Insert(char *S, int w) 
        { 
            int r = root, len = strlen(S); 
            for (int i = 0; i < len; i++) 
            { 
                int val = S[i] - 'a'; 
                if (!nex[r][val]) 
                    nex[r][val] = ++cnt; 
                r = nex[r][val]; 
            } 
            plc[w] = r; 
        } 
        void Build() 
        { 
            int r = root; 
            fail[r] = r; 
            q.push(root); 
            while (!q.empty()) 
            { 
                r = q.front(); 
                q.pop(); 
                for (int i = 0; i < 26; i++) 
                { 
                    if (nex[r][i]) 
                    { 
                        int tmp = nex[fail[r]][i]; 
                        if (tmp && tmp != nex[r][i]) 
                            fail[nex[r][i]] = tmp; 
                        else
                            fail[nex[r][i]] = root; 
                        q.push(nex[r][i]); 
                    } 
                    else
                    { 
                        int tmp = nex[fail[r]][i]; 
                        if (tmp) 
                            nex[r][i] = tmp; 
                        else
                            nex[r][i] = root; 
                    } 
                } 
                if (r != root) 
                    add(fail[r], r); 
            } 
        } 
        void DFS(int r) 
        { 
            dfn[r] = ++tim; 
            in[r] = tim; 
            id[tim] = r; 
            for (int i = T2.fir[r]; i; i = T2.nex[i]) 
            { 
                int v = T2.tar[i]; 
                DFS(v); 
            } 
            out[r] = tim; 
        } 
        void Update(int x, int v) 
        { 
            for (int i = x; i <= cnt; i += i & -i) 
                tree[i] += v; 
        } 
        int Getsum(int x) 
        { 
            int ans = 0; 
            for (int i = x; i; i -= i & -i) 
                ans += tree[i]; 
            return ans; 
        } 
    }AC; 
    //ACTrie-end 
    void dfs2(int r, int now) 
    { 
        AC.Update(AC.dfn[now], 1); 
        int s = B[r].size(); 
        for (int i = 0; i < s; i++) 
            ens[(B[r][i] + 1)/ 2] -= AC.Getsum(AC.out[plc[B[r][i]]]) - AC.Getsum(AC.in[plc[B[r][i]]] - 1); 
        s = E[r].size(); 
        for (int i = 0; i < s; i++) 
            ens[(E[r][i] + 1)/ 2] += AC.Getsum(AC.out[plc[E[r][i]]]) - AC.Getsum(AC.in[plc[E[r][i]]] - 1); 
        for (int i = T1.fir[r]; i; i = T1.nex[i]) 
        { 
            int v = T1.tar[i]; 
            if (v != fa[r]) 
                dfs2(v, AC.nex[now][len[i] - 'a']); 
        } 
        AC.Update(AC.dfn[now], -1); 
    } 
    int main() 
    { 
        //freopen("memory.in", "r", stdin); 
        //freopen("memory.out", "w", stdout); 
        scanf("%d%d", &n, &m); 
        for (int i = 1; i < n; i++) 
        { 
            int a, b; 
            char c[5]; 
            scanf("%d%d%s", &a, &b, c); 
            add(a, b, c[0]); 
            add(b, a, c[0]); 
        } 
        dfs(1); 
        for (int i = 1; i <= n; i++) 
            f[0][i] = fa[i]; 
        for (int i = 1; i <= 20; i++) 
            for (int j = 1; j <= n; j++) 
                f[i][j] = f[i - 1][f[i - 1][j]]; 
        int w = 0; 
        for (int i = 1; i <= m; i++) 
        { 
            int u, v, c; 
            scanf("%d%d%s", &u, &v, S); 
            c = LCA(u, v); 
            int l1 = dep[u] - dep[c], l2 = dep[v] - dep[c], ls = strlen(S); 
            int a = getk(u, max(0, l1 - ls + 1)); 
            int b = getk(v, max(0, l2 - ls + 1)); 
            KMP(a, b, c, ls, i); 
            w++; 
            AC.Insert(S, w); 
            B[b].push_back(w); 
            E[v].push_back(w); 
            w++; 
            for (int i = 0; i < ls / 2; i++) 
                swap(S[i], S[ls - i - 1]); 
            AC.Insert(S, w); 
            B[a].push_back(w); 
            E[u].push_back(w); 
        } 
        AC.Build(); 
        AC.DFS(AC.root); 
        dfs2(1, AC.root); 
        for (int i = 1; i <= m; i++) 
            printf("%d\n", ens[i]); 
    } 
    
    
  • 相关阅读:
    #include <functional>
    3.3内联函数
    如何查看内存占用和运行速度
    属性和方法的动态绑定和限制
    __slots__节约空间
    函数进阶之一等对象
    python继承之super
    python的方法VSjava方法
    python面向对象基础(三)内置方法 __xx__
    python面向对象基础(二)反射
  • 原文地址:https://www.cnblogs.com/ModestStarlight/p/8533686.html
Copyright © 2011-2022 走看看