题目链接:https://nanti.jisuanke.com/t/42400
题意:给一个模式串集合,序号1~n,有T个操作,或者是交换序号,或者是查询模式串集合中序号在L到R之 间的字符串有多少个和目标串公共前缀长度大于等于K。
—————————————————————————————————
对模式串集合建字典树,则与目标串(LCP)大于等K的字符串,都在以目标串第(K)个字符所在的结点为根的子树中,每个模式串插入字典树的过程中记录序号,修改时交换序号,询问就相于在子树中询问有多少结点的序号在(L)到(R)之间。
涉及到子树,容易想到建立(DFS)序,再用数据结构维护
但是不建议在字典树的(DFS)序上用数据结构维护,因为可能出现两个一样的模式串,但是序号不一样,可能会覆盖之前的序号(个人观点,我写的时候是遇到了这种问题,可能写法问题)
所以可以用一个数组(h[i]),表示序号为(i)的字符串在字典树中的编号为(h[i]),对(h)使用数据结构维护,需要支持单点修改(交换视为两次单点修改)和查询区间([L,R])中(h[i])的值在范围([in[x],out[x]])中的数量((x)为目标串第(K)字符在字典树中的结点编号)
#include <bits/stdc++.h>
using namespace std;
const int maxn = 200010;
int ptov[maxn], tv[maxn];
/************************************************************/
struct Trie
{
int trie[maxn][26], tot = 0;
int ins(char s[])
{
int len = strlen(s);
int root = 0;
for (int i = 0; i < len; i++)
{
int id = s[i] - 'a';
if (trie[root][id] == 0)
trie[root][id] = ++tot;
root = trie[root][id];
}
return root;
}
int fin(int k, char s[])
{
int root = 0;
for (int i = 0; i < k; i++)
{
int id = s[i] - 'a';
if (trie[root][id] == 0)
return -1;
root = trie[root][id];
}
return root;
}
} t;
/************************************************************/
int in[maxn], out[maxn];
int tim = 0;
void dfs(int x)
{
in[x] = ++tim;
for (int i = 0; i < 26; i++)
{
if (t.trie[x][i] == 0)
continue;
dfs(t.trie[x][i]);
}
out[x] = tim;
}
/************************************************************/
int T[maxn], S[maxn], L[maxn * 150], R[maxn * 150], sum[maxn * 150];
int h[maxn];
int ul[maxn], ur[maxn];
int tot, num, n, m;
struct node
{
int l, r, k;
bool flag;
} Q[maxn << 1];
void build(int &rt, int l, int r)
{
rt = ++tot;
sum[rt] = 0;
if (l == r)
return;
int mid = (l + r) >> 1;
build(L[rt], l, mid);
build(R[rt], mid + 1, r);
}
void update(int &rt, int pre, int l, int r, int x, int val)
{
rt = ++tot;
L[rt] = L[pre];
R[rt] = R[pre];
sum[rt] = sum[pre] + val;
if (l == r)
return;
int mid = (l + r) >> 1;
if (x <= mid)
update(L[rt], L[pre], l, mid, x, val);
else
update(R[rt], R[pre], mid + 1, r, x, val);
}
int lowbit(int x) { return x & (-x); }
void add(int x, int val)
{
int res = lower_bound(h + 1, h + 1 + num, ptov[x]) - h;
while (x <= n)
{
update(S[x], S[x], 1, num, res, val);
x += lowbit(x);
}
}
int Sum(int x, int flag)
{
int res = 0;
while (x > 0)
{
if (flag == 1)
res += sum[L[ur[x]]];
else if (flag == 2)
res += sum[L[ul[x]]];
else if (flag == 3)
res += sum[ur[x]];
else
res += sum[ul[x]];
x -= lowbit(x);
}
return res;
}
int query(int s, int e, int ts, int te, int l, int r, int k)
{
if (l == r)
return Sum(e, 3) - Sum(s, 4) + sum[te] - sum[ts];
int mid = (l + r) >> 1;
int res = Sum(e, 1) - Sum(s, 2) + sum[L[te]] - sum[L[ts]];
if (k <= mid)
{
for (int i = e; i; i -= lowbit(i))
ur[i] = L[ur[i]];
for (int i = s; i; i -= lowbit(i))
ul[i] = L[ul[i]];
return query(s, e, L[ts], L[te], l, mid, k);
}
else
{
for (int i = e; i; i -= lowbit(i))
ur[i] = R[ur[i]];
for (int i = s; i; i -= lowbit(i))
ul[i] = R[ul[i]];
return res + query(s, e, R[ts], R[te], mid + 1, r, k);
}
}
/************************************************************/
char s[maxn];
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%s", s);
ptov[i] = t.ins(s);
}
dfs(0);
for (int i = 1; i <= n; i++)
{
ptov[i] = in[ptov[i]];
h[i] = ptov[i];
tv[i] = ptov[i];
}
int q = 0;
for (int i = 1; i <= m; i++)
{
int opt;
scanf("%d", &opt);
if (opt == 2) //询问
{
scanf("%s%d%d%d", s, &Q[q].k, &Q[q].l, &Q[q].r);
Q[q].k = t.fin(Q[q].k, s);
Q[q].flag = true;
q++;
}
else //修改
{
int tl, tr;
scanf("%d%d", &tl, &tr);
Q[q].l = tl;
Q[q].r = tv[tr];
Q[q].flag = false;
q++;
Q[q].l = tr;
Q[q].r = tv[tl];
Q[q].flag = false;
q++;
swap(tv[tl], tv[tr]);
}
}
sort(h + 1, h + 1 + n);
num = unique(h + 1, h + 1 + n) - h - 1;
tot = 0;
build(T[0], 1, num);
for (int i = 1; i <= n; i++)
update(T[i], T[i - 1], 1, num, lower_bound(h + 1, h + 1 + num, ptov[i]) - h, 1);
for (int i = 1; i <= n; i++)
S[i] = T[0];
for (int i = 0; i < q; i++)
{
if (Q[i].flag)
{
if (Q[i].k == -1)
puts("0");
else if (Q[i].k == 0)
printf("%d
", Q[i].r - Q[i].l + 1);
else
{
int tin = lower_bound(h + 1, h + 1 + num, in[Q[i].k]) - h;
auto tout = lower_bound(h + 1, h + 1 + num, out[Q[i].k]);
for (int j = Q[i].r; j; j -= lowbit(j))
ur[j] = S[j];
for (int j = Q[i].l - 1; j; j -= lowbit(j))
ul[j] = S[j];
int ans1 = query(Q[i].l - 1, Q[i].r, T[Q[i].l - 1], T[Q[i].r], 1, num, tout - h);
if (tin == 1)
{
if (*tout == out[Q[i].k])
printf("%d
", ans1);
else
printf("%d
", ans1 - 1);
continue;
}
for (int j = Q[i].r; j; j -= lowbit(j))
ur[j] = S[j];
for (int j = Q[i].l - 1; j; j -= lowbit(j))
ul[j] = S[j];
int ans2 = query(Q[i].l - 1, Q[i].r, T[Q[i].l - 1], T[Q[i].r], 1, num, tin - 1);
if (*tout == out[Q[i].k])
printf("%d
", ans1 - ans2);
else
printf("%d
", ans1 - 1 - ans2);
}
}
else
{
add(Q[i].l, -1);
ptov[Q[i].l] = Q[i].r;
add(Q[i].l, 1);
}
}
return 0;
}