暴力部分
首先 (n^3) 的暴力很好想到,枚举起点,枚举终点,把串放 SAM 上随便匹配即可。DFS优化可以将复杂度优化成 (O(n^2))(没有人不会 (n^2) 的做法吧)
另一种做法
发现这是暴力的极限了。由于这是树上路径问题,我们考虑点分治。我们可以把路径 a --> b --> c 在 (b) (分治重心)处统计。直接把 a --> b 和 b --> c 的所有路径搞出来然后匹配即可。
发现复杂度仍然无法接受。不易发现,((a,b),(b,c)) 能够匹配,当且仅当它们都在 (s) 中出现过,并且能够在 (s) 上某些地方的 (b) 中匹配上,贡献为这种 (b) 的位置数量。一种方法是,我们可以先将 endpos 是一个特定的集合的情况的次数先搞出来,并且将反串 endpos 是一个特定的集合的情况的次数也搞出来,然后把贡献一直推到各个 endpos 里,最后 endpos 对应相乘后求和即可。但是这样可能会计重(俩链都出现在同一个子树中),容斥一下就好。
不过这要求我们前端加字符,求新的具有新的 endpos 集合的节点。我们发现 Parent Tree 上的父子关系实际上就是在父亲的“串”的前面加了字符。因此我们可以直接在 Parent Tree 上建出转移路径即可。当我们发现当前节点还能够容纳新的 endpos 集合时,就留在当前节点;否则按照转移路径转移。这个判断可以通过比较 (nwlen) 和 (len[np]) 来实现。
仔细思考一下,我们发现这种做法是 (O(n+m)) 的,其中 (n) 是当前分治区域的大小。加上 (m) 是因为我们需要在两个 SAM 上向下推一遍,推到各个 endpos 里。因此我们不能让这种方法进行太多次。
正解
根号分治。如果当前分治块大小大于 (sqrt{n}),就用第二种方法,否则用第一种方法。
复杂度:(O((n+m)sqrt n))。
证明。假设 (n=2^k),那么分治最多有 (k) 层,每层大小为 (n),一共有 (2n) 个“叶子”(共 (2n) 次 (calc))。如果我们在点分树的下一半的最上面一层用 (n^2) 暴力,复杂度为 (O((sqrt n)^3))(每次 (O((sqrt n)^2)),最多 (O(sqrt n)) 次)。对于上面的部分,它的节点数为 (O(2^{k/2})),即 (O(sqrt n)),那么复杂度为 (O(msqrt n))。
注意如果容斥子树太小,也要用暴力做法。否则一个菊花图就能卡到 (O(nm))
关键代码:
char ch[N], s[N];
bool vis[N];
struct SAM {
char s[N];
int slen;
int son[N][26], path[N][26], fa[N], len[N], tot, lst, mp[N], pos[N], siz[N];
int bin[N], id[N], tag[N];
vector<int> vec[N];
SAM()
inline void ins(int c)
void dfs(int cur) {
for (register uint i = 0; i < vec[cur].size(); ++i) {
int to = vec[cur][i]; dfs(to);
pos[cur] = pos[to]; path[cur][s[pos[to] - len[cur]] - 'a'] = to;
}
}
inline void init() {
slen = strlen(s + 1);
for (register int i = 1; i <= slen; ++i) ins(s[i] - 'a');
for (register int i = 2; i <= tot; ++i) vec[fa[i]].push_back(i);
dfs(1);
for (register int i = 1; i <= tot; ++i) ++bin[len[i]];
for (register int i = 1; i <= m; ++i) bin[i] += bin[i - 1];
for (register int i = tot; i; --i) id[bin[len[i]]--] = i;
for (register int i = tot; i > 1; --i) {
int p = id[i];
siz[fa[p]] += siz[p];
}
}
void dfs_sam(int cur, int faa, int np, int nwlen) {
if (nwlen == len[np]) np = path[np][ch[cur] - 'a'];
else if (ch[cur] != s[pos[np] - nwlen]) np = 0;//Attention!!!
if (!np) return ;
++nwlen;//Attention!!!
++tag[np];
for (register int i = head[cur]; i; i = e[i].nxt) {
int to = e[i].to; if (to == faa || vis[to]) continue;
dfs_sam(to, cur, np, nwlen);
}
}
void Pushdown(int p) {
for (register uint i = 0; i < vec[p].size(); ++i)
tag[vec[p][i]] += tag[p], Pushdown(vec[p][i]);
//Attention!!!!!!!
}
inline void clear_tag() {
memset(tag, 0, sizeof(tag));
}
}A, B;
void find_root(int cur, int faa)
void dfs_siz(int cur, int faa) {
++Siz;
siz[cur] = 1;
...
}
void dfs_dfs_dfs(int cur, int faa, int nw) {
if (A.son[nw][ch[cur] - 'a']) nw = A.son[nw][ch[cur] - 'a'], ans += A.siz[nw];
else return ;
for (register int i = head[cur]; i; i = e[i].nxt) {
int to = e[i].to; if (to == faa || vis[to]) continue;
dfs_dfs_dfs(to, cur, nw);
}
}
void dfs_dfs(int cur, int faa) {
dfs_dfs_dfs(cur, 0, 1);
for (register int i = head[cur]; i; i = e[i].nxt) {
int to = e[i].to; if (to == faa || vis[to]) continue;
dfs_dfs(to, cur);
}
}
void dfs_of_dfs_of_dfs(int cur, int faa, int np) {
int c = ch[cur] - 'a';//Attention!!!
np = A.son[np][c];
if (!np) return;
ans -= A.siz[np];//Attention!!!
for (register int i = head[cur]; i; i = e[i].nxt) {
int to = e[i].to; if (to == faa || vis[to]) continue;
dfs_of_dfs_of_dfs(to, cur, np);
}
}
int stk[N], stop, yuan;
inline void del_ans() {
int np = 1;
for (register int i = stop; i; --i) {
int c = ch[stk[i]] - 'a';
np = A.son[np][c];
if (!np) return ;
}
dfs_of_dfs_of_dfs(yuan, root, np);
}
void dfs_of_dfs(int cur, int faa) {
stk[++stop] = cur;
del_ans();
for (register int i = head[cur]; i; i = e[i].nxt) {
int to = e[i].to; if (to == faa || vis[to]) continue;
dfs_of_dfs(to, cur);
}
--stop;
}
void dfs_vis(int cur) {
vis[cur] = true;
...
}
void calc(int cur, int faa, int initc, int initlen, int type) {
A.dfs_sam(cur, faa, ~initc ? A.son[1][initc] : 1, initlen);
B.dfs_sam(cur, faa, ~initc ? B.son[1][initc] : 1, initlen);
A.Pushdown(1); B.Pushdown(1);
ll res = 0;
for (register int i = 1; i <= m; ++i) {
res += 1ll * A.tag[A.mp[i]] * B.tag[B.mp[m - i + 1]];
}
ans += res * type;
A.clear_tag(); B.clear_tag();
}
void dfs(int cur) {
Siz = 0;
dfs_siz(cur, 0);
if (Siz <= limi) {
dfs_dfs(cur, 0);
dfs_vis(cur);
return ;
}
calc(cur, 0, -1, 0, 1);
for (register int i = head[cur]; i; i = e[i].nxt) {
int to = e[i].to; if (vis[to]) continue;
if (siz[to] > limi) calc(to, cur, ch[cur] - 'a', 1, -1);
else stk[stop = 1] = cur, root = cur, yuan = to, dfs_of_dfs(to, cur), stop = 0;
}
...
}
int main() {
...
for (register int i = 1; i <= m; ++i) B.s[i] = A.s[m - i + 1];//Attention!!!
...
dfs(root);
...
}