Nodes at k distance from Root

Last Updated : 2 May, 2026

Given a binary tree, and an integer k. The task is to return all the nodes which are at k distance from the root. 

Examples:

Input:

Print-nodes-at-k-distance-from-root

Output: 2 9 13
Explanation: In the above tree 2, 9 & 13 are at distance 2 from root. 

Input:

Print-nodes-at-k-distance-from-root-2

Output: 5 11
Explanation: In the above tree 5 & 11 are at distance 1 from root.

Try It Yourself
redirect icon

The idea is to use a recursive approach to find all nodes at a specified distance k from the root of a binary tree. or nodes at greater distances, the function recursively explores both left and right children, decrementing k with each call. And add the current node's data when k becomes 0.

Algorithm:

  • If root is NULL or k < 0, return
  • If k == 0, add current node’s value to result
  • Recur for left subtree with k - 1
  • Recur for right subtree with k - 1
C++
// c++ of find all nodes that are at distance 
// k from the root of a binary tree usign recursion.

#include<bits/stdc++.h> 

using namespace std;

class Node  { 
public:
    int data; 
    Node* left; 
    Node* right; 
    Node(int x) {
        data = x;
        left = nullptr;
        right = nullptr;
    }
}; 

// Function to collect nodes at distance
// k from the root in a vector
void KdistanceUill(Node *root, int k, vector<int> &result) { 
  
  	// If root is null and k is not zero return it
    if(root == NULL|| k < 0 ) 
        return; 
  
  	// if k is zero then store the data and return
    if( k == 0 ) { 
      	result.push_back(root->data);
        return;
	} 
  
  	// Make recursive call on left and right pointer
	KdistanceUill(root->left, k - 1, result) ; 
    KdistanceUill(root->right, k - 1, result) ; 
} 

// Function to result all nodes at kth distance from root
vector<int> Kdistance(struct Node *root, int k) {
	vector<int> result;
  	KdistanceUill(root, k, result);
  	return result;
}

//Driver Code
int main() {

    // Constructing the tree:
    //        8
    //       / \
    //      7   10
    //     /    / \
    //    2    9  13

    Node* root = new Node(8);
    root->left = new Node(7);
    root->right = new Node(10);
    root->left->left = new Node(2);
    root->right->left = new Node(9);
    root->right->right = new Node(13);

    int k = 2;

    vector<int> res = Kdistance(root, k);

    cout << "Nodes at distance " << k << ": ";
    for (int x : res) {
        cout << x << " ";
    }
    cout << endl;

    return 0;
}
Java
import java.util.ArrayList;
import java.util.List;

class Node {
    int data;
    Node left;
    Node right;

    Node(int x) {
        data = x;
        left = null;
        right = null;
    }
}

public class GfG {
    // Function to collect nodes at distance
    // k from the root in a vector
    static void KdistanceUill(Node root, int k, List<Integer> result) {

        // If root is null and k is not zero return it
        if (root == null || k < 0)
            return;

        // if k is zero then store the data and return
        if (k == 0) {
            result.add(root.data);
            return;
        }

        // Make recursive call on left and right pointer
        KdistanceUill(root.left, k - 1, result);
        KdistanceUill(root.right, k - 1, result);
    }

    // Function to result all nodes at kth distance from root
    static List<Integer> Kdistance(Node root, int k) {
        List<Integer> result = new ArrayList<>();
        KdistanceUill(root, k, result);
        return result;
    }

    public static void main(String[] args) {

        // Constructing the tree:
        //        8
        //       / \
        //      7   10
        //     /    / \\
        //    2    9  13

        Node root = new Node(8);
        root.left = new Node(7);
        root.right = new Node(10);
        root.left.left = new Node(2);
        root.right.left = new Node(9);
        root.right.right = new Node(13);

        int k = 2;

        List<Integer> res = Kdistance(root, k);

        System.out.print("Nodes at distance " + k + ": ");
        for (int x : res) {
            System.out.print(x + " ");
        }
        System.out.println();
    }
}
Python
class Node:
    def __init__(self, x):
        self.data = x
        self.left = None
        self.right = None

