Let's discuss today the problem 378. Kth Smallest Element in a Sorted Matrix at leetcode. Problem link https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/
The example in question has following
Input: matrix = [
[1,5,9]
[10,11,13]
[12,13,15]
], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
With a simple observation we can understand that the first element will be the smallest number (low) in the matrix and the last number say , high (matrix[n-1][n-1] , n bring length of matrix) will be the highest.
It is clear that question wan't has to find the number before which there are k-1 lower number. Since the matrix is sorted row and column wise, we need to find how many numbers in each row are bigger than K .
If we find the sum of numbers lower than k in each row we can get the answer.
In any row the smallest is at 0th column and largest element is at n-1 column.
In any column smallest is at 0th row and largest element is at n-1 row.
Number of elements lesser than current element at any place in the matrix in a row is column number -1 ;
Since, array is sorted , can we try Binary Search ? Let's see
If the find the mid of high and low , by mid=low+(high-low)/2
, it can be present or not present in the matrix, but we can try to see how many are lower than mid.
If there are lesser number which are smaller than mid, we need to increase lower bound to find more numbers or else decrease the higher bound to remove few numbers.
If we keep finding mid with our binary search , we will find k numbers which are lower than mid. The position of the mid is the position on required answer, as it will be nothing but the number from the grid.
class Solution {
int[][] matrix=null;
public int kthSmallest(int[][] matrix, int k) {
this.matrix=matrix;
int n = matrix.length;
int low = matrix[0][0];
int high = matrix[n-1][n-1];
while(low<high){
int mid=low+(high-low)/2;
if(lowerThanMid(mid)<k){
low=mid+1;
}else{
high =mid;
}
}
return low;
}
int lowerThanMid(int target){
int n= matrix.length;
int i=matrix.length-1;
int j=0;
int lower=0;
while(i>=0 && j<n){
if(matrix[i][j] > target){
i--;
}else{
j++;
lower+=i+1;
}
}
return lower;
}
}