You are given a list of integers nums
and an integer k
. Consider an operation where you increment any one element once(increase by 1). Given that you can perform this at most k
times, return the length of the longest subarray containing equal elements.
Constraints
n ≤ 100,000
wheren
is the length ofnums
k < 2 ** 31
Example 1
Input
nums = [2, 4, 8, 5, 9, 6]
k = 6
Output
3
Explanation
We can increment 8
once and 5
four times to get a sublist of [9, 9, 9]
.
This is a really nice problem that can be solved using different data structures and algorithms. We start with the following 2 observations:
1. Given a window of length L, the cost is maxV * L - window sum.
2. If we can find a window of length L such that the cost is <= k, then we can definitely find a window of smaller length. Conversely, if we can not find such a window of length L, we definitely can not find any window of bigger length that fits within the cost limit.
Point 2 is a good hint that we can try binary search in the answer space [1, N], assuming that we are able to do the following efficiently.
(a) Given a fixed length L, we can efficiently determine if it fits within the cost limit or not.
(b) Compute window sum.
Solution 1. O(N * logN * logN), sliding window + tree map + binary search
For (a) we can keep a running tree map that keeps track of the frequencies of numbers inside a sliding window. This gives O(N * logN) runtime for (a). For (b), since the input array is static, we can pre-compute prefix sum so we get O(1) range sum query. The overall runtime is O(N * log N * log N), which is better than O(N^2) but still a bit slow.
class Solution { public int solve(int[] nums, int k) { if(nums.length == 0) return 0; int n = nums.length, l = 1, r = n; int[] ps = computePrefixSum(nums); while(l < r - 1) { int mid = l + (r - l) / 2; if(check(nums, ps, mid, k)) { l = mid; } else { r = mid - 1; } } if(check(nums, ps, r, k)) return r; return l; } private int[] computePrefixSum(int[] nums) { int[] ps = new int[nums.length]; ps[0] = nums[0]; for(int i = 1; i < ps.length; i++) { ps[i] = ps[i - 1] + nums[i]; } return ps; } private boolean check(int[] nums, int[] ps, int len, int op) { TreeMap<Integer, Integer> tm = new TreeMap<>(); for(int i = 0; i < len; i++) { tm.put(nums[i], tm.getOrDefault(nums[i], 0) + 1); } int cost = tm.lastKey() * len - ps[len - 1]; if(cost <= op) return true; for(int i = len; i < nums.length; i++) { tm.put(nums[i - len], tm.get(nums[i - len]) - 1); tm.put(nums[i], tm.getOrDefault(nums[i], 0) + 1); if(tm.get(nums[i - len]) == 0) { tm.remove(nums[i - len]); } cost = tm.lastKey() * len - (ps[i] - ps[i - len]); if(cost <= op) return true; } return false; } }
Solution 2. O(N * logN), sliding window + range max query sparse table + binary search
Since the input array is static and we need to do range max query, we can use the sparse table that takes O(N * logN) preprocessing time and supports O(1) range max query. This will improve solution 1 by a log N factor as it takes O(1) instead of O(logN) to do range max query now.
class Solution { class RangeMaxSparseTable { int n, k; int[] log; int[][] rangeMax; RangeMaxSparseTable(int[] nums) { n = nums.length; log = new int[n + 1]; log[1] = 0; for(int i = 2; i <= n; i++) { log[i] = log[i / 2] + 1; } k = log[n]; rangeMax = new int[n][k + 1]; for(int i = 0; i < n; i++) { rangeMax[i][0] = nums[i]; } //rangeMax[i][j]: max in range[i, i + 2^j - 1] of length 2^j //rangeMax[i][j - 1]: max in range[i, i + 2^(j - 1) - 1] of length 2^(j - 1) //rangeMax[i + (1 << (j - 1))][j - 1]: max in range[i + 2^(j - 1), i + 2^j - 1] of length 2^(j - 1) for(int j = 1; j <= k; j++) { for(int i = 0; i + (1 << j) <= n; i++) { rangeMax[i][j] = Math.max(rangeMax[i][j - 1], rangeMax[i + (1 << (j - 1))][j - 1]); } } } int query(int L, int R) { int j = log[R - L + 1]; return Math.max(rangeMax[L][j], rangeMax[R - (1 << j) + 1][j]); } } private int[] ps; private RangeMaxSparseTable rmst; public int solve(int[] nums, int k) { if(nums.length == 0) return 0; ps = computePrefixSum(nums); rmst = new RangeMaxSparseTable(nums); int n = nums.length, l = 1, r = n; while(l < r - 1) { int mid = l + (r - l) / 2; if(check(nums, mid, k)) { l = mid; } else { r = mid - 1; } } if(check(nums, r, k)) return r; return l; } private int[] computePrefixSum(int[] nums) { int[] ps = new int[nums.length]; ps[0] = nums[0]; for(int i = 1; i < ps.length; i++) { ps[i] = ps[i - 1] + nums[i]; } return ps; } private boolean check(int[] nums, int len, int op) { int cost = rmst.query(0, len - 1) * len - ps[len - 1]; if(cost <= op) return true; for(int i = len; i < nums.length; i++) { cost = rmst.query(i - len + 1, i) * len - (ps[i] - ps[i - len]); if(cost <= op) return true; } return false; } }
Solution 3. O(N * logN), sliding window + range max query sparse table
We can further optimize solution 2 by replacing the binary search with greedy window sliding: Greedily extend the window toward right side until the cost is too big. When this happens, shrink the window's left bound until the cost fits within limit again. Update answer after these 2 steps. Each update represents the max length window whose right bound lands at the number currently processed. This does not change the asymptotic runtime as the sparse table preprocessing alone already takes O(N * logN) time.
class Solution { class RangeMaxSparseTable { int n, k; int[] log; int[][] rangeMax; RangeMaxSparseTable(int[] nums) { n = nums.length; log = new int[n + 1]; log[1] = 0; for(int i = 2; i <= n; i++) { log[i] = log[i / 2] + 1; } k = log[n]; rangeMax = new int[n][k + 1]; for(int i = 0; i < n; i++) { rangeMax[i][0] = nums[i]; } for(int j = 1; j <= k; j++) { for(int i = 0; i + (1 << j) <= n; i++) { rangeMax[i][j] = Math.max(rangeMax[i][j - 1], rangeMax[i + (1 << (j - 1))][j - 1]); } } } int query(int L, int R) { int j = log[R - L + 1]; return Math.max(rangeMax[L][j], rangeMax[R - (1 << j) + 1][j]); } } private int[] ps; private RangeMaxSparseTable rmst; public int solve(int[] nums, int k) { int n = nums.length; if(n == 0) return 0; ps = new int[n]; ps[0] = nums[0]; for(int i = 1; i < n; i++) { ps[i] = ps[i - 1] + nums[i]; } rmst = new RangeMaxSparseTable(nums); int left = 0, right = 0, ans = 0; while(right < n) { int len = right - left + 1; int maxV = rmst.query(left, right); int cost = maxV * len - ps[right] + (left > 0 ? ps[left - 1] : 0); while(cost > k) { left++; len--; cost = rmst.query(left, right) * len - ps[right] + (left > 0 ? ps[left - 1] : 0); } ans = Math.max(ans, len); right++; } return ans; } }
Solution 4. O(N), Double-Ended Queue + Two Pointers + Sliding Window + Running Sum
If you recall, we can use double ended queue to find out a fixed-length sliding window's maximum value in O(N) time. For that problem, the idea is that for a given number v, we can safely discard all previous <= v numbers because they have no impact on the max value of the current window that includes v. If we do this on the end of a deque, we'll get a strictly decreasing sequence. So the max is always the head of this dq. When sliding the window, if the head of the dq is too far away from the current element, it is out of the current window and we remove it from the dq.
Here, we leverage the same idea. The only difference here is that removal criteria. In sliding window max, we remove the dq's head when its index difference with the current element's index is bigger than window size. Here we keep two pointers left and right, representing the current subarray range. When adding the current element into the current window, we check if cost is too big. If it is, we need to keep shrink our window's left bound until cost fits within limit. Because the dq does not keep every element in it, we need to make sure during the process of shrinking the window's left bound, we also keep removing dq's head if its index is smaller than the left bound. Dq can not have elements that are already out of the current valid window.
class Solution { public int solve(int[] nums, int k) { int n = nums.length; if(n == 0) return 0; ArrayDeque<Integer> dq = new ArrayDeque<>(); int left = 0, right = 0, sum = 0, ans = 0; while(right < n) { while(dq.size() > 0 && nums[dq.peekLast()] <= nums[right]) { dq.removeLast(); } dq.addLast(right); sum += nums[right]; while(nums[dq.peekFirst()] * (right - left + 1) - sum > k) { sum -= nums[left]; left++; while(dq.size() > 0 && dq.peekFirst() < left) { dq.removeFirst(); } } ans = Math.max(ans, right - left + 1); right++; } return ans; } }
Related Problem