Open In App

K Closest Points to the Origin

Last Updated : 18 Oct, 2025
Comments
Improve
Suggest changes
17 Likes
Like
Report

Given a 2D array points[][] and an integer k, where each element of points represents a point [xi, yi] on the X-Y plane, find the k points that are closest to the origin (0,0) in any order.
Note: The distance between two points on a plane is the Euclidean distance

Examples: 

Input: k = 2, points = [[1, 3], [-2, 2], [5, 8], [0, 1]]
Output: [[-2, 2], [0, 1]]
Explanation: The Euclidean distances from the origin are:
Point (1, 3) = sqrt(10)
Point (-2, 2) = sqrt(8)
Point (5, 8) = sqrt(89)
Point (0, 1) = sqrt(1)
The two closest points to the origin are [-2, 2] and [0, 1].

Input: k = 1, points = [[2, 4], [-1, -1], [0, 0]]
Output: [[0, 0]]
Explanation: The Euclidean distances from the origin are:
Point (2, 4) = sqrt(20)
Point (-1, -1) = sqrt(2)
Point (0, 0) = sqrt(0)
The closest point to origin is [0, 0].

[Naive Approach] Using Sorting - O(n*log n) Time and O(1) Space

The idea is that calculate the squared distance of all the points from the origin and then sort the points based on these distances. Since comparing squared distances gives the same result as comparing actual distances, we don’t need to use square roots. After sorting, we just take the first k points, which will be the closest to the origin.

C++
//Driver Code Starts
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
//Driver Code Ends


// Function to calculate squared distance from the origin
 static int squaredDis(vector<int>& point) {
    return point[0] * point[0] + point[1] * point[1];
}

// Comparator function
static bool cmp(vector<int>& p1, vector<int>& p2) {
    return squaredDis(p1) < squaredDis(p2);
}

// Function to find k closest points
vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
    sort(points.begin(), points.end(), cmp);
    return vector<vector<int>>(points.begin(), points.begin() + k);
}


//Driver Code Starts
int main() {
    vector<vector<int>> points = {{1, 3}, {-2, 2}, {5, 8}, {0, 1}};
    int k = 2;
    
    vector<vector<int>> res = kClosest(points, k);
    
    for (vector<int> point : res) {
        cout << point[0] << ", " << point[1];
        cout<<endl;
    }

    return 0;
}

//Driver Code Ends
Java
//Driver Code Starts
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;

class GFG {
//Driver Code Ends


    // Function to calculate squared distance from the origin
    static int squaredDis(int[] point) {
        return point[0] * point[0] + point[1] * point[1];
    }

    // Function to find k closest points
    static ArrayList<ArrayList<Integer>> kClosest(int[][] points, int k) {
        Arrays.sort(points, new Comparator<int[]>() {
            public int compare(int[] p1, int[] p2) {
                return squaredDis(p1) - squaredDis(p2);
            }
        });

        // Convert first k points to ArrayList<ArrayList<Integer>>
        ArrayList<ArrayList<Integer>> res = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            ArrayList<Integer> temp = new ArrayList<>();
            temp.add(points[i][0]);
            temp.add(points[i][1]);
            res.add(temp);
        }
        return res;
    }


//Driver Code Starts
    public static void main(String[] args) {
        int[][] points = {{1, 3}, {-2, 2}, {5, 8}, {0, 1}};
        int k = 2;
        
        ArrayList<ArrayList<Integer>> res = kClosest(points, k);
        
        for (ArrayList<Integer> point : res) {
            System.out.println(point.get(0) + ", " + point.get(1));
        }
    }
}

//Driver Code Ends
Python
# Function to calculate squared distance from the origin
def squaredDis(point):
    return point[0] * point[0] + point[1] * point[1]

# Comparator function
def cmp(p1, p2):
    return squaredDis(p1) - squaredDis(p2)

# Function to find k closest points
def kClosest(points, k):
    points.sort(key=lambda x: squaredDis(x))
    return points[:k]