# Function to collect nodes at distance
# k from the root in a vector
def KdistanceUill(root, k, result):

    # If root is null and k is not zero return it
    if root is None or k < 0:
        return

    # if k is zero then store the data and return
    if k == 0:
        result.append(root.data)
        return

    # Make recursive call on left and right pointer
    KdistanceUill(root.left, k - 1, result)
    KdistanceUill(root.right, k - 1, result)

# Function to result all nodes at kth distance from root
def Kdistance(root, k):
    result = []
    KdistanceUill(root, k, result)
    return result

# Driver Code
if __name__ == '__main__':

    # Constructing the tree:
    #        8
    #       / \
    #      7   10
    #     /    / \\
    #    2    9  13

    root = Node(8)
    root.left = Node(7)
    root.right = Node(10)
    root.left.left = Node(2)
    root.right.left = Node(9)
    root.right.right = Node(13)

    k = 2

    res = Kdistance(root, k)

    print('Nodes at distance', k, ':', end=' ')
    for x in res:
        print(x, end=' ')
    print()
C#
using System;
using System.Collections.Generic;

public class Node
{
    public int data;
    public Node left, right;

    public Node(int x)
    {
        data = x;
        left = null;
        right = null;
    }
}

public class GfG
{
    // Function to collect nodes at distance
    // k from the root in a vector
    static void KdistanceUill(Node root, int k, List<int> result)
    {
        // If root is null and k is not zero return it
        if (root == null || k < 0)
            return;

        // if k is zero then store the data and return
        if (k == 0)
        {
            result.Add(root.data);
            return;
        }

        // Make recursive call on left and right pointer
        KdistanceUill(root.left, k - 1, result);
        KdistanceUill(root.right, k - 1, result);
    }

    // Function to result all nodes at kth distance from root
    static List<int> Kdistance(Node root, int k)
    {
        List<int> result = new List<int>();
        KdistanceUill(root, k, result);
        return result;
    }

    static void Main(string[] args)
    {
        // Constructing the tree:
        //        8
        //       / \
        //      7   10
        //     /    / \\
        //    2    9  13

        Node root = new Node(8);
        root.left = new Node(7);
        root.right = new Node(10);
        root.left.left = new Node(2);
        root.right.left = new Node(9);
        root.right.right = new Node(13);

        int k = 2;

        List<int> res = Kdistance(root, k);

        Console.Write("Nodes at distance " + k + ": ");
        foreach (int x in res)
        {
            Console.Write(x + " ");
        }
        Console.WriteLine();
    }
}
JavaScript
class Node {
    constructor(x) {
        this.data = x;
        this.left = null;
        this.right = null;
    }
}

// Function to collect nodes at distance
// k from the root in a vector
function KdistanceUill(root, k, result) {

    // If root is null and k is not zero return it
    if (root === null || k < 0)
        return;

    // if k is zero then store the data and return
    if (k === 0) {
        result.push(root.data);
        return;
    }

    // Make recursive call on left and right pointer
    KdistanceUill(root.left, k - 1, result);
    KdistanceUill(root.right, k - 1, result);
}

// Function to result all nodes at kth distance from root
function Kdistance(root, k) {
    let result = [];
    KdistanceUill(root, k, result);
    return result;
}

// Driver Code
// Constructing the tree:
//        8
//       / \
//      7   10
//     /    / \\
//    2    9  13
let root = new Node(8);
root.left = new Node(7);
root.right = new Node(10);
root.left.left = new Node(2);
root.right.left = new Node(9);
root.right.right = new Node(13);

let k = 2;

let res = Kdistance(root, k);

console.log('Nodes at distance ' + k + ': ' + res.join(' '));

Output
Nodes at distance 2: 2 9 13 

Time Complexity: O(n) where n is number of nodes in the given binary tree.
Space Complexity : O(h) where h is the height of binary tree.

[Better Approach] Using Queue - O(n) time and O(n) Space

This idea is based on line by line level order traversal. Start with an empty queue, enqueue the root, and set the level to 0. While the queue isn't empty, if the level equals k, store and return node values. For each node, dequeue it and enqueue its children. Increment the level after processing all nodes. return an empty list if no nodes are at distance k.

Algorithm:

  • Push root into the queue and set level or lvl = 0, then start level order traversal.
  • At each step, get current level size and check if lvl == k; if yes, store all node values and return.
  • Otherwise, remove each node and push its left and right children into the queue.
  • After processing the level, increment lvl and continue.
  • If the queue becomes empty before reaching k, return the empty result.
