题目大意:
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
题解:
为了解决这个问题,首先我们需要掌握后缀自动机的两个性质:
- 每个串s代表的串的长度是区间((len_{fa},len_s])
- 每个状态代表的所有串在原串中的出现次数及每次出现的右端点相同
所以我们知道,如果一个串可以表达s代表的所有串,那么匹配次数即为((len_s - len_{fa})*|right_s|)
所以还需要记录一下right集合的大小.
我们设(f(i))表示如果包含了s表达的所有串,那么匹配次数增加多少.
显然有(f(i) = (len_s - len_{fa})*|right_s|)可以(O(n))求出
然后我们又发现:如果我们完全包括了一个状态,那么我们也一定会包括其(parent)树中的所有祖先状态.
所以我们也应该加上祖先状态的f值,所以我们对(f(i))求一个前缀和
即(f(i) = f(fa_i) + (len_s - len_{fa})*|right_s|)
然后我们把串扔上去跑就好了啊
还有一个问题!!!
如果我们没有完全匹配到一个状态,就不能直接应用f的值
记当前匹配的长度为L,如果出现(L in (len_{fa},len_s))我们就不能使用f
因为我们没有包含长度为((L,len_s))的串
所以我们手动计算一下((L - len_{fa})*|right_s|)累加到答案
然后加上(fa)的f即可..
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
inline void read(ll &x){
x=0;char ch;bool flag = false;
while(ch=getchar(),ch<'!');if(ch == '-') ch=getchar(),flag = true;
while(x=10*x+ch-'0',ch=getchar(),ch>'!');if(flag) x=-x;
}
const ll maxn = 200010;
struct Node{
ll nx[26];
ll len,fa,cnt,f;
}T[maxn<<1];
ll last,nodecnt;
inline void init(){
last = nodecnt = 0;
T[0].len = T[0].cnt = T[0].f = 0;
T[0].fa = -1;
}
void insert(ll c){
ll cur = ++ nodecnt,p;
T[cur].len = T[last].len + 1;T[cur].cnt = 1;
for(p = last;p != -1 && !T[p].nx[c];p = T[p].fa) T[p].nx[c] = cur;
if(p == -1) T[cur].fa = 0;
else{
ll q = T[p].nx[c];
if(T[q].len == T[p].len + 1) T[cur].fa = q;
else{
ll co = ++ nodecnt;
T[co].len = T[p].len + 1;T[co].fa = T[q].fa;
for(ll k=0;k<26;++k) T[co].nx[k] = T[q].nx[k];
for(;p != -1 && T[p].nx[c] == q;p = T[p].fa) T[p].nx[c] = co;
T[cur].fa = T[q].fa = co;
}
}last = cur;
}
ll q[maxn<<1],sum[maxn<<1];
inline void Sort(){
for (ll i=0;i<=nodecnt;i++) sum[T[i].len]++;
for (ll i=1;i<=T[last].len;i++) sum[i]+=sum[i-1];
for (ll i=nodecnt;i>=0;--i) q[sum[T[i].len]--]=i;
}
inline void build(char *s){
init();ll len = strlen(s);
for(ll i=0;i<len;++i) insert(s[i]-'a');
}
char s[maxn];
int main(){
scanf("%s",s);build(s);Sort();
for(ll i=nodecnt+1;i>1;--i) T[T[q[i]].fa].cnt += T[q[i]].cnt;
for(ll i=2,x;i<=nodecnt+1;++i){
x = q[i];
T[x].f = T[T[x].fa].f + T[x].cnt*(T[x].len - T[T[x].fa].len);
}scanf("%s",s);
ll len = strlen(s),ans = 0;
for(ll i=0,p=0,l=0,c;i<len;++i){
c = s[i] - 'a';
while(p && !T[p].nx[c]) p = T[p].fa,l = T[p].len;
if(T[p].nx[c]) p = T[p].nx[c],l++;
if(p) ans += T[T[p].fa].f + T[p].cnt*(l - T[T[p].fa].len);
}printf("%lld
",ans);
getchar();getchar();
return 0;
}
问题的拓展 !!!
假如我们要求求本质不同的相同字串的个数呢?
解答:
首先我们知道肯定不能直接套用上述算法:
如:
input:
aa
aa
答案应该是2,但是上述算法会给出答案3
这时候就需要我们从每个节点表示的串的长度范围入手.
记录一下每一个状态的最大匹配长度,然后(ams = sum(maxlen_i - len_{fa}))即可.
见COGS 2610. [HZOI 2015]找相同子串V2
学长出的题呢 ... ...