CF871E Restore the Tree
题目大意
yzh 有一棵 (n) 个节点的妹子树。大给给 ztr 对 yzh 使用了男♂酮♂魔♂法,导致 yzh 把妹子树忘记了!
好在 yzh 有 (k) 个最喜欢的妹子(称为关键节点),他记得每个最喜欢的妹子到其它妹子的距离。
现在 yzh 告诉了你 (n), (k) 以及 (k imes n) 个数,表示每个关键节点到其它节点的树上距离。请你帮 yzh 求出任意一种符合要求的树,或告诉他不存在这样的树。
数据范围:(1leq nleq 30000),(1leq kleq min{n, 200})。
本题题解
首先我们可以确定 (k) 个关键节点的编号:如果有解,每个关键节点一定和恰好一个节点距离为 (0),这个节点就是它自己。记第 (i) 个关键节点的编号为 (mathrm{id}_i)。
记 (mathrm{dis}(u, v)) 表示 (u, v) 两点的树上距离。
不妨以第 (1) 个关键节点为根,记为 (mathrm{root} = mathrm{id}_1)。
对于其他每个关键节点 (mathrm{id}_i),可以求出根到它的路径。具体来说,一个节点 (v) 在 (mathrm{root}) 到 (mathrm{id}_i) 的路径上,当且仅当 (mathrm{dis}(mathrm{root}, v) + mathrm{dis}(mathrm{id}_i, v) = mathrm{dis}(mathrm{root}, mathrm{id}_i))。对所有满足这样条件的 (v),按 (mathrm{dis}(mathrm{root}, v)) 从小到大排序,就可以得到从 (mathrm{root}) 到 (mathrm{id}_i) 的路径(注意,因为 (mathrm{dis}(mathrm{root}, v) < n),所以“排序”可以用桶排实现,时间复杂度是 (mathcal{O}(n)) 的)。
对每个关键节点,求出根到它的路径后,类似于向 ( ext{trie}) 里插入字符串一样,插入这条路径。至此,我们确定了树的结构的一部分:所有关键节点到根的路径的并(以下称这部分为“主体部分”)。这部分时间复杂度 (mathcal{O}(nk))。
考虑不在主体部分里的节点。我们已知它们的深度(也就是 (mathrm{root}) 到它的距离)。按深度从小到大考虑这些节点,问题相当于要给每个节点确定一个父亲。
对节点 (v),考虑它的祖先里,深度最大的、主体部分里的节点,记为 (u)。对每个 (v),(u) 就是所有【关键节点和 (v) 的 LCA】中的深度最大者。求一个关键节点 (mathrm{id}_i) 和 (v) 的 LCA(记为 (l)),可以考虑:(mathrm{dis}(mathrm{root}, mathrm{id}_i) + mathrm{dis}(mathrm{root}, v) - 2cdot mathrm{dis}(mathrm{root}, l) = mathrm{dis}(mathrm{id}_i, v)),解出 (mathrm{dis}(mathrm{root}, l) = x),那么 (mathrm{root}) 到 (mathrm{id}_i) 的路径上第 (x) 个点就是 (l) 了。
确定了 (u) 后,同时也可以知道 (u) 到 (v) 的距离。若 (u) 就是 (v) 的父亲,则已经完成任务(确定了 (v) 的父亲)。否则找出任意一个 (u) 子树里、不在主体部分中的、距离 (u) 为 (mathrm{dis}(u, v) - 1) 的节点,以它作为 (v) 的父亲即可。因为我们是按深度从小到大考虑不在主体部分中的节点的,所以 (v) 的父亲一定已经被加入过。
枚举 (mathcal{O}(n)) 个 (v),(mathcal{O}(k)) 求出对应的 (u),所以这部分时间复杂度也是 (mathcal{O}(nk))。
至此,我们在 (mathcal{O}(nk)) 的时间内解决了本题。
参考代码
以下代码来自 CF 官方题解,比较清晰易懂。
#include <bits/stdc++.h>
using namespace std;
void impossible() {
puts("-1");
exit(0);
}
int main() {
int n, k;
assert(scanf("%d %d", &n, &k) == 2);
assert(k > 0);
vector <int> dist[k];
vector <int> idx(n, 0);
for (int i = 0; i < k; ++i) {
dist[i] = vector <int> (n, 0);
for (int j = 0; j < n; ++j)
assert(scanf("%d", &dist[i][j]) == 1);
idx[i] = -1;
for (int j = 0; j < n; ++j) {
if (dist[i][j] == 0) {
if (idx[i] != -1)
impossible();
else
idx[i] = j;
}
}
if (idx[i] == -1)
impossible();
for (int j = 0; j < i; ++j)
if (idx[i] == idx[j])
impossible();
}
vector <int> p(n, -1);
vector <int> h = dist[0];
int root = idx[0];
vector <int> ps[k];
for (int i = 0; i < k; ++i)
ps[i] = vector <int> ();
for (int t = 1; t < k; ++t) {
int v = idx[t];
vector <int> &d = dist[t];
vector <int> pars(n, -1);
for (int u = 0; u < n; ++u) {
if (h[u] + d[u] == h[v]) {
if (pars[h[u]] != -1)
impossible();
pars[h[u]] = u;
}
}
for (int i = 0; i < n; ++i) {
if (pars[i] == -1 && i <= h[v])
impossible();
if (pars[i] != -1 && i > h[v])
impossible();
}
for (int i = 0; i < h[v]; ++i) {
int pu = pars[i];
int u = pars[i + 1];
if (p[u] != -1 && p[u] != pu)
impossible();
p[u] = pu;
}
ps[t] = pars;
}
if (p[root] != -1)
impossible();
vector <int> subs[n];
for (int i = 0; i < n; ++i)
subs[i] = vector <int> ();
vector <int> vh[n];
for (int i = 0; i < n; ++i)
vh[i] = vector <int> ();
for (int v = 0; v < n; ++v)
vh[h[v]].push_back(v);
for (int hh = 0; hh < n; ++hh) {
for (int v: vh[hh]) {
if (v == root || p[v] != -1)
continue;
int lp = root;
int lt = 0;
for (int t = 1; t < k; ++t) {
int u = idx[t];
vector <int> &d = dist[t];
vector <int> &pars = ps[t];
// h[v] + h[u] - 2 * h[l] == d[v]
// 2 * h[l] == h[v] + h[u] - d[v]
int hl2 = h[v] + h[u] - d[v];
if (hl2 % 2 != 0)
impossible();
int hl = hl2 / 2;
if (hl < 0 || hl > h[u] || hl > h[v])
impossible();
int l = pars[hl];
if (h[l] >= h[lp]) {
if (pars[h[lp]] != lp)
impossible();
lp = l;
lt = t;
} else {
if (ps[lt][h[l]] != l)
impossible();
}
}
subs[lp].push_back(v);
}
}
for (int i = 0; i < n; ++i) {
int prev = i;
int nxt = -1;
for (int v: subs[i]) {
assert(prev != -1);
if (h[v] == h[prev] + 1) {
p[v] = prev;
nxt = v;
} else {
prev = nxt;
if (prev != -1 && h[v] == h[prev] + 1) {
p[v] = prev;
nxt = v;
} else
impossible();
}
}
}
for (int i = 0; i < n; ++i) {
if (i == root)
assert(p[i] == -1);
else {
printf("%d %d
", i + 1, p[i] + 1);
}
}
return 0;
}