C++
// C++ code to implement the iterative 
// approach using a Queue
#include <bits/stdc++.h>
using namespace std;

class Node {
public:
	int data;
	Node *left, *right;
	Node(int x) {
     	data = x;
      	left = right = nullptr;
    }
};

// Function to find all nodes at distance k from the root
vector<int> Kdistance(Node* root, int k) {
  
    // If the root is NULL, return an empty vector
    if (root == nullptr)
        return {};

    vector<int> result;

    // Create an empty queue for level order traversal
    queue<Node*> q;
    q.push(root);
    int lvl = 0;

    while (!q.empty()) {
      
        // Get the number of nodes at the current level
        int n = q.size();

        if (lvl == k) {
          
            // Collect all nodes at this level
            for (int i = 0; i < n; i++) {
                Node* temp = q.front();
                result.push_back(temp->data);
                q.pop();
            }
            // Return the result as we've found 
          	// nodes at distance k
            return result;
        }

        // Process all nodes at the current level
        for (int i = 0; i < n; i++) {
            Node* temp = q.front();
            q.pop();
          
            // If the left child exists, add it to the queue
            if (temp->left != nullptr)
                q.push(temp->left);
          
            // If the right child exists, add it to the queue
            if (temp->right != nullptr)
                q.push(temp->right);
        }

        // Move to the next level
        lvl += 1;
    }

    return result;
}

//Driver Code
int main() {

    // Constructing the tree:
    //        8
    //       / \
    //      7   10
    //     /    / \
    //    2    9  13

    Node* root = new Node(8);
    root->left = new Node(7);
    root->right = new Node(10);
    root->left->left = new Node(2);
    root->right->left = new Node(9);
    root->right->right = new Node(13);

    int k = 2;

    vector<int> res = Kdistance(root, k);

    cout << "Nodes at distance " << k << ": ";
    for (int x : res) {
        cout << x << " ";
    }
    cout << endl;

    return 0;
}
Java
import java.util.LinkedList;
import java.util.Queue;
import java.util.ArrayList;

class Node {
    public int data;
    public Node left, right;

    public Node(int x) {
        data = x;
        left = right = null;
    }
}

public class GfG {
    // Function to find all nodes at distance k from the root
    public static ArrayList<Integer> Kdistance(Node root, int k) {

        // If the root is NULL, return an empty ArrayList
        if (root == null)
            return new ArrayList<>();

        ArrayList<Integer> result = new ArrayList<>();

        // Create an empty queue for level order traversal
        Queue<Node> q = new LinkedList<>();
        q.add(root);
        int lvl = 0;

        while (!q.isEmpty()) {

            // Get the number of nodes at the current level
            int n = q.size();

            if (lvl == k) {

                // Collect all nodes at this level
                for (int i = 0; i < n; i++) {
                    Node temp = q.poll();
                    result.add(temp.data);
                }
                // Return the result as we've found nodes at distance k
                return result;
            }

            // Process all nodes at the current level
            for (int i = 0; i < n; i++) {
                Node temp = q.poll();

                // If the left child exists, add it to the queue
                if (temp.left!= null)
                    q.add(temp.left);

                // If the right child exists, add it to the queue
                if (temp.right!= null)
                    q.add(temp.right);
            }

            // Move to the next level
            lvl += 1;
        }

        return result;
    }

    //Driver Code
    public static void main(String[] args) {

        // Constructing the tree:
        //        8
        //       / \
        //      7   10
        //     /    / \
        //    2    9  13

        Node root = new Node(8);
        root.left = new Node(7);
        root.right = new Node(10);
        root.left.left = new Node(2);
        root.right.left = new Node(9);
        root.right.right = new Node(13);

        int k = 2;

        ArrayList<Integer> res = Kdistance(root, k);

        System.out.print("Nodes at distance " + k + ": ");
        for (int x : res) {
            System.out.print(x + " ");
        }
        System.out.println();
    }
}
Python
from collections import deque

class Node:
    def __init__(self, x):
        self.data = x
        self.left = self.right = None

