编程之美 2013 全国挑战赛 资格赛 题目三 树上的三角形
题目三 树上的三角形
时间限制: 2000ms 内存限制: 256MB
描述
有一棵树,树上有只毛毛虫。它在这棵树上生活了很久,对它的构造了如指掌。所以它在树上从来都是走最短路,不会绕路。它还还特别喜欢三角形,所以当它在树上爬来爬去的时候总会在想,如果把刚才爬过的那几根树枝/树干锯下来,能不能从中选三根出来拼成一个三角形呢?
输入
输入数据的第一行包含一个整数 T,表示数据组数。
接下来有 T 组数据,每组数据中:
第一行包含一个整数 N,表示树上节点的个数(从 1 到 N 标号)。
接下来的 N-1 行包含三个整数 a, b, len,表示有一根长度为 len 的树枝/树干在节点 a 和节点 b 之间。
接下来一行包含一个整数 M,表示询问数。
接下来M行每行两个整数 S, T,表示毛毛虫从 S 爬行到了 T,询问这段路程中的树枝/树干是否能拼成三角形。
输出
对于每组数据,先输出一行"Case #X:",其中X为数据组数编号,从 1 开始。
接下来对于每个询问输出一行,包含"Yes"或“No”,表示是否可以拼成三角形。
数据范围
1 ≤ T ≤ 5
小数据:1 ≤ N ≤ 100, 1 ≤ M ≤ 100, 1 ≤ len ≤ 10000
大数据:1 ≤ N ≤ 100000, 1 ≤ M ≤ 100000, 1 ≤ len ≤ 1000000000
样例输入
2
5
1 2 5
1 3 20
2 4 30
4 5 15
2
3 4
3 5
5
1 4 32
2 3 100
3 5 45
4 5 60
2
1 4
1 3
样例输出
Case #1:
No
Yes
Case #2:
No
Yes
解题思路
这道题如果直接按照题意去写,那么可以利用广度优先搜索得到最短路径(因为这是一颗树,而不是图,所以不必使用最短路算法),然后判断路径上的边是否能组成一个三角形(先对路径排序,然后用两边之和大于第三边进行判断)。不过搜索的时间复杂度是 O(N),判断三角形的时间复杂度为 O(llgl)(其中 l 是最短路径的长度),小数据没问题,但大数据肯定会挂。
判断三角形是否存在,我并没有更好的办法,那么只能在求最短路径上下手了,以下面的树作为例子(题目没说是几叉树,不过没有关系):
图 1 一颗树的示例
求一棵树上两个节点的最短路径,其实就是求两个节点的最近公共祖先(Least Common Ancestors,LCA)。最近公共祖先指的是在一颗有根树中,找到两个节点 u 和 v 最近的公共祖先。这个概念很容易理解,例如上面节点 5 和 10 的 LCA 就是 1,3 和 11 的 LCA 是 3,7 和 9 的 LCA 是 3。
显然,两个节点与它们的最近公共祖先之间的路径(可以不断向上查找父节点得到)加起来,就是两个节点间的最短路径。上面节点 5 和 10 的最短路径就为 5、2、1、3、8、10;节点 3 和 11 的最短路径就为 3、8、9、11。
求 LCA 有两种算法,一种是离线的 Tarjan 算法,计算出所有 M 个询问所需的时间复杂度是 O(N+M);另一种是基于区间最值查询(Range Minimum/Maximum Query,RMQ)的在线算法,预处理时间是 O(NlgN),每次询问的时间复杂度为 O(1),总得时间复杂度就是 O(NlgN+M)。两个算法使用那个都可以,不过感觉还是用 Tarjan 更好点,占用内存更少,速度也更快。关于这两个算法的详细解释,可以参见算法之LCA与RMQ问题,这里就不详细说明了。
在线算法的代码
1 #include <stdio.h> 2 #include <cmath> 3 #include <algorithm> 4 #include <list> 5 #include <string.h> 6 using namespace std; 7 // 树的节点 8 struct Node { 9 int next, len; 10 Node (int n, int l):next(n), len(l) {} 11 }; 12 int pow2[20]; 13 list<Node> nodes[100010]; 14 bool visit[100010]; 15 int ns[200010]; 16 int nIdx; 17 int length[100010]; 18 int parent[100010]; 19 int depth[200010]; 20 int first[100010]; 21 int mmin[20][200010]; 22 int edges[100010]; 23 // DFS 对树进行预处理 24 void dfs(int u, int dep) 25 { 26 ns[++nIdx] = u; depth[nIdx] = dep; 27 visit[u] = true; 28 if (first[u] == -1) first[u] = nIdx; 29 list<Node>::iterator it = nodes[u].begin(), end = nodes[u].end(); 30 for (;it != end; it++) 31 { 32 int v = it->next; 33 if(!visit[v]) 34 { 35 length[v] = it->len; 36 parent[v] = u; 37 dfs(v, dep + 1); 38 ns[++nIdx] = u; 39 depth[nIdx] = dep; 40 } 41 } 42 } 43 // 初始化 RMQ 44 void init_rmq() 45 { 46 nIdx = 0; 47 memset(visit, 0, sizeof(visit)); 48 memset(first, -1, sizeof(first)); 49 depth[0] = 0; 50 length[1] = parent[1] = 0; 51 dfs(1, 1); 52 memset(mmin, 0, sizeof(mmin)); 53 for(int i = 1; i <= nIdx; i++) { 54 mmin[0][i] = i; 55 } 56 int t1 = (int)(log((double)nIdx) / log(2.0)); 57 for(int i = 1; i <= t1; i++) { 58 for(int j = 1; j + pow2[i] - 1 <= nIdx; j++) { 59 int a = mmin[i-1][j], b = mmin[i-1][j+pow2[i-1]]; 60 if(depth[a] <= depth[b]) { 61 mmin[i][j] = a; 62 } else { 63 mmin[i][j] = b; 64 } 65 } 66 } 67 } 68 // RMQ 询问 69 int rmq(int u, int v) 70 { 71 int i = first[u], j = first[v]; 72 if(i > j) swap(i, j); 73 int t1 = (int)(log((double)j - i + 1) / log(2.0)); 74 int a = mmin[t1][i], b = mmin[t1][j - pow2[t1] + 1]; 75 if(depth[a] <= depth[b]) { 76 return ns[a]; 77 } else { 78 return ns[b]; 79 } 80 } 81 82 int main() { 83 for(int i = 0; i < 20; i++) { 84 pow2[i] = 1 << i; 85 } 86 int T, n, m, a, b, len; 87 scanf("%d ", &T); 88 for (int caseIdx = 1;caseIdx <= T;caseIdx++) { 89 scanf("%d", &n); 90 for (int i = 0;i <= n;i++) { 91 nodes[i].clear(); 92 } 93 for (int i = 1;i < n;i++) { 94 scanf("%d%d%d", &a, &b, &len); 95 nodes[a].push_back(Node(b, len)); 96 nodes[b].push_back(Node(a, len)); 97 } 98 init_rmq(); 99 scanf("%d", &m); 100 printf("Case #%d:\n", caseIdx); 101 for (int i = 0;i < m;i++) { 102 scanf("%d%d", &a, &b); 103 // 利用 RMQ 得到 LCA 104 int root = rmq(a, b); 105 bool success = false; 106 int l = 0; 107 while (a != root) { 108 edges[l++] = length[a]; 109 a = parent[a]; 110 } 111 while (b != root) { 112 edges[l++] = length[b]; 113 b = parent[b]; 114 } 115 if (l >= 3) { 116 sort(edges, edges + l); 117 for (int j = 2;j < l;j++) { 118 if (edges[j - 2] + edges[j - 1] > edges[j]) { 119 success = true; 120 break; 121 } 122 } 123 } 124 if (success) { 125 puts("Yes"); 126 } else { 127 puts("No"); 128 } 129 } 130 } 131 return 0; 132 }
离线算法的代码
1 #include <stdio.h> 2 #include <string.h> 3 #include <list> 4 #include <algorithm> 5 using namespace std; 6 // 树和查询的节点 7 struct Node { 8 int next, len; 9 Node (int n, int l):next(n), len(l) {} 10 }; 11 list<Node> nodes[100010]; 12 list<Node> querys[100010]; 13 bool visit[100010]; 14 int ancestor[100010]; 15 int parent[100010]; 16 int length[100010]; 17 int edges[100010]; 18 // 查询的结果 19 bool result[100010]; 20 // 并查集 21 int uset[100010]; 22 int find(int x) { 23 int p = x, t; 24 while (uset[p] >= 0) p = uset[p]; 25 while (x != p) { t = uset[x]; uset[x] = p; x = t; } 26 return x; 27 } 28 void un_ion(int a, int b) { 29 if ((a = find(a)) == (b = find(b))) return; 30 if (uset[a] < uset[b]) { uset[a] += uset[b]; uset[b] = a; } 31 else { uset[b] += uset[a]; uset[a] = b; } 32 } 33 void init_uset() { 34 memset(uset, -1, sizeof(uset)); 35 } 36 37 void tarjan(int u) { 38 visit[u] = true; 39 ancestor[find(u)] = u; 40 list<Node>::iterator it = nodes[u].begin(), end = nodes[u].end(); 41 for (;it != end; it++) 42 { 43 int v = it->next; 44 if(!visit[v]) 45 { 46 length[v] = it->len; 47 parent[v] = u; 48 tarjan(v); 49 un_ion(u, v); 50 ancestor[find(u)] = u; 51 } 52 } 53 it = querys[u].begin(); end = querys[u].end(); 54 for (;it != end; it++) 55 { 56 int v = it->next; 57 if(visit[v]) 58 { 59 // 处理从 u 起始的查询 60 int root = ancestor[find(v)]; 61 int l = 0; 62 int a = u; 63 while (a != root) { 64 edges[l++] = length[a]; 65 a = parent[a]; 66 } 67 while (v != root) { 68 edges[l++] = length[v]; 69 v = parent[v]; 70 } 71 sort(edges, edges + l); 72 for (int j = 2;j < l;j++) { 73 if (edges[j - 2] + edges[j - 1] > edges[j]) { 74 result[it->len] = true; 75 break; 76 } 77 } 78 } 79 } 80 } 81 82 int main() { 83 int T, n, m, a, b, len; 84 scanf("%d ", &T); 85 for (int caseIdx = 1;caseIdx <= T;caseIdx++) { 86 scanf("%d", &n); 87 for (int i = 0;i <= n;i++) { 88 nodes[i].clear(); 89 querys[i].clear(); 90 } 91 for (int i = 1;i < n;i++) { 92 scanf("%d%d%d", &a, &b, &len); 93 nodes[a].push_back(Node(b, len)); 94 nodes[b].push_back(Node(a, len)); 95 } 96 scanf("%d", &m); 97 for (int i = 0;i < m;i++) { 98 scanf("%d%d", &a, &b); 99 // 查询要添加两遍,以防止出现遗漏 100 querys[a].push_back(Node(b, i)); 101 querys[b].push_back(Node(a, i)); 102 } 103 printf("Case #%d:\n", caseIdx); 104 init_uset(); 105 memset(visit, 0, sizeof(visit)); 106 memset(result, 0, sizeof(result)); 107 length[1] = parent[1] = 0; 108 tarjan(1); 109 for (int i = 0;i < m;i++) { 110 if (result[i]) { 111 puts("Yes"); 112 } else { 113 puts("No"); 114 } 115 } 116 } 117 return 0; 118 }
这两个算法应该是没问题的,但大数据的时候都 TLE 了,看来 list 真不能随便用,动态开辟内存还是太慢了。离线算法的内存使用大概只有在线算法的 70%。
后来我翻代码的时候(所有人的代码都可以看到,这点挺给力),看到有人没用上面的 LCA 算法,而是在用 DFS 建好树后,使要判断的两个节点 u 和 v 分别沿着父节点链向上遍历,同时保持 u 和 v 的深度是相同的,这样同样能得到最短路径和 LCA,只不过时间复杂度要高一些。但在这道题中也没有关系,因为在找三角形时还是需要把路径遍历一编才可以,LCA 的计算反而会带来额外的复杂性,看来的确是自己想复杂了。
这段遍历算法大概类似于下面这样:
1 while (deep(u) > deep(v)){ 2 //
算法:Eratosthenes 筛选求质数
public class Eratosthenes { /** * @param args */ public static void main(String[] args) { int N = 100; int i = 0, j = 0 , count = 0; int prime[] = new int[N + 1]; //初始化数据 for (i = 2; i <= N; i++) { prime[i] = 1; } //循环1(N 开方 次) for (i = 2; i * i <= N; i++) { if (prime[i] == 0) { count++; continue; } //循环2(N/i 次) 筛选被i整除的数 for (j = i * i; j <= N; j = j + i) { prime[j] = 0; count++; } } System.out.println("Times of calculation : " + count); j=0; for (i = 2; i <= N; i++) { if (prime[i] == 1) { System.out.print("\t"); System.out.print(i); j++; if(j % 10 == 0){ System.out.println(); } } } } }
循环次数 O(N):
N | 进入循环的次数 | 循环次数/N |
100 | 109 | 1.09 |
1000 | 1430 | 1.43 |
10000 | 17055 | 1.70 |
100000 | 193328 | 1.93 |
1000000 | 2122879 | 2.12 |
10000000 | 22852766 | 2.28 |