Given an array nums
of integers, we need to find the maximum possible sum of elements of the array such that it is divisible by three.
Example 1:
Input: nums = [3,6,5,1,8]
Output: 18
Explanation: Pick numbers 3, 6, 1 and 8 their sum is 18 (maximum sum divisible by 3).
Example 2:
Input: nums = [4]
Output: 0
Explanation: Since 4 is not divisible by 3, do not pick any number.
Example 3:
Input: nums = [1,2,3,4,4]
Output: 12
Explanation: Pick numbers 1, 3, 4 and 4 their sum is 12 (maximum sum divisible by 3).
Constraints:
1 <= nums.length <= 4 * 10^4
1 <= nums[i] <= 10^4
Incorrect greedy solution
If a number is divisible by 3, always add it to the final sum. If not, there will be two general cases: remainder 1 or 2. We need to combine these non-3-divisible numbers. One way of doing this is to save remainder 1 and 2 numbers separately. Then from higher to lower, greedily combine 3 remainder 1 numbers and 3 remainder 2 numbers to get a sum divisible by 3. However, this is incorrect. Consider this counter example: [2,6,2,2,7].
We would have {7} in the remainder 1 list, {2,2,2} in the remainder 2 list. The above solution would pick all 2s to get a sum of 6. But by picking 7 and 2 we would get a better result. So the combination step is not limited to one case. We can have cross case picks (pick one remainder 1 and one remainder 2) to get an optimal result.
When greedy does not work, we should consider a dynamic programming solution.
dp[i][j]: the max sum from nums[0, i] with the sum % 3 == j.
Depending on nums[i] % 3, we update dp[i][j] using dp[i - 1]. The final answer is dp[n - 1][0].
The pitfall here is that depending on nums[i] % 3 and the current remainder j, nums[i] can or can not be picked. For example, if nums[i] % 3 == 1 and j = 2, then only if dp[i - 1][1] is > 0, meaning in nums[0, i - 1] we have picked at least 1 number to get a sum S such that S % 3 == 1, can we pick nums[i] to get a bigger sum S' such that S' % 3 == 2.
The state transition needs to correctly handle this logic.
class Solution { public int maxSumDivThree(int[] nums) { int n = nums.length; int[][] dp = new int[n][3]; if(nums[0] % 3 == 0) { dp[0][0] = nums[0]; } else if(nums[0] % 3 == 1) { dp[0][1] = nums[0]; } else { dp[0][2] = nums[0]; } for(int i = 1; i < n; i++) { int r = nums[i] % 3; for(int j = 0; j < 3; j++) { if(j == r) { dp[i][j] = Math.max(dp[i - 1][j], dp[i - 1][0] + nums[i]); } else { int diff = (j + 3 - r) % 3; dp[i][j] = Math.max(dp[i - 1][j], dp[i - 1][diff] > 0 ? dp[i - 1][diff] + nums[i] : 0); } } } return dp[n - 1][0]; } }