zoukankan      html  css  js  c++  java
  • UOJ-581 NOIP2020 字符串匹配

    Description

    给定小写字母组成的字符串 (S)。定义 (AB) 表示字符串 (A, B) 拼接,(A^n=A^{n-1}A) 表示 (A) 复制 (n) 遍。求三元组 ((A, B, C)) 的个数,满足 (S) 可以写成 ((AB)^i C) 的形式。共 (T) 组数据。

    Constraints

    (1le |S| le 2^{20}, 1le Tle 5)​。

    Solution

    首先是一个比较大众也比较好想的做法。记 (pre(i))​ 表示前缀 (i)​ 中出现奇数次的字符个数,同理对后缀定义 (suf(i))​。字符集用 (Sigma) 表示。

    考虑枚举 (AB)​​ 的长 (x)​​,那么前缀 (S[1:x])​​ 就是 (AB)​​。考虑 (S)​​ 将会由 (AB)​​ 循环若干次构成,剩下的就是 (C)​​,那么考虑找到一个最大的循环次数 (k)​​,哈希即可。找到 (k)​​ 之后,就能对于每个循环次数 (iin [1, k])​​,求出 (pre)​ 上 ([1, ix))​ 中有多少个 (le suf(ix+1))​ 就是对答案的贡献。考虑到值域是 ([0, 26])​,树状数组维护单次操作是 (O(log |Sigma|))​。直接做是 (O(Tsum_{i=1}^n frac n i) = O(Tnlog nlog |Sigma|))​ 的,想要通过比较困难。

    优化其实并不难,考虑一下两个要点:

    • (k)​ 的合法性是单调的;
    • 对于一个 (k),不需要枚举 (i),奇数偶数分开算,同为奇数或偶数贡献是一样的。

    第二个比较简单。第一个我们可以考虑二分或者倍增找到 (k)。这样的话复杂度大概是 (O(sum_{i=1}^nlog ( frac n i))approx O(n))。参考 这里

    二分我不太会保证复杂度,这里介绍一种倍增方法。

    (X=S[1:x]),其哈希值为 (H(X))。那么我们可以得到复制 (t) 倍的串的哈希值:

    [H(X^t)=sum_{i=0}^{k-1}b^{ix} H(X)=frac{b^{tx}-1}{b^x-1}H(X) ]

    其中 (b)​ 为哈希的基数。计算一个哈希值,如果使用快速幂的话,需要 (O(log tx))​ 的时间。不过如果是倍增的话,我们只需要计算 (lfloorlog_2 frac n x floor)​ 个 (b^{tx})​ 值即可,每一项等于前一项的平方。逆元直接算是 (O(log mod))​ 的,尽管每个 (x)​ 都只算一次也是不可接受的。那么只好用一个 离线求逆元的 trick,预处理所有 (x)​ 的 ((b^x-1)^{-1})​。这样复杂度就只有 (O(Tnlog |Sigma|))​ 了。不会二分是因为倍增可以预处理 (lfloorlog_2 frac n x floor)​ 个 (b^{tx}) 而二分我就不知道了。

    Code

    #include <algorithm>
    #include <cmath>
    #include <cstdio>
    #include <cstring>
    const int N = 1 << 20 | 5;
    
    typedef unsigned long long ull;
    const ull base = 19260817;
    const ull mod = 1e9 + 7;
    ull pw[N];
    
    int n;
    long long ans;
    char s[N];
    int pre[N], suf[N];
    ull hs[N];
    
    inline ull fastpow(ull a, ull b) {
      ull r = 1llu;
      for (; b; b >>= 1, (a *= a) %= mod)
        if (b & 1) (r *= a) %= mod;
      return r;
    }
    
    ull buf[N], inv[N];
    namespace inversion {
      ull pre[N], suf[N];
      void process(int n) {
        memset(pre, 0, sizeof(pre));
        memset(suf, 0, sizeof(suf));
        memset(inv, 0, sizeof(inv));
        pre[0] = suf[n + 1] = 1llu;
        for (int i = 1; i <= n; i++)
          pre[i] = pre[i - 1] * buf[i] % mod;
        for (int i = n; i >= 1; i--)
          suf[i] = suf[i + 1] * buf[i] % mod;
        ull all = fastpow(pre[n], mod - 2);
        for (int i = 1; i <= n; i++)
          inv[i] = pre[i - 1] * suf[i + 1] % mod * all % mod;
      }
    }
    
    struct bit {
      int t[28];
      inline int get(int x) {
        int r = 0;
        for (++x; x; x -= x & -x) r += t[x];
        return r;
      }
      inline void add(int x) {
        for (++x; x <= 27; x += x & -x) ++t[x];
      }
      inline void reset() {
        memset(t, 0, sizeof(t));
      }
    } tr;
    
    signed main() {
      pw[0] = 1llu;
      for (int i = 1; i < N; i++)
        pw[i] = pw[i - 1] * base % mod;
    
      int T;
      scanf("%d", &T);
      while (T--) {
        scanf("%s", s + 1);
        n = strlen(s + 1);
        ans = 0;
        tr.reset();
    
        memset(hs, 0, sizeof(hs));
        memset(pre, 0, sizeof(pre));
        memset(suf, 0, sizeof(suf));
        memset(buf, 0, sizeof(buf));
    
        for (int i = 1; i <= n; i++)
          hs[i] = (hs[i - 1] * base + s[i]) % mod;
    
        pre[0] = suf[n + 1] = 0;
        for (int i = 1, v = 0; i <= n; i++) {
          int nv = v ^ (1 << (s[i] - 'a'));
          if (nv > v) pre[i] = pre[i - 1] + 1;
          else pre[i] = pre[i - 1] - 1;
          v = nv;
        }
        for (int i = n, v = 0; i >= 1; i--) {
          int nv = v ^ (1 << (s[i] - 'a'));
          if (nv > v) suf[i] = suf[i + 1] + 1;
          else suf[i] = suf[i + 1] - 1;
          v = nv;
        }
        
        for (int x = 2; x < n; x++)
          buf[x - 1] = pw[x] - 1;
        inversion::process(n - 2);
        tr.add(pre[1]);
        for (int x = 2; x < n; x++) {
          int k = 0, maxb = log2(n / x);
          ull cst = inv[x - 1] * hs[x] % mod;
          ull fix = 0;
    
          ull tpw[maxb + 1];
          tpw[0] = pw[x];
          for (int j = 1; j <= maxb; j++)
            tpw[j] = tpw[j - 1] * tpw[j - 1] % mod;
    
          for (int j = maxb; j >= 0; j--) {
            ull cur = ((tpw[j] - 1) * cst % mod * pw[k * x] % mod + fix) % mod;
            if (cur == hs[x * (k + (1 << j))])
              k += (1 << j), fix = cur;
          }
          
          if (x * k == n) --k;
          ans += tr.get(suf[x + 1]) * ((k + 1) / 2);
          if (k > 1) ans += tr.get(suf[x * 2 + 1]) * (k / 2);
          tr.add(pre[x]);
        }
    
        printf("%lld
    ", ans);
      }
      return 0;
    }
    

    本文来自博客园,作者:-Wallace-,转载请注明原文链接:https://www.cnblogs.com/-Wallace-/p/uoj581.html

  • 相关阅读:
    SQL 触发器[1]
    SQL 存储过程[1]-常用参数及示例
    前端软件开发体系
    人工智能AI Boosting HMC Memory Chip
    先进一站式IP及定制
    BTC芯片介绍
    ONNX MLIR方法
    MLIR中间表示和编译器框架
    Non-Maximum Suppression,NMS非极大值抑制
    华为计算平台MDC810发布量产
  • 原文地址:https://www.cnblogs.com/-Wallace-/p/uoj581.html
Copyright © 2011-2022 走看看