if __name__ == "__main__":
#Driver Code Starts
    points = [[1, 3], [-2, 2], [5, 8], [0, 1]]
    k = 2
    
    res = kClosest(points, k)
    
    for point in res:
        print(f"{point[0]}, {point[1]}")

#Driver Code Ends
C#
//Driver Code Starts
using System;
using System.Collections.Generic;

class GFG {
//Driver Code Ends

    
    // Function to calculate squared distance from the origin
    static int squaredDis(List<int> point) {
        return point[0] * point[0] + point[1] * point[1];
    }

    // Comparator function
    static int cmp(List<int> p1, List<int> p2) {
        return squaredDis(p1) - squaredDis(p2);
    }

    // Function to find k closest points
    static List<List<int>> kClosest(int[,] points, int k) {
        List<List<int>> pointList = new List<List<int>>();
        for (int i = 0; i < points.GetLength(0); i++)
        {
            pointList.Add(new List<int> { points[i, 0], points[i, 1] });
        }

        pointList.Sort(cmp);
        return pointList.GetRange(0, k);
    }


//Driver Code Starts
    static void Main() {
        int[,] points = new int[,]
        {
            {1, 3},
            {-2, 2},
            {5, 8},
            {0, 1}
        };

        int k = 2;
        
        List<List<int>> res = kClosest(points, k);
        
        foreach (List<int> point in res)
        {
            Console.WriteLine(point[0] + ", " + point[1]);
        }
    }
}

//Driver Code Ends
JavaScript
// Function to calculate squared distance from the origin
function squaredDis(point) {
    return point[0] * point[0] + point[1] * point[1];
}

// Comparator function
function cmp(p1, p2) {
    return squaredDis(p1) - squaredDis(p2);
}

// Function to find k closest points
function kClosest(points, k) {
    points.sort(cmp);
    return points.slice(0, k);
}


// Driver Code
//Driver Code Starts
const points = [[1, 3], [-2, 2], [5, 8], [0, 1]];
const k = 2;

let res = kClosest(points, k);

for (let point of res) {
    console.log(point[0] + ", " + point[1]);
}

//Driver Code Ends

Output
0, 1
-2, 2

[Alternate Approach] Using Quick Select - O(n^2) Time and O(n) Space

The idea is to use QuickSort’s partitioning. We pick one point as a pivot(usually last) and move all points closer than it to the left and all points farther to the right, placing the pivot in its correct position. Then we check the left side including the pivot: if it has exactly k points, they are the closest. If it has more than k, we search only in the left side. If it has fewer, we take all left points and search for the remaining points in the right side, reducing k by the number of points already taken.

C++
//Driver Code Starts
#include <iostream>
#include <vector>
using namespace std;
//Driver Code Ends


// Function to calculate squared distance from the origin
int squaredDis(vector<int>& point) {
    return point[0] * point[0] + point[1] * point[1];
}

int partition(vector<vector<int>>& points, int left, int right) {
    
    // Last element is chosen as a pivot.
    vector<int> pivot = points[right];
    int i = left;

    for (int j = left; j < right; j++) {
      
          // Elements greater than or equal to pivot
          // are placed in the left side of pivot
        if (squaredDis(points[j]) <= squaredDis(pivot)) {
            swap(points[i], points[j]);
            i++;
        }
    }
    swap(points[i], points[right]);

    return i;
}

void quickSelect(vector<vector<int>>& points, int left, int right, int k) {
    if (left <= right) {
        int pivotIdx = partition(points, left, right);
        
        // Count of all elements in the left part
        int leftCnt = pivotIdx - left + 1;

        if (leftCnt == k) 
            return;
        
        // Search in the left subarray
        if (leftCnt > k)
            quickSelect(points, left, pivotIdx - 1, k);
            
        // Reduce the k by number of elements already covered
        // and search in the right subarray
        else 
            quickSelect(points, pivotIdx + 1, right, k - leftCnt);
    }
}