# Function to find all nodes at distance k from the root
def Kdistance(root, k):

    # If the root is None, return an empty list
    if root is None:
        return []

    result = []

    # Create an empty queue for level order traversal
    q = deque([root])
    lvl = 0

    while q:
        # Get the number of nodes at the current level
        n = len(q)

        if lvl == k:
            # Collect all nodes at this level
            for _ in range(n):
                temp = q.popleft()
                result.append(temp.data)
            # Return the result as we've found nodes at distance k
            return result

        # Process all nodes at the current level
        for _ in range(n):
            temp = q.popleft()

            # If the left child exists, add it to the queue
            if temp.left is not None:
                q.append(temp.left)

            # If the right child exists, add it to the queue
            if temp.right is not None:
                q.append(temp.right)

        # Move to the next level
        lvl += 1

    return result

#Driver Code
if __name__ == '__main__':

    # Constructing the tree:
    #        8
    #       / \
    #      7   10
    #     /    / \
    #    2    9  13

    root = Node(8)
    root.left = Node(7)
    root.right = Node(10)
    root.left.left = Node(2)
    root.right.left = Node(9)
    root.right.right = Node(13)

    k = 2

    res = Kdistance(root, k)

    print("Nodes at distance ", k, ": ", end="")
    for x in res:
        print(x, end=" ")
    print()
C#
using System;
using System.Collections.Generic;

public class Node {
    public int data;
    public Node left, right;
    public Node(int x) {
        data = x;
        left = right = null;
    }
}

public class GfG {
    // Function to find all nodes at distance k from the root
    public static List<int> Kdistance(Node root, int k) {

        // If the root is NULL, return an empty list
        if (root == null)
            return new List<int>();

        List<int> result = new List<int>();

        // Create an empty queue for level order traversal
        Queue<Node> q = new Queue<Node>();
        q.Enqueue(root);
        int lvl = 0;

        while (q.Count > 0) {

            // Get the number of nodes at the current level
            int n = q.Count;

            if (lvl == k) {

                // Collect all nodes at this level
                for (int i = 0; i < n; i++) {
                    Node temp = q.Dequeue();
                    result.Add(temp.data);
                }
                // Return the result as we've found nodes at distance k
                return result;
            }

            // Process all nodes at the current level
            for (int i = 0; i < n; i++) {
                Node temp = q.Dequeue();

                // If the left child exists, add it to the queue
                if (temp.left!= null)
                    q.Enqueue(temp.left);

                // If the right child exists, add it to the queue
                if (temp.right!= null)
                    q.Enqueue(temp.right);
            }

            // Move to the next level
            lvl += 1;
        }

        return result;
    }

    //Driver Code
    public static void Main() {

        // Constructing the tree:
        //        8
        //       / \
        //      7   10
        //     /    / \\
        //    2    9  13

        Node root = new Node(8);
        root.left = new Node(7);
        root.right = new Node(10);
        root.left.left = new Node(2);
        root.right.left = new Node(9);
        root.right.right = new Node(13);

        int k = 2;

        List<int> res = Kdistance(root, k);

        Console.Write("Nodes at distance " + k + ": ");
        for (int x = 0; x < res.Count; x++) {
            Console.Write(res[x] + " ");
        }
        Console.WriteLine();
    }
}
JavaScript
class Node {
    constructor(x) {
        this.data = x;
        this.left = null;
        this.right = null;
    }
}

// Function to find all nodes at distance k from the root
function Kdistance(root, k) {

    // If the root is NULL, return an empty array
    if (root === null)
        return [];

    let result = [];

    // Create an empty queue for level order traversal
    let q = [];
    q.push(root);
    let lvl = 0;

    while (q.length > 0) {

        // Get the number of nodes at the current level
        let n = q.length;

        if (lvl === k) {

            // Collect all nodes at this level
            for (let i = 0; i < n; i++) {
                let temp = q.shift();
                result.push(temp.data);
            }
            // Return the result as we've found nodes at distance k
            return result;
        }

        // Process all nodes at the current level
        for (let i = 0; i < n; i++) {
            let temp = q.shift();

            // If the left child exists, add it to the queue
            if (temp.left!== null)
                q.push(temp.left);

            // If the right child exists, add it to the queue
            if (temp.right!== null)
                q.push(temp.right);
        }

        // Move to the next level
        lvl += 1;
    }

    return result;
}

