诸神眷顾的幻想乡
题目链接:luogu P3346
题目大意
给你一棵树,且度数为 1 的点不超过 20 个。
然后每个点有一个颜色,然后问你这个树上有多少个不同的颜色序列。
思路
考虑颜色很像字符什么的,不难想到 SA / SAM 什么的。
然后你会想到 SAM 可以支持插入一个 Trie 树。
但你会发现这里并没有确定什么是根。
那你再看会发现度数为 (1) 的点不超过 (20) 个。
这提示我们可以直接暴力枚举这些点,建出 Trie 树插入进 SAM 里。
因为你只是要求不用的个数而不是求相同的个数,所以不会影响计算。
然后其实是不用真的建 Trie 树的,你只要从那个点开始 dfs 过程中不断插入即可。
然后这里讲讲 SAM 如何插入一个 Trie 树。
其实就是 dfs,每个点插入了之后的 (lst) 每个点都记录一个,然后插入一个点 (x) 的儿子的时候,用的 (lst) 就是插入好点 (x) 的 (lst)。
然后再广义 SAM 中,你可能要插入的时候,发现已经有这个儿子了,那你是没有必要新建点的(新建也行),你只需要走过去就可以了。
但是后面判断是否要新建那个复制点你还是要判断的。
代码
#include<queue>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
struct node {
int to, nxt;
}e[200001];
int n, c, a[100001], x, y, du[100001];
int le[100001], KK, lst[100001], tot;
bool in[100001];
queue <int> q;
ll ans;
struct SAM {
int fa, len, son[10], sz;
}d[4000001];
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
e[++KK] = (node){x, le[y]}; le[y] = KK;
}
int insert(int pl, int x) {//SAM 插入
if (d[pl].son[x]) {//因为是 Trie 树,所以如果之前有这个儿子就走下去,不用新建点(但后面的新建还是要考虑的)
int p = pl;
int q = d[p].son[x];
if (d[q].len == d[p].len + 1) return q;
else {
int nq = ++tot;
d[nq] = d[q];
d[nq].sz = 0;
d[nq].len = d[p].len + 1;
d[q].fa = nq;
for (; p && d[p].son[x] == q; p = d[p].fa)
d[p].son[x] = nq;
return nq;
}
}
int p = pl;
int np = ++tot;
d[np].len = d[p].len + 1;
d[np].sz = 1;
for (; p && !d[p].son[x]; p = d[p].fa)
d[p].son[x] = np;
if (!p) d[np].fa = 1;
else {
int q = d[p].son[x];
if (d[q].len == d[p].len + 1) d[np].fa = q;
else {
int nq = ++tot;
d[nq] = d[q];
d[nq].sz = 0;
d[nq].len = d[p].len + 1;
d[q].fa = d[np].fa = nq;
for (; p && d[p].son[x] == q; p = d[p].fa)
d[p].son[x] = nq;
}
}
return np;
}
void work(int s) {//不用真的建 Trie,直接 dfs 一遍就可以
lst[0] = 1;
lst[s] = insert(lst[0], a[s]);
q.push(s);
while(!q.empty()) {
int now = q.front();
q.pop();
in[now] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (!in[e[i].to]) {
lst[e[i].to] = insert(lst[now], a[e[i].to]);
q.push(e[i].to);
}
}
}
int main() {
scanf("%d %d", &n, &c);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
du[x]++; du[y]++;
add(x, y);
}
tot = 1;
for (int i = 1; i <= n; i++)
if (du[i] == 1) {
work(i);
memset(in, 0, sizeof(in));
}
for (int i = 1; i <= tot; i++) {
ans += 1ll * (d[i].len - d[d[i].fa].len);
}
printf("%lld", ans);
return 0;
}