// Function to return k closest points to the origin
vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
    
    quickSelect(points, 0, points.size() - 1, k);

    return vector<vector<int>> (points.begin(), points.begin() + k);
}


//Driver Code Starts
int main() {
    vector<vector<int>> points = {{1, 3}, {-2, 2}, {5, 8}, {0, 1}};
    int k = 2;
    
    vector<vector<int>> res = kClosest(points, k);
    
    for (vector<int> point : res) {
        cout << point[0] << ", " << point[1];
        cout<<endl;
    }

    return 0;
}
//Driver Code Ends
Java
//Driver Code Starts
import java.util.ArrayList;

class GFG {
//Driver Code Ends

    
    // Function to return k closest points to the origin
    static int squaredDis(ArrayList<Integer> point) {
        return point.get(0) * point.get(0) + point.get(1) * point.get(1);
    }

    static int partition(int[][] points, int left, int right) {

        // Last element is chosen as a pivot.
        int[] pivot = points[right];
        int i = left;

        for (int j = left; j < right; j++) {

            // Elements greater than or equal to pivot
            // are placed in the left side of pivot
            int[] current = points[j];
            if ((current[0]*current[0] + current[1]*current[1]) <= (pivot[0]*pivot[0] + pivot[1]*pivot[1])) {
                int[] temp = points[i];
                points[i] = points[j];
                points[j] = temp;
                i++;
            }
        }
        int[] temp = points[i];
        points[i] = points[right];
        points[right] = temp;

        return i;
    }

    static void quickSelect(int[][] points, int left, int right, int k) {
        if (left <= right) {
            int pivotIdx = partition(points, left, right);

            // Count of all elements in the left part
            int leftCnt = pivotIdx - left + 1;

            if (leftCnt == k)
                return;

            // Search in the left subarray
            if (leftCnt > k)
                quickSelect(points, left, pivotIdx - 1, k);

            // Reduce the k by number of elements already covered
            // and search in the right subarray
            else
                quickSelect(points, pivotIdx + 1, right, k - leftCnt);
        }
    }

    // Function to return k closest points to the origin
    static ArrayList<ArrayList<Integer>> kClosest(int[][] points, int k) {

        quickSelect(points, 0, points.length - 1, k);

        ArrayList<ArrayList<Integer>> res = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            ArrayList<Integer> temp = new ArrayList<>();
            temp.add(points[i][0]);
            temp.add(points[i][1]);
            res.add(temp);
        }
        return res;
    }


//Driver Code Starts
    public static void main(String[] args) {
        int[][] points = {{1, 3}, {-2, 2}, {5, 8}, {0, 1}};
        int k = 2;

        ArrayList<ArrayList<Integer>> res = kClosest(points, k);

        for (ArrayList<Integer> point : res) {
            System.out.println(point.get(0) + ", " + point.get(1));
        }
    }
}

//Driver Code Ends
Python
# Function to return k closest points to the origin
def squaredDis(point):
    return point[0] * point[0] + point[1] * point[1]


def partition(points, left, right):

    # Last element is chosen as a pivot.
    pivot = points[right]
    i = left

    for j in range(left, right):

        # Elements greater than or equal to pivot
        # are placed in the left side of pivot
        if squaredDis(points[j]) <= squaredDis(pivot):
            points[i], points[j] = points[j], points[i]
            i += 1

    points[i], points[right] = points[right], points[i]

    return i


def quickSelect(points, left, right, k):
    if left <= right:
        pivotIdx = partition(points, left, right)

        # Count of all elements in the left part
        leftCnt = pivotIdx - left + 1

        if leftCnt == k:
            return

        # Search in the left subarray
        if leftCnt > k:
            quickSelect(points, left, pivotIdx - 1, k)

        # Reduce the k by number of elements already covered
        # and search in the right subarray
        else:
            quickSelect(points, pivotIdx + 1, right, k - leftCnt)

