Reincarnation
[Time Limit: 3000 msquad Memory Limit: 65536 kB
]
题意
给出一个字符串 (S),然后给出 (m) 次查询,每次给出一个 ([l, r]) 区间,每次求出这个区间内有多少个不同的子串。
思路
首先知道给出一个字符串,如何求出其不同子串的个数,有三种方法。
- 利用后缀数组,(sum_{i=2}^{n} len-sa[i]+1-height[i]) 就是答案
- 利用后缀自动机,(dp[i]) 表示第 (i) 个节点拥有的字符串个数,可以让 (v in u'next),(dp[u] += dp[v]),最后的 (dp[1]-1) 就是答案。
- 利用后缀自动机,(sum_{i=2}^{sz} node[i].len - node[father].len) 就是答案。
这里用到了第三种方法,因为字符串的长度很小,所以直接 (O(N^{2})) 暴力预处理出两两之间的字符串长度并保存起来,最后 (O(1)) 查询。
#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define lowbit(x) x & (-x)
#define mes(a, b) memset(a, b, sizeof a)
#define fi first
#define se second
#define pii pair<int, int>
#define INOPEN freopen("in.txt", "r", stdin)
#define OUTOPEN freopen("out.txt", "w", stdout)
typedef unsigned long long int ull;
typedef long long int ll;
const int maxn = 2e3 + 10;
const int maxm = 1e5 + 10;
const ll mod = 1e9 + 7;
const ll INF = 1e18 + 100;
const int inf = 0x3f3f3f3f;
const double pi = acos(-1.0);
const double eps = 1e-8;
using namespace std;
int n, m;
int cas, tol, T;
struct SAM {
struct Node {
int next[27];
int fa, len;
void init() {
mes(next, 0);
len = fa = 0;
}
} node[maxn<<1];
int last, sz;
void init() {
last = sz = 1;
node[sz].init();
}
void insert(int k) {
int p = last, np = last = ++sz;
node[np].init();
node[np].len = node[p].len + 1;
for(; p&&!node[p].next[k]; p=node[p].fa)
node[p].next[k] = np;
if(p == 0) {
node[np].fa = 1;
} else {
int q = node[p].next[k];
if(node[q].len == node[p].len+1) {
node[np].fa = q;
} else {
int nq = ++sz;
node[nq] = node[q];
node[nq].len = node[p].len+1;
node[np].fa = node[q].fa = nq;
for(; p&&node[p].next[k]==q; p=node[p].fa)
node[p].next[k] = nq;
}
}
}
int calc() {
return node[last].len - node[node[last].fa].len;
}
} sam;
char s[maxn];
int dp[maxn][maxn];
int main() {
scanf("%d", &T);
while(T--) {
scanf("%s", s+1);
int len = strlen(s+1);
mes(dp, 0);
for(int i=1; i<=len; i++) {
sam.init();
for(int j=i; j<=len; j++) {
sam.insert(s[j]-'a'+1);
dp[i][j] = dp[i][j-1] + sam.calc();
}
}
scanf("%d", &n);
for(int i=1, l, r; i<=n; i++) {
scanf("%d%d", &l, &r);
printf("%d
", dp[l][r]);
}
}
return 0;
}