Given a binary search tree, write a function kthSmallest to find the kth smallest element in it. Note: You may assume k is always valid, 1 ≤ k ≤ BST's total elements. Follow up: What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine? Hint: Try to utilize the property of a BST. What if you could modify the BST node's structure? The optimal runtime complexity is O(height of BST).
Java Solution 1 - Inorder Traversal
We can inorder traverse the tree and get the kth smallest element. Time is O(n).
1 /** 2 * Definition for a binary tree node. 3 * public class TreeNode { 4 * int val; 5 * TreeNode left; 6 * TreeNode right; 7 * TreeNode(int x) { val = x; } 8 * } 9 */ 10 public class Solution { 11 public int kthSmallest(TreeNode root, int k) { 12 TreeNode node = root; 13 Stack<TreeNode> st = new Stack<TreeNode>(); 14 int counter = 0; 15 while (!st.isEmpty() || node != null) { 16 if (node != null) { 17 st.push(node); 18 node = node.left; 19 } 20 else { 21 node = st.pop(); 22 counter++; 23 if (counter == k) return node.val; 24 node = node.right; 25 } 26 } 27 return -1; 28 } 29 }
Recursion method:
1 public class Solution { 2 int count = 0; 3 4 public int kthSmallest(TreeNode root, int k) { 5 List<Integer> res = new ArrayList<Integer>(); 6 res.add(null); 7 helper(root, k, res); 8 return res.get(0); 9 } 10 11 public void helper(TreeNode root, int k, List<Integer> res) { 12 if (root == null) return; 13 helper(root.left, k, res); 14 count++; 15 if (count == k) res.set(0, root.val); 16 helper(root.right, k, res); 17 } 18 }
Java Solution 2 Binary Search
We can let each node track the order, i.e., the number of elements that are less than itself(left Subtree size). Time is O(log(n)).
当前做法是O(NlogN)
如果我们频繁的操作该树,并且频繁的调用kth函数,有什么优化方法使时间复杂度降低至O(h)?h是树的高度。根据提示,我们可以在TreeNode中加入一个rank成员,这个变量记录的是该节点的左子树中节点的个数,其实就是有多少个节点比该节点小。这样我们就可以用二叉树搜索的方法来解决这个问题了。这个添加rank的操作可以在建树的时候一起完成。
1 public int kthSmallest(TreeNode root, int k) { 2 int count = countNodes(root.left); 3 if (k <= count) { 4 return kthSmallest(root.left, k); 5 } else if (k > count + 1) { 6 return kthSmallest(root.right, k-1-count); // 1 is counted as current node 7 } 8 9 return root.val; 10 } 11 12 public int countNodes(TreeNode n) { 13 if (n == null) return 0; 14 15 return 1 + countNodes(n.left) + countNodes(n.right); 16 }