# Function to return k closest points to the origin
def kClosest(points, k):

    quickSelect(points, 0, len(points) - 1, k)

    return points[:k]



if __name__ == "__main__":
#Driver Code Starts
    points = [[1, 3], [-2, 2], [5, 8], [0, 1]]
    k = 2

    res = kClosest(points, k)

    for point in res:
        print(f"{point[0]}, {point[1]}")

#Driver Code Ends
C#
//Driver Code Starts
using System;
using System.Collections.Generic;

class GFG {
//Driver Code Ends

    
    // Function to calculate squared distance from the origin
    static int squaredDis(int[,] point, int idx)
    {
        return point[idx, 0] * point[idx, 0] + point[idx, 1] * point[idx, 1];
    }

    static int partition(int[,] points, int left, int right)
    {
        // Last element is chosen as a pivot.
        int pivotX = points[right, 0];
        int pivotY = points[right, 1];
        int i = left;

        for (int j = left; j < right; j++)
        {
            int distJ = points[j, 0] * points[j, 0] + points[j, 1] * points[j, 1];
            int distPivot = pivotX * pivotX + pivotY * pivotY;

            if (distJ <= distPivot)
            {
                // Swap points[i] and points[j]
                int tempX = points[i, 0], tempY = points[i, 1];
                points[i, 0] = points[j, 0];
                points[i, 1] = points[j, 1];
                points[j, 0] = tempX;
                points[j, 1] = tempY;
                i++;
            }
        }

        // Swap pivot to its correct position
        int tempPX = points[i, 0], tempPY = points[i, 1];
        points[i, 0] = points[right, 0];
        points[i, 1] = points[right, 1];
        points[right, 0] = tempPX;
        points[right, 1] = tempPY;

        return i;
    }

    static void quickSelect(int[,] points, int left, int right, int k)
    {
        if (left <= right)
        {
            int pivotIdx = partition(points, left, right);
            
            // Count of all elements in the left part
            int leftCnt = pivotIdx - left + 1;

            if (leftCnt == k)
                return;
                            
            // Search in the left subarray
            if (leftCnt > k)
                quickSelect(points, left, pivotIdx - 1, k);
                            
            // Reduce the k by number of elements already covered
            // and search in the right subarray
            else
                quickSelect(points, pivotIdx + 1, right, k - leftCnt);
        }
    }
    
    // Function to return k closest points to the origin
    static List<List<int>> kClosest(int[,] points, int k)
    {
        quickSelect(points, 0, points.GetLength(0) - 1, k);

        List<List<int>> res = new List<List<int>>();
        for (int i = 0; i < k; i++)
        {
            List<int> temp = new List<int> { points[i, 0], points[i, 1] };
            res.Add(temp);
        }
        return res;
    }


//Driver Code Starts
    static void Main()
    {
        int[,] points = new int[,]
        {
            { 1, 3 },
            { -2, 2 },
            { 5, 8 },
            { 0, 1 }
        };
        int k = 2;

        List<List<int>> res = kClosest(points, k);

        foreach (List<int> point in res)
        {
            Console.WriteLine(point[0] + ", " + point[1]);
        }
    }
}

//Driver Code Ends
JavaScript
// Function to calculate squared distance from the origin
function squaredDis(point) {
    return point[0] * point[0] + point[1] * point[1];
}

function partition(points, left, right) {

    // Last element is chosen as a pivot.
    let pivot = points[right];
    let i = left;

    for (let j = left; j < right; j++) {

        // Elements greater than or equal to pivot
        // are placed in the left side of pivot
        if (squaredDis(points[j]) <= squaredDis(pivot)) {
            let temp = points[i];
            points[i] = points[j];
            points[j] = temp;
            i++;
        }
    }

    let temp2 = points[i];
    points[i] = points[right];
    points[right] = temp2;

    return i;
}

