zoukankan      html  css  js  c++  java
  • 2019牛客多校第四场 I题 后缀自动机_后缀数组_求两个串de公共子串的种类数

    @

    求若干个串的公共子串个数相关变形题

    • 牛客这题题意大概是求一个长度为(2e5)的字符串有多少个不同子串,若(s==t)(s==rev(t))则认为子串(s,t)相同。我们知道回文串肯定和他的反串相同。
    • 链接:传送门

    做法1:

    • (yx)大佬秒出思路%%,对(s)串建后缀自动机,可以得到串(s)本质不同的子串的个数(all),然后只要能减去有多少个串(x)(rev(x))同时也出现了即可。
    • 考虑先求出(s)(rev(s))的本质不同的公共子串的数量(res),串(s)本质不同的回文串数量为(q),显然(res-q)肯定是(2)的倍数。求回文串数量是个板子题:here
    • 因为(s)(rev(s))本质不同的公共子串除了回文串,就只有非回文串且(x=rev(x))的串了。又因为(x)(rev(x))只能算一次贡献,所以最后答案就是(all-frac {res-q} 2)
    • 所以我们现在只要能求出串(t)与串(s)的公共子串种类数量即可。(还有一种题是求长度至少为k的公共子串数量

    做法2:

    广义后缀自动机直接求即可。

    用普通后缀自动机也有更简单做法,我在第一个做法下面有讲解。

    做法3:

    后缀数组


    对一个串建后缀自动机,另一个串在上面跑同时计数

    • 构建好(s)串的后缀自动机后,从根节点开始用(t)串在上面匹配,记录一下已经匹配的(lcs)长度(LEN)。若(u)节点有(t[i])这个后继,则(u)跳到(nex[u][t[i]-'a'],LEN++);如果没有这个后继,就从(u)开始沿着后缀连接树向上走直到碰到一个节点(x)(t[i])这个后继或者到了根节点(x),则(u = nex[x][t[i]-'a'],LEN=len[x]+1)
    • 算贡献就是我当前在(u)节点,(lcs)长度为(len),那么(LEN-len[link[u]])就是符合条件的子串。但是这不完全,就是如果(len[link[u]])也大于(0)的话,那么他的父亲状态(link[u])是有符合条件的子串,而且符合条件的子串的数量是固定的:(len[u]-len[link[u]])
    • 听说如果你每次走后缀连接树算完所有贡献的话会(tle),一个优化就是匹配结束后,逆拓扑排序更新父亲结点的出现次数。像线段树一样用一个(lazy)标记记录它是否需要更新,要记得把(lazy)标记向父亲上传。
    • 但是这样不够,因为还有一部分贡献没有计算,你可能多次匹配到自动机上的一个节点,我们需要记录一下匹配到每个节点的最长(lcs)长度即(vis[u]),若(vis[u])等于(0),则贡献如上,反之贡献为(LEN-vis[u]),最后更新(vis[u])(LEN)
    • 本题结束。

    其实还有一个更简单的方法,把串(s)和串(rev(s))用一个没有出现过得字符拼接起来,求出新字符串的本质不同的子串个数(x),我们知道包含那个未出现过字符的子串数量为(y = (Len+1) imes (Len+1)),(注意串((ba))和串((ab))只能计一个贡献)然后在求出(s)本质不同的回文串个数(p),答案就是(frac{x-y+p}2)

    #pragma comment(linker, "/STACK:102400000,102400000")
    #include<bits/stdc++.h>
    #define fi first
    #define se second
    #define endl '
    '
    #define o2(x) (x)*(x)
    #define BASE_MAX 30
    #define mk make_pair
    #define eb emplace_back
    #define all(x) (x).begin(), (x).end()
    #define clr(a, b) memset((a),(b),sizeof((a)))
    #define iis std::ios::sync_with_stdio(false); cin.tie(0)
    #define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
    using namespace std;
    #pragma optimize("-O3")
    typedef long long LL;
    typedef pair<int, int> pii;
    inline LL read() {
        LL x = 0;int f = 0;
        char ch = getchar();
        while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
        while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x = f ? -x : x;
    }
    inline void write(LL x) {
        if (x == 0) {putchar('0'), putchar('
    ');return;}
        if (x < 0) {putchar('-');x = -x;}
        static char s[23];
        int l = 0;
        while (x != 0)s[l++] = x % 10 + 48, x /= 10;
        while (l)putchar(s[--l]);
        putchar('
    ');
    }
    int lowbit(int x) { return x & (-x); }
    template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
    template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
    template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
    template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
    void debug_out() { cerr << '
    '; }
    template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
    #define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);
    
    #define print(x) write(x);
    
    const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
    const int HMOD[] = {1000000009, 1004535809};
    const LL BASE[] = {1572872831, 1971536491};
    const int mod = 998244353;
    const int MOD = 1e9 + 7;
    const int INF = 0x3f3f3f3f;
    const int MXN = 1e6 + 7;
    
    int n;
    char s[MXN], t[MXN];
    LL all, ANS;
    int vis[MXN], lazy[MXN];
    struct Palindromic_Tree {
        static const int MAXN = 600005 ;
        static const int CHAR_N = 26 ;
        int next[MAXN][CHAR_N];//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
        int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点
        int cnt[MAXN];
        int num[MAXN];
        int len[MAXN];//len[i]表示节点i表示的回文串的长度
        int S[MAXN];//存放添加的字符
        int last;//指向上一个字符所在的节点,方便下一次add
        int n;//字符数组指针
        int p;//节点指针
        int pos[MAXN];
        int newnode(int l) {//新建节点
            for (int i = 0; i < CHAR_N; ++i) next[p][i] = 0;
            cnt[p] = 0;
            num[p] = 0;
            len[p] = l;
            return p++;
        }
        void init() {//初始化
            p = 0;
            newnode(0);
            newnode(-1);
            last = 0;
            n = 0;
            S[n] = -1;//开头放一个字符集中没有的字符,减少特判
            fail[0] = 1;
        }
        int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
            while (S[n - len[x] - 1] != S[n]) x = fail[x];
            return x;
        }
        void add(int c, int id) {
            c -= 'a';
            S[++n] = c;
            int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
            if (!next[cur][c]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
                int now = newnode(len[cur] + 2);//新建节点
                fail[now] = next[get_fail(fail[cur])][c];//和AC自动机一样建立fail指针,以便失配后跳转
                next[cur][c] = now;
                num[now] = num[fail[now]] + 1;
            }
            last = next[cur][c];
            cnt[last] ++;
            pos[last] = id;
        }
        void count() {
            for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i];
            //父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
        }
    } pt;
    struct Suffix_Automaton {
        static const int maxn = 1e6 + 105;
        static const int MAXN = 1e6 + 5;
        //basic
    //    map<char,int> nex[maxn * 2];
        int nex[maxn*2][26];
        int link[maxn * 2], len[maxn * 2];
        int last, cnt;
        LL tot_c;//不同串的个数
        //extension
        int cntA[MAXN * 2], A[MAXN * 2];/*辅助拓扑更新*/
        int nums[MAXN * 2];/*每个节点代表的所有串的出现次数*/
        void clear() {
            tot_c = 0;
            last = cnt = 1;
            link[1] = len[1] = 0;
            memset(nex[1], 0, sizeof(nex[1]));
        }
        void init_str(char *s) {
            while (*s) {
                add(*s - 'a');
                ++ s;
            }
        }
        void add(int c) {
            int p = last;
            int np = ++cnt;
    //        nex[cnt].clear();
            memset(nex[cnt], 0, sizeof(nex[cnt]));
            len[np] = len[p] + 1;
            last = np;
            while (p && !nex[p][c])nex[p][c] = np, p = link[p];
            if (!p)link[np] = 1, tot_c += len[np] - len[link[np]];
            else {
                int q = nex[p][c];
                if (len[q] == len[p] + 1)link[np] = q, tot_c += len[np] - len[link[np]];
                else {
                    int nq = ++cnt;
                    len[nq] = len[p] + 1;
    //                nex[nq] = nex[q];
                    memcpy(nex[nq], nex[q], sizeof(nex[q]));
                    link[nq] = link[q];
                    link[np] = link[q] = nq;
                    tot_c += len[np] - len[link[np]];
                    while (nex[p][c] == q)nex[p][c] = nq, p = link[p];
                }
            }
        }
        void build(int n) {
            memset(cntA, 0, sizeof cntA);
            memset(nums, 0, sizeof nums);
            for (int i = 1; i <= cnt; i++)cntA[len[i]]++;
            for (int i = 1; i <= n; i++)cntA[i] += cntA[i - 1];
            for (int i = cnt; i >= 1; i--)A[cntA[len[i]]--] = i;
            /*更行主串节点*/
            int temps = 1;
            for (int i = 0; i < n; i++) {
                nums[temps = nex[temps][s[i] - 'a']] = 1;
            }
            for (int i = cnt, x; i >= 1; i--) {
                x = A[i];
                nums[link[x]] += nums[x];
            }
        }
        void query() {
            int u = 1, LEN = 0;
            for(int i = 0; i < n; ++i) {
                if(nex[u][t[i]-'a']) {
                    u = nex[u][t[i]-'a'];
                    ++ LEN;
                }else {
                    while (u && nex[u][t[i] - 'a'] == 0) u = link[u];
                    if (u == 0) u = 1, LEN = 0;
                    else {
                        LEN = len[u] + 1;
                        u = nex[u][t[i] - 'a'];
                    }
                }
                if(vis[u] == 0) {
                    ANS += 1 * (LEN - len[link[u]]);
    //                debug(i, t[i], LEN - len[link[u]])
                    if (len[link[u]]) lazy[link[u]] = 1;
                    vis[u] = LEN;
                }else if(LEN > vis[u]) {
                    ANS += 1 * (LEN - vis[u]);
    //                debug(i, t[i], LEN - vis[u])
                    vis[u] = LEN;
                }
            }
            for(int i = cnt, x; i >= 1; --i) {
                x = A[i];
                if(vis[x] == 0 && len[x] && lazy[x]) {
                    ANS += len[x] - len[link[x]];
                    vis[x] = len[x];
                    if(len[link[x]]) lazy[link[x]] = 1;
                }else if(lazy[x] && vis[x] < len[x]) {
                    ANS += len[x] - vis[x];
                    vis[x] = len[x];
                    if(len[link[x]]) lazy[link[x]] = 1;
                }
                if(len[link[x]]) lazy[link[x]] = 1;
            }
        }
        void DEBUG() {
            for (int i = cnt; i >= 1; i--) {
                printf("nums[%d]=%d numt[%d]=%d len[%d]=%d link[%d]=%d
    ", i, nums[i], i, nums[i], i, len[i], i, link[i]);
            }
        }
    } sam;
    
    int main() {
    #ifndef ONLINE_JUDGE
        freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
        //freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
    #endif
    //    int tim = read();
        scanf("%s", s);
        memcpy(t, s, sizeof(s));
        n = strlen(s);
        reverse(t, t + n);
        sam.clear();
        sam.init_str(s);
        all = sam.tot_c;
        sam.build(n);
        sam.query();
        pt.init();
        for(int i = 0; i < n; ++i) pt.add(s[i], i);
        int hui = pt.p - 2;
        debug(n, hui, all, ANS)
        printf("%lld
    ", all - (ANS - hui) / 2);
    #ifndef ONLINE_JUDGE
        cout << "time cost:" << clock() << "ms" << endl;
    #endif
        return 0;
    }
    

    广义后缀自动机

    • 直接离线构建广义后缀自动机(插入函数和普通后缀自动机一模一样),先插入(s)串,置(last=1),再插入(rev(s)),然后对这个后缀自动机求出本质不同的子串个数(all)(回文串只计算一次贡献,其他串计算了两次,因为(x=rev(x))),设(p)表示(s)串本质不同的回文串个数,最后答案即为(frac{all+p}2)

    后缀数组


    其他:POJ 3415 求两个串长度至少为k的公共子串数量

    本题不需要去重。可后缀数组也可后缀自动机写。

    后缀自动机
    解法和牛客那题基本一样,甚至更简单,因为本题不需要去重,是算总数。
    不需要记录每个节点被匹配到的(lcs)长度,因此当前节点每次被匹配到的贡献都是(LEN-max(len[link[u]],k-1))
    因为是算所有子串的数量,只需要用(lazy[])标记表示这个节点被匹配到的次数,最后逆拓扑序向上传(lazy[])标记即可。

    后缀数组
    按套路,把(s,t)拼成一个串,两遍单调栈,分别算(t)串对(s)串的贡献和(s)串对(t)串的贡献

    #pragma comment(linker, "/STACK:102400000,102400000")
    //#include<bits/stdc++.h>
    #include<cstdio>
    #include<cstring>
    #include<string>
    #include<vector>
    #include<stack>
    #include<map>
    #include<iostream>
    #include<assert.h>
    #define fi first
    #define se second
    #define endl '
    '
    #define o2(x) (x)*(x)
    #define BASE_MAX 30
    #define mk make_pair
    #define eb emplace_back
    #define all(x) (x).begin(), (x).end()
    #define clr(a, b) memset((a),(b),sizeof((a)))
    #define iis std::ios::sync_with_stdio(false); cin.tie(0)
    #define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
    using namespace std;
    #pragma optimize("-O3")
    typedef long long LL;
    typedef pair<int, int> pii;
    inline LL read() {
        LL x = 0;int f = 0;
        char ch = getchar();
        while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
        while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x = f ? -x : x;
    }
    inline void write(LL x) {
        if (x == 0) {putchar('0'), putchar('
    ');return;}
        if (x < 0) {putchar('-');x = -x;}
        static char s[23];
        int l = 0;
        while (x != 0)s[l++] = x % 10 + 48, x /= 10;
        while (l)putchar(s[--l]);
        putchar('
    ');
    }
    int lowbit(int x) { return x & (-x); }
    template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
    //template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
    //template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
    //template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
    //void debug_out() { cerr << '
    '; }
    //template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
    //#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);
    
    #define print(x) write(x);
    
    const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
    const int HMOD[] = {1000000009, 1004535809};
    const LL BASE[] = {1572872831, 1971536491};
    const int mod = 998244353;
    const int MOD = 1e9 + 7;
    const int INF = 0x3f3f3f3f;
    const int MXN = 2e5 + 7;
    
    int n, m, k;
    LL ANS;
    char s[MXN], t[MXN];
    LL lazy[MXN];
    struct Suffix_Automaton {
        static const int maxn = 2e5 + 105;
        static const int MAXN = 2e5 + 5;
        //basic
    //    map<char,int> nex[maxn * 2];
        int nex[maxn][58];
        int link[maxn * 2], len[maxn * 2];
        int last, cnt;
        LL tot_c;//不同串的个数
        //extension
        int cntA[MAXN * 2], A[MAXN * 2];/*辅助拓扑更新*/
        int nums[MAXN * 2];/*每个节点代表的所有串的出现次数*/
        void clear() {
            tot_c = 0;
            last = cnt = 1;
            link[1] = len[1] = 0;
    //        nex[1].clear();
            memset(nex[1], 0, sizeof(nex[1]));
        }
        void init_str(char *s) {
            while (*s) {
                add(*s - 'A');
                ++ s;
            }
        }
        void add(int c) {
            int p = last;
            int np = ++cnt;
    //        nex[cnt].clear();
            memset(nex[cnt], 0, sizeof(nex[cnt]));
            len[np] = len[p] + 1;
            last = np;
            while (p && !nex[p][c])nex[p][c] = np, p = link[p];
            if (!p)link[np] = 1, tot_c += len[np] - len[link[np]];
            else {
                int q = nex[p][c];
                if (len[q] == len[p] + 1)link[np] = q, tot_c += len[np] - len[link[np]];
                else {
                    int nq = ++cnt;
                    len[nq] = len[p] + 1;
    //                nex[nq] = nex[q];
                    memcpy(nex[nq], nex[q], sizeof(nex[q]));
                    link[nq] = link[q];
                    link[np] = link[q] = nq;
                    tot_c += len[np] - len[link[np]];
                    while (nex[p][c] == q)nex[p][c] = nq, p = link[p];
                }
            }
        }
        void build(int n) {
            for(int i = 0; i <= cnt; ++i) nums[i] = cntA[i] = 0;
            for (int i = 1; i <= cnt; i++) cntA[len[i]]++;
            for (int i = 1; i <= n; i++)cntA[i] += cntA[i - 1];
            for (int i = cnt; i >= 1; i--)A[cntA[len[i]]--] = i;
            /*更行主串节点*/
            int temps = 1;
            for (int i = 0; i < n; i++) {
                nums[temps = nex[temps][s[i] - 'A']] = 1;
            }
            for (int i = cnt, x; i >= 1; i--) {
                x = A[i];
                nums[link[x]] += nums[x];
            }
        }
        void query() {
            int u = 1, LEN = 0;
            for(int i = 0; i < m; ++i) {
                if(nex[u][t[i]-'A']) {
                    u = nex[u][t[i]-'A'];
                    ++ LEN;
                }else {
                    while (u && nex[u][t[i] - 'A'] == 0) u = link[u];
                    if (u == 0) u = 1, LEN = 0;
                    else {
                        LEN = len[u] + 1;
                        u = nex[u][t[i] - 'A'];
                    }
                }
                if(LEN >= k) {
                    ANS += (LL)nums[u] * (LEN - big(len[link[u]], k - 1));
                    if (len[link[u]]) lazy[link[u]] ++;
                }
            }
            for(int i = cnt, x; i >= 1; --i) {
                x = A[i];
                if(len[x] >= k && lazy[x]) {
                    ANS += lazy[x] * nums[x] * (len[x] - big(len[link[x]], k - 1));
                    if(len[link[x]]) lazy[link[x]] += lazy[x];
                }
            }
        }
        void DEBUG() {
            for (int i = cnt; i >= 1; i--) {
                printf("nums[%d]=%d numt[%d]=%d len[%d]=%d link[%d]=%d
    ", i, nums[i], i, nums[i], i, len[i], i, link[i]);
            }
        }
    } sam;
    
    int main() {
    #ifndef ONLINE_JUDGE
        freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
        //freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
    #endif
        while(~scanf("%d", &k) && k) {
            scanf("%s%s", s, t);
            n = strlen(s), m = strlen(t);
            sam.clear();
            sam.init_str(s);
            sam.build(n);
            ANS = 0;
            sam.query();
            for(int i = 0; i <= 2 * n + 5; ++i) lazy[i] = 0;
            printf("%lld
    ", ANS);
        }
    #ifndef ONLINE_JUDGE
        cout << "time cost:" << clock() << "ms" << endl;
    #endif
        return 0;
    }
    
  • 相关阅读:
    Python pip配置国内源
    【VLC】VLC命令行参数
    发个在owasp上演讲web应用防火墙的ppt
    Tips of Linux C programming
    linux程序调试
    scrapy结合webkit抓取js生成的页面
    Using Internet Explorer from .NET
    http长连接200万尝试及调优
    nginx url解码引发的waf漏洞
    poj 2513
  • 原文地址:https://www.cnblogs.com/Cwolf9/p/11257007.html
Copyright © 2011-2022 走看看