Consider a matrix M
with dimensions width * height
, such that every cell has value 0
or 1
, and any square sub-matrix of M
of size sideLength * sideLength
has at most maxOnes
ones.
Return the maximum possible number of ones that the matrix M
can have.
Example 1:
Input: width = 3, height = 3, sideLength = 2, maxOnes = 1
Output: 4
Explanation:
In a 3*3 matrix, no 2*2 sub-matrix can have more than 1 one.
The best solution that has 4 ones is:
[1,0,1]
[0,0,0]
[1,0,1]
Example 2:
Input: width = 3, height = 3, sideLength = 2, maxOnes = 2
Output: 6
Explanation:
[1,0,1]
[1,0,1]
[1,0,1]
Constraints:
1 <= width, height <= 100
1 <= sideLength <= width, height
0 <= maxOnes <= sideLength * sideLength
Key Observation:
If we only consider the top left square sub-matrix and put in allowed max number of 1s in it, the rest of non-overlapping sub-matrix is bascially a replica of the top left sub-matrix. Thus we can divide the entire matrix input non-overlapping sub-matrices, the rightmost and bottom sub-matrices may have smaller dimensions if width or height is not divisible by sideLength.
Based on the above observation, we derive the following algorithm.
1. Divide the input matrix into non-overlapping sub-matrices.
2. For each cell (i, j) in the input matrix, map it to the same relative position with the top-left sub-matrix (i % sideLength, j % sideLength). If we put a 1 in such a cell, we know that we can put a 1 to all other cells with the same relative positions in other sub-matrices, without increasing the number of 1s in the same sub-matrix. Increment the contribution of 1s from mapped position (i % sideLength, j % sideLength) by 1.
3. After iterating all cells in input matrix in step 2, we've got all the contribution count of 1s in M with positions mapped to the top-left sub-matrix. Since we want to maximize the total number of 1s in M, we sort this count in descending order and add the first maxOnes counts as the final answer.
class Solution { public int maximumNumberOfOnes(int width, int height, int sideLength, int maxOnes) { Integer[] cnt = new Integer[sideLength * sideLength]; Arrays.fill(cnt, 0); for(int i = 0; i < height; i++) { for(int j = 0; j < width; j++) { int x = i % sideLength; int y = j % sideLength; cnt[x * sideLength + y]++; } } Arrays.sort(cnt, Comparator.reverseOrder()); int total = 0; for(int i = 0; i < maxOnes; i++) { total += cnt[i]; } return total; } }