function quickSelect(points, left, right, k) {
    if (left <= right) {
        let pivotIdx = partition(points, left, right);

        // Count of all elements in the left part
        let leftCnt = pivotIdx - left + 1;

        if (leftCnt === k) return;

        // Search in the left subarray
        if (leftCnt > k) {
            quickSelect(points, left, pivotIdx - 1, k);
        } else {
            
            // Reduce the k by number of elements already covered
            // and search in the right subarray
            quickSelect(points, pivotIdx + 1, right, k - leftCnt);
        }
    }
}

// Function to return k closest points to the origin
function kClosest(points, k) {
    quickSelect(points, 0, points.length - 1, k);
    return points.slice(0, k);
}


// Driver Code
//Driver Code Starts
let points = [[1, 3], [-2, 2], [5, 8], [0, 1]];
let k = 2;

let res = kClosest(points, k);

for (let point of res) {
    console.log(point[0] + ", " + point[1]);
}

//Driver Code Ends

Output
0, 1
-2, 2

[Expected Approach] Using Priority Queue (Max-Heap) - O(n*log k) Time and O(k) Space

The idea is to use a priority queue (max-heap) to keep track of the k closest points to the origin based on their squared distances. After each iteration over the array, we update our queue so that the priority queue always contains the k closest points.

C++
//Driver Code Starts
#include <iostream>
#include <vector>
#include <queue>
#include <cmath>

using namespace std;
//Driver Code Ends


// Function to calculate squared distance from the origin
int squaredDis(vector<int>& point) {
    return point[0] * point[0] + 
      point[1] * point[1];
}

// Function to find k closest points to
// the origin
vector<vector<int>> kClosest(
        vector<vector<int>>& points, int k) {
    
    // Max heap to store points with their 
    // squared distances
    priority_queue<pair<int, vector<int>>> maxHeap;

    // Iterate through each point
    for (int i = 0; i < points.size(); i++) {
        int dist = squaredDis(points[i]);

        if (maxHeap.size() < k) {
            
            // If the heap size is less than k, 
            // insert the point
            maxHeap.push({dist, points[i]});
        } else {
            
            // If the heap size is k, compare with
            // the top element
            if (dist < maxHeap.top().first) {
                
                // Replace the top element if the
                // current point is closer
                maxHeap.pop();
                maxHeap.push({dist, points[i]});
            }
        }
    }

    // Take the k closest points from the heap
    vector<vector<int>> res;
    while (!maxHeap.empty()) {
        res.push_back(maxHeap.top().second);
        maxHeap.pop();
    }

    return res;
}


//Driver Code Starts
int main() {
    vector<vector<int>> points ={{1, 3}, {-2, 2}, {5, 8}, {0, 1}};
    int k = 2;
    
    vector<vector<int>> res = kClosest(points, k);
    
     for (vector<int> point : res) {
        cout << point[0] << ", " << point[1];
        cout<<endl;
    }

    return 0;
}
//Driver Code Ends
Java
//Driver Code Starts
import java.util.ArrayList;
import java.util.PriorityQueue;

class GFG {
//Driver Code Ends


    // Function to calculate squared distance from the origin
    static int squaredDis(int[] point) {
        return point[0] * point[0] + point[1] * point[1];
    }

    // Function to find k closest points to
    // the origin
    static ArrayList<ArrayList<Integer>> kClosest(int[][] points, int k) {

        // Max heap to store points with their 
        // squared distances
        PriorityQueue<int[]> maxHeap = new PriorityQueue<>(
            (a, b) -> b[0] - a[0]
        );

        // Iterate through each point
        for (int i = 0; i < points.length; i++) {
            int dist = squaredDis(points[i]);
            int[] entry = new int[]{dist, i}; // store index to retrieve point

            if (maxHeap.size() < k) {

                // If the heap size is less than k, 
                // insert the point
                maxHeap.add(entry);
            } else {

                // If the heap size is k, compare with
                // the top element
                if (dist < maxHeap.peek()[0]) {

                    // Replace the top element if the
                    // current point is closer
                    maxHeap.poll();
                    maxHeap.add(entry);
                }
            }
        }

        // Take the k closest points from the heap
        ArrayList<ArrayList<Integer>> res = new ArrayList<>();
        while (!maxHeap.isEmpty()) {
            int idx = maxHeap.poll()[1];
            ArrayList<Integer> temp = new ArrayList<>();
            temp.add(points[idx][0]);
            temp.add(points[idx][1]);
            res.add(temp);
        }

        return res;
    }


//Driver Code Starts
    public static void main(String[] args) {
        int[][] points = {{1, 3}, {-2, 2}, {5, 8}, {0, 1}};
        int k = 2;

        ArrayList<ArrayList<Integer>> res = kClosest(points, k);

        for (ArrayList<Integer> point : res) {
            System.out.println(point.get(0) + ", " + point.get(1));
        }
    }
}