//Driver Code
// Constructing the tree:
//        8
//       / \
//      7   10
//     /    / \\
//    2    9  13

let root = new Node(8);
root.left = new Node(7);
root.right = new Node(10);
root.left.left = new Node(2);
root.right.left = new Node(9);
root.right.right = new Node(13);

let k = 2;

let res = Kdistance(root, k);

console.log(`Nodes at distance ${k}:`);
for (let x of res) {
    process.stdout.write(`${x} `);
}
console.log();

Output
Nodes at distance 2: 2 9 13 

Time Complexity: O(n) where n is number of nodes in the given binary tree.
Space Complexity: O(n)

[Expected Approach] Using Stack - O(n) time and O(n) Space

The idea is based on iterative (Stack based) Preorder traversal. We use a stack of pairs where we push level along with the node.

Algorithm:

  • Push root with level 0 into stack and start traversal.
  • Pop top element, if node is NULL, continue.
  • If its level equals k, add node value to result.
  • Otherwise, push the right child with level + 1, then the left child with level + 1.
  • Repeat until the stack becomes empty and return the result.
C++
// C++ code to implement the iterative 
// approach using a stack
#include <bits/stdc++.h>
using namespace std;

class Node {
public:
	int data;
	Node *left, *right;
	Node(int x) {
     	data = x;
      	left = right = nullptr;
    }
};

// Function to perform iterative DFS traversal and find all
// nodes at distance K
vector<int> Kdistance(struct Node* root, int k) {
	vector<int> result;
	stack<pair<Node*, int> > stk;
	stk.push({root, 0});

	while (!stk.empty()) {
		Node* curr = stk.top().first;
		int level = stk.top().second;
		stk.pop();

		if (curr == nullptr) {
			continue;
		}

		// If the current node is at distance K from the
		// root, add its data to the result
		if (level == k) {
			result.push_back(curr->data);
		}

		// Push the right child onto the stack with its
		// level incremented by 1
		stk.push({curr->right, level + 1});

		// Push the left child onto the stack with its level
		// incremented by 1
		stk.push({curr->left, level + 1});
	}

	return result;
}

//Driver Code
int main() {

    // Constructing the tree:
    //        8
    //       / \
    //      7   10
    //     /    / \
    //    2    9  13

    Node* root = new Node(8);
    root->left = new Node(7);
    root->right = new Node(10);
    root->left->left = new Node(2);
    root->right->left = new Node(9);
    root->right->right = new Node(13);

    int k = 2;

    vector<int> res = Kdistance(root, k);

    cout << "Nodes at distance " << k << ": ";
    for (int x : res) {
        cout << x << " ";
    }
    cout << endl;

    return 0;
}
Java
import java.util.ArrayList;
import java.util.Stack;

class Node {
    public int data;
    Node left, right;
    Node(int x) {
        data = x;
        left = right = null;
    }
}

public class GfG {
    // Function to perform iterative DFS traversal and find all
    // nodes at distance K
    static ArrayList<Integer> Kdistance(Node root, int k) {
        ArrayList<Integer> result = new ArrayList<>();
        Stack<Pair> stk = new Stack<>();
        stk.push(new Pair(root, 0));

        while (!stk.isEmpty()) {
            Pair p = stk.pop();
            Node curr = p.first;
            int level = p.second;

            if (curr == null) {
                continue;
            }

            // If the current node is at distance K from the
            // root, add its data to the result
            if (level == k) {
                result.add(curr.data);
            }

            // Push the right child onto the stack with its
            // level incremented by 1
            stk.push(new Pair(curr.right, level + 1));

            // Push the left child onto the stack with its level
            // incremented by 1
            stk.push(new Pair(curr.left, level + 1));
        }

        return result;
    }

    public static void main(String[] args) {
        // Constructing the tree:
        //        8
        //       / \
        //      7   10
        //     /    / \\
        //    2    9  13

        Node root = new Node(8);
        root.left = new Node(7);
        root.right = new Node(10);
        root.left.left = new Node(2);
        root.right.left = new Node(9);
        root.right.right = new Node(13);

        int k = 2;

        ArrayList<Integer> res = Kdistance(root, k);

        System.out.print("Nodes at distance " + k + ": ");
        for (int x : res) {
            System.out.print(x + " ");
        }
        System.out.println();
    }

