You are given an integer array nums
and you have to return a new counts
array. The counts
array has the property where counts[i]
is the number of smaller elements to the right of nums[i]
.
Example 1:
Input: nums = [5,2,6,1]
Output: [2,1,1,0]
Explanation:
To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.
Example 2:
Input: nums = [-1]
Output: [0]
Example 3:
Input: nums = [-1,-1]
Output: [0,0]
Constraints:
1 <= nums.length <= 10^5
-10^4 <= nums[i] <= 10^4
Brute force solution takes O(N^2) time.
Solution 1. Count smaller numbers during merge sort, O(N * logN)
In merge sort, when we are doing the merge for subarray [L, R], each time we pick a number A[i] from the left half, we can check how many numbers from the right half have been picked. This count is the number of smaller values of A[i] on its right side. After the merging for [L, R] is done, it is in sorted order, which means when we merge even larger subarrays, [L,R] will be on the same side, either left half or right half but not both. This way we'll never double count. So for A[i], we can first save its index information for the purpose of updating the final count. Then do a merge sort on A, the total number of right half numbers that get picked before A[i] is its answer.
class Pair { int v; int idx; Pair(int v, int idx) { this.v = v; this.idx = idx; } } class Solution { private Integer[] ans; public List<Integer> countSmaller(int[] nums) { ans = new Integer[nums.length]; Arrays.fill(ans, 0); Pair[] pairs = convertToPairs(nums); mergeSort(pairs, new Pair[nums.length], 0, nums.length - 1); return Arrays.asList(ans); } private Pair[] convertToPairs(int[] nums) { Pair[] pairs = new Pair[nums.length]; for(int i = 0; i < nums.length; i++) { pairs[i] = new Pair(nums[i], i); } return pairs; } private void mergeSort(Pair[] pairs, Pair[] aux, int left, int right) { if(left < right) { int mid = left + (right - left) / 2; mergeSort(pairs, aux, left, mid); mergeSort(pairs, aux, mid + 1, right); for(int i = left; i <= right; i++) { aux[i] = pairs[i]; } int j = left, k = mid + 1; for(int i = left; i <= right; i++) { if(j <= mid && k <= right) { if(aux[j].v <= aux[k].v) { pairs[i] = aux[j]; ans[aux[j].idx] += (k - mid - 1); j++; } else { pairs[i] = aux[k]; k++; } } else if(j <= mid) { pairs[i] = aux[j]; ans[aux[j].idx] += (k - mid - 1); j++; } else { pairs[i] = aux[k]; k++; } } } } }
Solution 2. Binary Indexed Tree, O(N * log(maxV))
Create a binary indexed tree that covers the full range of all possible nums[i] values. Here let's shift the entire array by 10^4 + 1 so the minimum value is 1.
Then starting from right to left, do a sum range query on all smaller values than nums[i], this is the answer for nums[i]. Then update frequency count at key nums[i] + 10001.
class BinaryIndexedTree { int[] ft; BinaryIndexedTree(int n) { ft = new int[n]; } void update(int k, int v) { for(; k < ft.length; k += (k & (-k))) { ft[k] += v; } } int rangeSumQuery(int r) { int sum = 0; for(; r > 0; r -= (r & (-r))) { sum += ft[r]; } return sum; } } class Solution { public List<Integer> countSmaller(int[] nums) { BinaryIndexedTree bit = new BinaryIndexedTree(20002); List<Integer> ans = new ArrayList<>(); for(int i = nums.length - 1; i >= 0; i--) { ans.add(bit.rangeSumQuery(nums[i] + 10001 - 1)); bit.update(nums[i] + 10001, 1); } Collections.reverse(ans); return ans; } }
Solution 3. Segment Tree, O(N * logN), same idea with binary indexed tree.
class SegmentTree { SegmentTree lChild, rChild; int leftMost, rightMost, sum; SegmentTree(int[] a, int leftMost, int rightMost) { this.leftMost = leftMost; this.rightMost = rightMost; if(leftMost == rightMost) sum = a[leftMost]; else { int mid = leftMost + (rightMost - leftMost) / 2; lChild = new SegmentTree(a, leftMost, mid); rChild = new SegmentTree(a, mid + 1, rightMost); recalc(); } } void recalc() { if(leftMost != rightMost) { sum = lChild.sum + rChild.sum; } } void update(int idx, int v) { if(leftMost == rightMost) { sum += v; return; } if(idx <= lChild.rightMost) lChild.update(idx, v); else rChild.update(idx, v); recalc(); } int rangeSum(int l, int r) { if(l > rightMost || r < leftMost) return 0; else if(l <= leftMost && r >= rightMost) return sum; return lChild.rangeSum(l, r) + rChild.rangeSum(l, r); } } class Solution { public List<Integer> countSmaller(int[] nums) { List<Integer> ans = new ArrayList<>(); int[] a = new int[20002]; SegmentTree st = new SegmentTree(a, 0, a.length - 1); for(int i = nums.length - 1; i >= 0; i--) { ans.add(st.rangeSum(0, nums[i] + 10001 - 1)); st.update(nums[i] + 10001, 1); } Collections.reverse(ans); return ans; } }
Related Problems