//Driver Code Ends
Python
#Driver Code Starts
import heapq
#Driver Code Ends


# Function to calculate squared distance from the origin
def squaredDis(point):
    return point[0] * point[0] + point[1] * point[1]

# Function to find k closest points to
# the origin
def kClosest(points, k):

    # Max heap to store points with their 
    # squared distances
    maxHeap = []

    # Iterate through each point
    for i in range(len(points)):
        dist = squaredDis(points[i])

        if len(maxHeap) < k:

            # If the heap size is less than k, 
            # insert the point
            heapq.heappush(maxHeap, (-dist, points[i]))
        else:

            # If the heap size is k, compare with
            # the top element
            if dist < -maxHeap[0][0]:

                # Replace the top element if the
                # current point is closer
                heapq.heappop(maxHeap)
                heapq.heappush(maxHeap, (-dist, points[i]))

    # Take the k closest points from the heap
    res = []
    while maxHeap:
        res.append(heapq.heappop(maxHeap)[1])

    return res


#Driver Code Starts
if __name__ == "__main__":
    points = [[1, 3], [-2, 2], [5, 8], [0, 1]]
    k = 2

    res = kClosest(points, k)

    for point in res:
        print(f"{point[0]}, {point[1]}")

#Driver Code Ends
C#
//Driver Code Starts
using System;
using System.Collections.Generic;

// Custom MaxHeap class
class MaxHeap
{
    private List<(int dist, List<int> point)> heap = new List<(int, List<int>)>();

    private void Swap(int i, int j)
    {
        var temp = heap[i];
        heap[i] = heap[j];
        heap[j] = temp;
    }

    private void HeapifyUp(int index)
    {
        while (index > 0)
        {
            int parent = (index - 1) / 2;
            if (heap[index].dist > heap[parent].dist)
            {
                Swap(index, parent);
                index = parent;
            }
            else break;
        }
    }

    private void HeapifyDown(int index)
    {
        int n = heap.Count;
        while (true)
        {
            int left = 2 * index + 1;
            int right = 2 * index + 2;
            int largest = index;

            if (left < n && heap[left].dist > heap[largest].dist)
                largest = left;
            if (right < n && heap[right].dist > heap[largest].dist)
                largest = right;

            if (largest != index)
            {
                Swap(index, largest);
                index = largest;
            }
            else break;
        }
    }

    public void Add(List<int> point, int dist)
    {
        heap.Add((dist, point));
        HeapifyUp(heap.Count - 1);
    }

    public (int dist, List<int> point) Pop()
    {
        var top = heap[0];
        heap[0] = heap[heap.Count - 1];
        heap.RemoveAt(heap.Count - 1);
        if (heap.Count > 0) HeapifyDown(0);
        return top;
    }

    public (int dist, List<int> point) Peek()
    {
        return heap[0];
    }

    public int Count()
    {
        return heap.Count;
    }
}

class GFG {
//Driver Code Ends

    
    // Function to calculate squared distance from the origin
    static int squaredDis(int[,] points, int idx)
    {
        return points[idx, 0] * points[idx, 0] + points[idx, 1] * points[idx, 1];
    }

