zoukankan      html  css  js  c++  java
  • FFT/NTT字符串模糊匹配

    因为FFT精度问题太离谱了,所以墙裂推荐用NTT
    首先考虑精确匹配:https://www.acwing.com/problem/content/833/
    假设我们有短串(s1)(长度为(n)),长串(s2)(长度为(m)
    我们定义字符差

    [c(x,y) = s1(x) - s2(y) ]

    (c(x,y) = 0),表明(s1)的第(x)个字符与(s2)的第(y)个字符匹配,再定义

    [F(x) = sum_{i = 0}^{n - 1}c(i,x-n+i+1) ]

    (s2)子串的字符差之和,这个子串长为(n)并且以下标(x)为结尾,若(F(x) = 0),则表明这个子串与(s1)完全匹配,但这样可能会将(ab)(ba)算为完全匹配,因此我们考虑将(F(x))换个表达式

    [F(x) = sum_{i = 0}^{n - 1}[s1(i)-s2(x-n+i+1)]^{2} ]

    这样若(F(x) = 0),则表明这个子串与之完全匹配,将其暴力拆解

    [F(x) =sum_{i = 0}^{n - 1}s1(i)^2+sum_{i = 0}^{n - 1}s2(x-n+i+1)^2-sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1) ]

    其中(sum_{i = 0}^{n - 1}s1(i)^2)(sum_{i = 0}^{n - 1}s2(x-n+i+1)^2)都可以用前缀和解决,关键是(sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)),我们将(s1)翻转,可得(s1'(x-n+i+1)=s1(i)),即

    [sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)=sum_{i = 0}^{n - 1}2s1'(n-i-1)s2(x-n+i+1)=sum_{i+j=x}^{}s1'(i)s2(j) ]

    可以发现能用NTT啦!因此

    [F(x) = sum - S(x) + S(x-n) - 2sum_{i+j=x}^{}s1'(i)s2(j) ]

    (F(x)=0)时,表明完全匹配

    AC代码:
    不开O2会T

    #include <unordered_map>
    #include <algorithm>
    #include <iostream>
    #include <cstring>
    #include <cstdio>
    #include <vector>
    #include <string>
    #include <stack>
    #include <deque>
    #include <queue>
    #include <cmath>
    #include <map>
    #include <set>
    using namespace std;
    #pragma GCC optimize(2)
    typedef pair<int, int> PII;
    typedef pair<double, int> PDI;
    //typedef __int128 int128;
    typedef long long ll;
    typedef unsigned long long ull;
    const int INF = 0x3f3f3f3f;
    const int N = 1e7 + 10, M = 4e7 + 10;
    const int base = 1e9;
    const int P = 131;
    const int mod = 998244353;
    const double eps = 1e-9;
    const double PI = acos(-1.0);
    int n, m, tot, bit;
    char s1[N], s2[N];
    ll S[N], a[N], b[N];
    int R[N];
    ll ksm(ll a, ll b)
    {
        ll res = 1 % mod;
        while (b)
        {
            if (b & 1)
                res = res * a % mod;
            a = a * a % mod;
            b >>= 1;
        }
        return res;
    }
    void inif(int n)
    {
        tot = 1, bit = 0;
        while (tot <= n)
            tot *= 2, ++bit;
        for (int i = 0; i <= tot; ++i)
            R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }
    void NTT(ll f[], int total, int type)
    {
        for (int i = 0; i < total; ++i)
            if (i < R[i])
                swap(f[i], f[R[i]]);
        for (int tot = 2; tot <= total; tot *= 2)
        {
            ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
            //332748118为 3 在模 998244353 的逆元
            for (int pos = 0; pos < total; pos += tot)
            {
                ll w = 1;
                for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
                {
                    int x = f[i];
                    int y = w * f[i + tot / 2] % mod;
                    f[i] = (x + y) % mod;
                    f[i + tot / 2] = (x - y + mod) % mod;
                }
            }
        }
        if (type == -1)
        {
            int inv = ksm(tot, mod - 2);
            for (int i = 0; i <= n + m; ++i)
                a[i] = a[i] * inv % mod;
        }
    }
    
    int main()
    {
        scanf("%d%s%d%s", &n, &s1, &m, &s2);
        for (int i = 0; i < n; ++i)
            a[i] = s1[i] - 'a' + 1;
        for (int i = 0; i < m; ++i)
            b[i] = s2[i] - 'a' + 1;
        reverse(a, a + n);
        ll sum = 0;
        for (int i = 0; i < n; ++i)
            sum = (sum + a[i] * a[i] % mod) % mod;
        S[0] = b[0] * b[0];
        for (int i = 1; i < m; ++i)
            S[i] = (S[i - 1] + b[i] * b[i] % mod) % mod;
        inif(n + m);
        NTT(a, tot, 1), NTT(b, tot, 1);
        for (int i = 0; i < tot; ++i)
            a[i] = a[i] * b[i] % mod;
        NTT(a, tot, -1);
        for (int x = n - 1; x < m; ++x)
        {
            double P = (sum + S[x] - S[x - n] - 2 * a[x]) % mod;
            if (P == 0)
                printf("%d ", x - n + 1);
        }
        return 0;
    }
    

    接着我们考虑模糊匹配,即有通配符的情况:https://www.luogu.com.cn/problem/P4173
    设通配符的值为0,重新定义字符差

    [c(x,y) = [s1(x) - s2(y)]^2s1(x)s2(y) ]

    发现会完美解决问题,依然暴力拆解

    [F(x) = sum_{i = 0}^{n - 1}[s1(i)-s2(x-n+i+1)]^{2}s1(i)s2(x-n+i+1)\ =[sum_{i = 0}^{n - 1}s1(i)^2+sum_{i = 0}^{n - 1}s2(x-n+i+1)^2-sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)]s1(i)s2(x-n+i+1)\ =sum_{i = 0}^{n - 1}s1(i)^3s2(x-n+i+1)+sum_{i = 0}^{n - 1}s1(i)s2(x-n+i+1)^3-sum_{i = 0}^{n - 1}2s1(i)^2s2(x-n+i+1)^2\ =sum_{i+j=x}^{}s1'(i)^3s2(j)+sum_{i+j=x}^{}s1'(i)s2(j)^3+sum_{i+j=x}^{}s1'(i)^2s2(j)^2 ]

    (F(x)=0)时,表明完全匹配

    AC代码:

    #include <unordered_map>
    #include <algorithm>
    #include <iostream>
    #include <cstring>
    #include <cstdio>
    #include <vector>
    #include <string>
    #include <stack>
    #include <deque>
    #include <queue>
    #include <cmath>
    #include <map>
    #include <set>
    using namespace std;
    typedef pair<int, int> PII;
    typedef pair<double, int> PDI;
    //typedef __int128 int128;
    typedef long long ll;
    typedef unsigned long long ull;
    const int INF = 0x3f3f3f3f;
    const int N = 1e7 + 10, M = 4e7 + 10;
    const int base = 1e9;
    const int P = 131;
    const int mod = 998244353;
    const double eps = 1e-9;
    const double PI = acos(-1.0);
    int n, m;
    int A[N], B[N];
    char s1[N], s2[N];
    int R[N], ans[N];
    int tot, bit, pos;
    ll a[N], b[N], p[N];
    ll ksm(ll a, ll b)
    {
    	ll res = 1 % mod;
    	while (b)
    	{
    		if (b & 1)
    			res = res * a % mod;
    		a = a * a % mod;
    		b >>= 1;
    	}
    	return res;
    }
    void inif(int n)
    {
    	tot = 1, bit = 0;
    	while (tot <= n)
    		tot *= 2, ++bit;
    	for (int i = 0; i <= tot; ++i)
    		R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }
    void NTT(ll f[], int total, int type)
    {
    	for (int i = 0; i < total; ++i)
    		if (i < R[i])
    			swap(f[i], f[R[i]]);
    	for (int tot = 2; tot <= total; tot *= 2)
    	{
    		ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
    		//332748118为 3 在模 998244353 的逆元
    		for (int pos = 0; pos < total; pos += tot)
    		{
    			ll w = 1;
    			for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
    			{
    				int x = f[i];
    				int y = w * f[i + tot / 2] % mod;
    				f[i] = (x + y) % mod;
    				f[i + tot / 2] = (x - y + mod) % mod;
    			}
    		}
    	}
    	if (type == -1)
    	{
    		int inv = ksm(tot, mod - 2);
    		for (int i = 0; i <= n + m; ++i)
    			a[i] = a[i] * inv % mod;
    	}
    }
    int main()
    {
    	scanf("%d%d%s%s", &n, &m, &s1, &s2);
    	reverse(s1, s1 + n);
    	for (int i = 0; i < n; ++i)
    		A[i] = s1[i] == '*' ? 0 : s1[i] - 'a' + 1;
    	for (int i = 0; i < m; ++i)
    		B[i] = s2[i] == '*' ? 0 : s2[i] - 'a' + 1;
    	inif(n + m);
    	//A[i]^3 B[i]
    	for (int i = 0; i < tot; ++i)
    		a[i] = A[i] * A[i] * A[i];
    	for (int i = 0; i < tot; ++i)
    		b[i] = B[i];
    	NTT(a, tot, 1), NTT(b, tot, 1);
    	for (int i = 0; i < tot; ++i)
    		p[i] = (p[i] + a[i] * b[i]) % mod;
    	//A[i] B[i]^3
    	for (int i = 0; i < tot; ++i)
    		a[i] = A[i];
    	for (int i = 0; i < tot; ++i)
    		b[i] = B[i] * B[i] * B[i];
    	NTT(a, tot, 1), NTT(b, tot, 1);
    	for (int i = 0; i < tot; ++i)
    		p[i] = (p[i] + a[i] * b[i]) % mod;
    	//A[i]^2 B[i]^2
    	for (int i = 0; i < tot; ++i)
    		a[i] = A[i] * A[i];
    	for (int i = 0; i < tot; ++i)
    		b[i] = B[i] * B[i];
    	NTT(a, tot, 1), NTT(b, tot, 1);
    	for (int i = 0; i < tot; ++i)
    		p[i] = (p[i] - 2 * a[i] * b[i] + mod) % mod;
    
    	NTT(p, tot, -1);
    	for (int i = n - 1; i < m; ++i)
    		if (p[i] == 0)
    			ans[++pos] = i - n + 2;
    
    	printf("%d
    ", pos);
    	for (int i = 1; i <= pos; ++i)
    		printf("%d ", ans[i]);
    	return 0;
    }
    

    然后是杭电多校让我知道了这个知识点
    HDU6975:https://acm.hdu.edu.cn/showproblem.php?pid=6975
    因为字符只包含0-9和,首先不考虑通配符,我们可以枚举0-9,将每个子串在0-9情况下的匹配数算出来,以8为例,将所有为8的地方值设为1,其他地方值设为0,则对单个字符的匹配数有

    [F(x)=sum_{i=0}^{n-1}s1(i)s2(x-n+1+i)=sum_{i=0}^{n-1}s1(n-i-1)s2(x-n+i+1)=sum_{i+j=x}s1(i)s2(j) ]

    求出每个子串的匹配数后就可以考虑通配符了,其实通配符匹配数=(s1)通配符数+(s2)子串通配符数-(s1)(s2)子串相同位置的通配符数,前缀和加卷积即可求出

    AC代码:

    #include <unordered_map>
    #include <algorithm>
    #include <iostream>
    #include <cstring>
    #include <cstdio>
    #include <vector>
    #include <string>
    #include <stack>
    #include <deque>
    #include <queue>
    #include <cmath>
    #include <map>
    #include <set>
    using namespace std;
    typedef pair<int, int> PII;
    typedef pair<double, int> PDI;
    //typedef __int128 int128;
    typedef long long ll;
    typedef unsigned long long ull;
    const int INF = 0x3f3f3f3f;
    const int N = 1e6 + 10, M = 4e7 + 10;
    const int base = 1e9;
    const int P = 131;
    const int mod = 998244353;
    const double eps = 1e-9;
    const double PI = acos(-1.0);
    FILE *fp;
    int n, m, tot, bit;
    char s1[N], s2[N];
    int R[N], ans[N];
    ll a[N], b[N], f[N], S[N];
    ll ksm(ll a, ll b)
    {
    	ll res = 1 % mod;
    	while (b)
    	{
    		if (b & 1)
    			res = res * a % mod;
    		a = a * a % mod;
    		b >>= 1;
    	}
    	return res;
    }
    void inif(int n)
    {
    	memset(s1, 0, sizeof(s1));
    	memset(s2, 0, sizeof(s2));
    	memset(ans, 0, sizeof(ans));
    	memset(f, 0, sizeof(f));
    	tot = 1, bit = 0;
    	while (tot <= n)
    		tot *= 2, ++bit;
    	for (int i = 0; i <= tot; ++i)
    		R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }
    void NTT(ll f[], int total, int type)
    {
    	for (int i = 0; i < total; ++i)
    		if (i < R[i])
    			swap(f[i], f[R[i]]);
    	for (int tot = 2; tot <= total; tot *= 2)
    	{
    		ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
    		//332748118? 3 ?? 998244353 ???
    		for (int pos = 0; pos < total; pos += tot)
    		{
    			ll w = 1;
    			for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
    			{
    				int x = f[i];
    				int y = w * f[i + tot / 2] % mod;
    				f[i] = (x + y) % mod;
    				f[i + tot / 2] = (x - y + mod) % mod;
    			}
    		}
    	}
    	if (type == -1)
    	{
    		int inv = ksm(tot, mod - 2);
    		for (int i = 0; i <= n + m; ++i)
    			f[i] = f[i] * inv % mod;
    	}
    }
    void get(char c, int type)
    {
    	for (int i = 0; i < tot; ++i)
    		a[i] = s1[i] == c;
    	for (int i = 0; i < tot; ++i)
    		b[i] = s2[i] == c;
    	NTT(a, tot, 1), NTT(b, tot, 1);
    	for (int i = 0; i < tot; ++i)
    	{
    		if (type == 1)
    			f[i] = (f[i] + a[i] * b[i] % mod) % mod;
    		else
    			f[i] = (f[i] - a[i] * b[i] % mod + mod) % mod;
    	}
    }
    int main()
    {
    	int T;
    	scanf("%d", &T);
    	while (T--)
    	{
    		scanf("%d%d", &m, &n);
    		inif(n + m);
    		scanf("%s%s", s2, s1);
    		reverse(s1, s1 + n);
    
    		for (char c = '0'; c <= '9'; ++c)
    			get(c, 1);
    		get('*', -1);
    		NTT(f, tot, -1);
    		ll sum = 0;
    		for (int i = 0; i < n; ++i)
    			sum += s1[i] == '*';
    		S[0] = s2[0] == '*';
    		for (int i = 1; i < m; ++i)
    			S[i] = (S[i - 1] + (s2[i] == '*')) % mod;
    		for (int i = 0; i < tot; ++i)
    		{
    			if (i >= n)
    				f[i] = (f[i] + sum + S[i] - S[i - n] + mod) % mod;
    			else
    				f[i] = (f[i] + sum + S[i]) % mod;
    		}
    		for (int i = n - 1; i < m; ++i)
    			++ans[n - f[i]];
    		for (int i = 0; i <= n; ++i)
    		{
    			if (i)
    				ans[i] += ans[i - 1];
    			printf("%d
    ", ans[i]);
    		}
    	}
    	return 0;
    }
    
  • 相关阅读:
    小谢第18问:如何让element-ui的弹出框每次显示的时候初始化,重新加载元素?
    小谢第7问:js前端如何实现大文件分片上传、上传进度、终止上传以及删除服务器文件?
    小谢第36问:elemet
    小谢第35问:已经 git commit 的代码怎么回退到本地
    小谢第34问:vue中路由传参params 和 query区别
    小谢第33问:获取对象所有的属性值
    小谢第32问:git 可视化管理工具
    小谢第31问:git拉取所有分支
    小谢第30问:get拼接字符串常用接口含义
    小谢第29问:Vue项目打包部署到服务器上,调接口就报js,css 文件404
  • 原文地址:https://www.cnblogs.com/xiaopangpangdehome/p/15080759.html
Copyright © 2011-2022 走看看