    // Helper class to store node and its level
    static class Pair {
        Node first;
        int second;
        Pair(Node first, int second) {
            this.first = first;
            this.second = second;
        }
    }
}
Python
class Node:
    def __init__(self, x):
        self.data = x
        self.left = None
        self.right = None

# Function to perform iterative DFS traversal and find all
# nodes at distance K
def Kdistance(root, k):
    result = []
    stk = [(root, 0)]

    while stk:
        curr, level = stk.pop()

        if curr is None:
            continue

        # If the current node is at distance K from the
        # root, add its data to the result
        if level == k:
            result.append(curr.data)

        # Push the right child onto the stack with its
        # level incremented by 1
        stk.append((curr.right, level + 1))

        # Push the left child onto the stack with its level
        # incremented by 1
        stk.append((curr.left, level + 1))

    return result

#Driver Code

# Constructing the tree:
#        8
#       / \
#      7   10
#     /    / \\
#    2    9  13

root = Node(8)
root.left = Node(7)
root.right = Node(10)
root.left.left = Node(2)
root.right.left = Node(9)
root.right.right = Node(13)

k = 2

res = Kdistance(root, k)

print('Nodes at distance', k, ':', ' '.join(map(str, res)))
C#
using System;
using System.Collections.Generic;

public class Node {
    public int data;
    public Node left, right;
    public Node(int x) {
        data = x;
        left = right = null;
    }
}

public class GfG {
    // Function to perform iterative DFS traversal and find all
    // nodes at distance K
    public static List<int> Kdistance(Node root, int k) {
        List<int> result = new List<int>();
        Stack<(Node, int)> stk = new Stack<(Node, int)>();
        stk.Push((root, 0));

        while (stk.Count > 0) {
            var (curr, level) = stk.Pop();

            if (curr == null) {
                continue;
            }

            // If the current node is at distance K from the
            // root, add its data to the result
            if (level == k) {
                result.Add(curr.data);
            }

            // Push the right child onto the stack with its
            // level incremented by 1
            stk.Push((curr.right, level + 1));

            // Push the left child onto the stack with its level
            // incremented by 1
            stk.Push((curr.left, level + 1));
        }

        return result;
    }

    public static void Main() {
        // Constructing the tree:
        //        8
        //       / \
        //      7   10
        //     /    / \\
        //    2    9  13

        Node root = new Node(8);
        root.left = new Node(7);
        root.right = new Node(10);
        root.left.left = new Node(2);
        root.right.left = new Node(9);
        root.right.right = new Node(13);

        int k = 2;

        List<int> res = Kdistance(root, k);

        Console.Write("Nodes at distance " + k + ": ");
        foreach (int x in res) {
            Console.Write(x + " ");
        }
        Console.WriteLine();
    }
}
JavaScript
class Node {
    constructor(x) {
        this.data = x;
        this.left = null;
        this.right = null;
    }
}

// Function to perform iterative DFS traversal and find all
// nodes at distance K
function Kdistance(root, k) {
    let result = [];
    let stk = [];
    stk.push([root, 0]);

    while (stk.length > 0) {
        let [curr, level] = stk.pop();

        if (curr === null) {
            continue;
        }

        // If the current node is at distance K from the
        // root, add its data to the result
        if (level === k) {
            result.push(curr.data);
        }

        // Push the right child onto the stack with its
        // level incremented by 1
        stk.push([curr.right, level + 1]);

        // Push the left child onto the stack with its level
        // incremented by 1
        stk.push([curr.left, level + 1]);
    }

    return result;
}

//Driver Code

// Constructing the tree:
//        8
//       / \
//      7   10
//     /    / \\
//    2    9  13

let root = new Node(8);
root.left = new Node(7);
root.right = new Node(10);
root.left.left = new Node(2);
root.right.left = new Node(9);
root.right.right = new Node(13);

let k = 2;

let res = Kdistance(root, k);

console.log('Nodes at distance ' + k + ':'+ res.join(' '));

Output
Nodes at distance 2: 2 9 13 

Time Complexity: O(n) where n is the number of nodes in tree.
Auxiliary Space: O(n)

Comment