    // Function to find k closest points to the origin
    static List<List<int>> kClosest(int[,] points, int k)
    {
        int n = points.GetLength(0);

        // Max-heap to keep track of k closest points
        MaxHeap sorted = new MaxHeap();

        for (int i = 0; i < n; i++)
        {
            List<int> pointList = new List<int> { points[i, 0], points[i, 1] };
            int dist = squaredDis(points, i);

            if (sorted.Count() < k)
            {
                sorted.Add(pointList, dist);
            }
            else
            {
                int maxDist = sorted.Peek().dist;
                if (dist < maxDist)
                {
                    sorted.Pop();
                    sorted.Add(pointList, dist);
                }
            }
        }

        // Collect results
        List<List<int>> res = new List<List<int>>();
        while (sorted.Count() > 0)
        {
            res.Add(sorted.Pop().point);
        }

        return res;
    }


//Driver Code Starts
    static void Main()
    {
        int[,] points = new int[,]
        {
            { 1, 3 },
            { -2, 2 },
            { 5, 8 },
            { 0, 1 }
        };

        int k = 2;

        List<List<int>> res = kClosest(points, k);

        foreach (List<int> point in res)
        {
            Console.WriteLine(point[0] + ", " + point[1]);
        }
    }
}

//Driver Code Ends
JavaScript
//Driver Code Starts
// MaxHeap class
class MaxHeap {
    constructor() {
        this.heap = [];
    }

    // Push an element into the heap
    push(item) {
        this.heap.push(item);
        let i = this.heap.length - 1;

        while (i > 0) {
            let parent = Math.floor((i - 1) / 2);
            if (this.heap[i][0] > this.heap[parent][0]) {
                [this.heap[i], this.heap[parent]] = [this.heap[parent], this.heap[i]];
                i = parent;
            } else break;
        }
    }

    // Pop the top (maximum) element from the heap
    pop() {
        const top = this.heap[0];
        const last = this.heap.pop();

        if (this.heap.length > 0) {
            this.heap[0] = last;
            let i = 0;

            while (true) {
                let left = 2 * i + 1;
                let right = 2 * i + 2;
                let largest = i;

                if (left < this.heap.length && this.heap[left][0] > this.heap[largest][0])
                    largest = left;
                if (right < this.heap.length && this.heap[right][0] > this.heap[largest][0])
                    largest = right;

                if (largest !== i) {
                    [this.heap[i], this.heap[largest]] = [this.heap[largest], this.heap[i]];
                    i = largest;
                } else break;
            }
        }

        return top;
    }

    // Peek at the top element
    top() {
        return this.heap[0];
    }

    // Return current size
    size() {
        return this.heap.length;
    }

    // Check if heap is empty
    isEmpty() {
        return this.heap.length === 0;
    }
}
//Driver Code Ends


// Function to calculate squared distance from the origin
function squaredDis(point) {
    return point[0] * point[0] + point[1] * point[1];
}

// Function to find k closest points to the origin
function kClosest(points, k) {

    // Create a max heap
    let maxHeap = new MaxHeap();

    // Iterate through each point
    for (let i = 0; i < points.length; i++) {
        let dist = squaredDis(points[i]);

        if (maxHeap.size() < k) {

            // If heap size is less than k, insert the point
            maxHeap.push([dist, points[i]]);
        } else if (dist < maxHeap.top()[0]) {

            // Replace the top element if current point is closer
            maxHeap.pop();
            maxHeap.push([dist, points[i]]);
        }
    }

    // Extract k closest points
    let res = [];
    while (!maxHeap.isEmpty()) {
        res.push(maxHeap.pop()[1]);
    }

    return res;
}


//Driver Code Starts
// Driver code
const points = [[1, 3], [-2, 2], [5, 8], [0, 1]];
const k = 2;

let res = kClosest(points, k);

for (let point of res) {
    console.log(point[0] + ", " + point[1]);
}

//Driver Code Ends

Output
-2, 2
0, 1


Explore