Open In App

Sum of nodes within K distance from target

Last Updated : 11 Nov, 2023
Comments
Improve
Suggest changes
Like Article
Like
Report

Given a binary tree, a target node and a positive integer K on it,  the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).

Examples:

Input: target = 9, K = 1,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22

Input: target = 40,  K = 2,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113

 

Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:

Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.

Follow the steps mentioned below to implement the approach:

  • Create a hash table (say par)to store the parent of each node.
  • Perform a DFS and store the parent of each node.
  • Now find the target in the tree.
  • Create a hash table to mark the visited nodes.
  • Start a DFS from target:
    • If the distance is not K, add the value in the final sum.
    • If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
    • Return the sum of its neighbours while the recursion for the current node is complete
  • Return the sum of all the nodes within K distance from the target.

Below is the implementation of the above approach:

C++
// C++ code to implement above approach

#include <bits/stdc++.h>
using namespace std;

// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
    {
        this->data = val;
        this->left = 0;
        this->right = 0;
    }
};

// Function for marking the parent node
// for all the nodes using DFS
void dfs(Node* root,
         unordered_map<Node*, Node*>& par)
{
    if (root == 0)
        return;
    if (root->left != 0)
        par[root->left] = root;
    if (root->right != 0)
        par[root->right] = root;
    dfs(root->left, par);
    dfs(root->right, par);
}

// Function calling for finding the sum
void dfs3(Node* root, int h, int& sum, int k,
          unordered_map<Node*, int>& vis,
          unordered_map<Node*, Node*>& par)
{
    if (h == k + 1)
        return;
    if (root == 0)
        return;
    if (vis[root])
        return;
    sum += root->data;
    vis[root] = 1;
    dfs3(root->left, h + 1, sum, k, vis, par);
    dfs3(root->right, h + 1, sum, k, vis, par);
    dfs3(par[root], h + 1, sum, k, vis, par);
}

// Function for finding
// the target node in the tree
Node* dfs2(Node* root, int target)
{
    if (root == 0)
        return 0;
    if (root->data == target)
        return root;
    Node* node1 = dfs2(root->left, target);
    Node* node2 = dfs2(root->right, target);
    if (node1 != 0)
        return node1;
    if (node2 != 0)
        return node2;
}

// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
{
    // Hash Table to store
    // the parent of a node
    unordered_map<Node*, Node*> par;

    // Make the parent of root node as NULL
    // since it does not have any parent
    par[root] = 0;

    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);

    // Find the target node in the tree
    Node* node = dfs2(root, target);

    // Hash Table to mark
    // the visited nodes
    unordered_map<Node*, int> vis;

    int sum = 0;

    // DFS call to find the sum
    dfs3(node, 0, sum, k, vis, par);
    return sum;
}

// Driver Code
int main()
{
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
    root->right->right->right
        = new Node(11);
    root->left->left->left->left
        = new Node(30);
    root->left->left->right->left
        = new Node(40);
    root->left->left->right->right
        = new Node(50);

    int target = 9, K = 1;

    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;
}
Java
// Java code to implement above approach
import java.util.*;

public class Main {
    // Structure of a tree node
    static class Node {
        int data;
        Node left;
        Node right;
        Node(int val)
        {
            this.data = val;
            this.left = null;
            this.right = null;
        }
    }

    // Function for marking the parent node
    // for all the nodes using DFS
    static void dfs(Node root,
            HashMap <Node, Node> par)
    {
        if (root == null)
            return;
        if (root.left != null)
            par.put( root.left, root);
        if (root.right != null)
            par.put( root.right, root);
        dfs(root.left, par);
        dfs(root.right, par);
    }
    static int sum;
    // Function calling for finding the sum
    static void dfs3(Node root, int h, int k,
            HashMap <Node, Integer> vis,
            HashMap <Node, Node> par)
    {
        if (h == k + 1)
            return;
        if (root == null)
            return;
        if (vis.containsKey(root))
            return;
        sum += root.data;
        vis.put(root, 1);
        dfs3(root.left, h + 1, k, vis, par);
        dfs3(root.right, h + 1, k, vis, par);
        dfs3(par.get(root), h + 1, k, vis, par);
    }
    // Function for finding
    // the target node in the tree
    static Node dfs2(Node root, int target)
    {
        if (root == null)
            return null;
        if (root.data == target)
            return root;
        Node node1 = dfs2(root.left, target);
        Node node2 = dfs2(root.right, target);
        if (node1 != null)
            return node1;
        if (node2 != null)
            return node2;
        return null;
    }

    static int sum_at_distK(Node root, int target,
                 int k)
    {
        // Hash Map to store
        // the parent of a node
        HashMap <Node, Node> par =  new HashMap<>();

        // Make the parent of root node as NULL
        // since it does not have any parent
        par.put(root, null);

        // Mark the parent node for all the
        // nodes using DFS
        dfs(root, par);

        // Find the target node in the tree
        Node node = dfs2(root, target);

        // Hash Map to mark
        // the visited nodes
        HashMap <Node, Integer> vis = new HashMap<>();

        sum = 0;

        // DFS call to find the sum
        dfs3(node, 0, k, vis, par);
        return sum;
    }



    public static void main(String args[]) {
        // Taking Input
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right
            = new Node(11);
        root.left.left.left.left
            = new Node(30);
        root.left.left.right.left
            = new Node(40);
        root.left.left.right.right
            = new Node(50);

        int target = 9, K = 1;

        // Function call
        System.out.println( sum_at_distK(root, target, K) );
        
    }
}

// This code has been contributed by Sachin Sahara (sachin801)
Python3
# python program to implement above approach
# structure of tree node
class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None


# function for making the parent node
# for all the nodes using DFS
def dfs(root, par):
    if(root is None): 
        return
    if(root.left is not None):
        par[root.left] = root
    if(root.right is not None):
        par[root.right] = root
    dfs(root.left, par)
    dfs(root.right, par)


# function calling for finding the sum
summ = 0
def dfs3(root, h, k, vis, par):
    if(h == k+1):
        return
    if(root is None):
        return
    if(vis.get(root) == 1):
        return
    global summ
    summ += root.data
    vis[root] = 1
    dfs3(root.left, h+1, k, vis, par)
    dfs3(root.right, h+1, k, vis, par)
    dfs3(par[root], h+1, k, vis, par)


# function for finding
# the target node in the tree
def dfs2(root, target):
    if(root is None):
        return None
    if(root.data == target):
        return root
    node1 = dfs2(root.left, target)
    node2 = dfs2(root.right, target)
    if(node1 is not None):
        return node1
    if(node2 is not None):
        return node2
        

# function tofind the sum at distance k
def sum_at_distK(root, target, k):
    # hash table to store
    # the parent of a node
    par = {}
    
    # make the parent of root node as None
    # since it does not have any parent
    par[root] = 0
    
    # make the parent node for all the 
    # nodes using DFS
    dfs(root, par)
    
    # find the target node in the tree
    node = dfs2(root, target)
    
    # hash table to make the visited nodes
    vis = {}
    
    # dfs call to find the sum
    dfs3(node, 0, k, vis, par)


# driver program
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)

target = 9
K = 1

# function call
sum_at_distK(root, target, K)
print(summ)

# this code is contributed by Yash Agarwal(yashagarwal2852002)
C#
// C# code to implement above approach

using System;
using System.Collections.Generic;

public class GFG {

  // Structure of a tree node
  class Node {
    public int data;
    public Node left;
    public Node right;
    public Node(int val)
    {
      this.data = val;
      this.left = null;
      this.right = null;
    }
  }

  // Function for marking the parent node
  // for all the nodes using DFS
  static void dfs(Node root, Dictionary<Node, Node> par)
  {
    if (root == null)
      return;
    if (root.left != null)
      par.Add(root.left, root);
    if (root.right != null)
      par.Add(root.right, root);
    dfs(root.left, par);
    dfs(root.right, par);
  }

  static int sum;

  // Function calling for finding the sum
  static void dfs3(Node root, int h, int k,
                   Dictionary<Node, int> vis,
                   Dictionary<Node, Node> par)
  {
    if (h == k + 1)
      return;
    if (root == null)
      return;
    if (vis.ContainsKey(root))
      return;
    sum += root.data;
    vis.Add(root, 1);
    dfs3(root.left, h + 1, k, vis, par);
    dfs3(root.right, h + 1, k, vis, par);
    dfs3(par[root], h + 1, k, vis, par);
  }

  // Function for finding
  // the target node in the tree
  static Node dfs2(Node root, int target)
  {
    if (root == null)
      return null;
    if (root.data == target)
      return root;
    Node node1 = dfs2(root.left, target);
    Node node2 = dfs2(root.right, target);
    if (node1 != null)
      return node1;
    if (node2 != null)
      return node2;
    return null;
  }

  static int sum_at_distK(Node root, int target, int k)
  {

    // Hash Map to store
    // the parent of a node
    Dictionary<Node, Node> par
      = new Dictionary<Node, Node>();

    // Make the parent of root node as NULL
    // since it does not have any parent
    par.Add(root, null);

    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);

    // Find the target node in the tree
    Node node = dfs2(root, target);

    // Hash Map to mark
    // the visited nodes
    Dictionary<Node, int> vis
      = new Dictionary<Node, int>();

    sum = 0;

    // DFS call to find the sum
    dfs3(node, 0, k, vis, par);
    return sum;
  }

  static public void Main()
  {

    // Code
    Node root = new Node(1);
    root.left = new Node(2);
    root.right = new Node(9);
    root.left.left = new Node(4);
    root.right.left = new Node(5);
    root.right.right = new Node(7);
    root.left.left.left = new Node(8);
    root.left.left.right = new Node(19);
    root.right.right.left = new Node(20);
    root.right.right.right = new Node(11);
    root.left.left.left.left = new Node(30);
    root.left.left.right.left = new Node(40);
    root.left.left.right.right = new Node(50);

    int target = 9, K = 1;

    // Function call
    Console.Write(sum_at_distK(root, target, K));
  }
}

// This code is contributed by lokesh(lokeshmvs21).
JavaScript
        // JavaScript code for the above approach
        // Structure of a tree node
        class Node {
            constructor(val) {
                this.data = val;
                this.left = null;
                this.right = null;
            }
        }

        // Function for marking the parent node
        // for all the nodes using DFS
        function dfs(root, par) {
            if (root === null) return;
            if (root.left !== null) par.set(root.left, root);
            if (root.right !== null) par.set(root.right, root);
            dfs(root.left, par);
            dfs(root.right, par);
        }

        let sum = 0;

        // Function calling for finding the sum
        function dfs3(root, h, k, vis, par) {
            if (h === k + 1) return;
            if (root === null) return;
            if (vis.has(root)) return;
            sum += root.data;
            vis.set(root, 1);
            dfs3(root.left, h + 1, k, vis, par);
            dfs3(root.right, h + 1, k, vis, par);
            if (par.get(root) !== null && vis.has(par.get(root))) {
                dfs3(par.get(root), h + 1, k, vis, par);
            }
        }

        // Function for finding
        // the target node in the tree
        function dfs2(root, target) {
            if (root === null) return null;
            if (root.data === target) return root;
            let node1 = dfs2(root.left, target);
            let node2 = dfs2(root.right, target);
            if (node1 !== null) return node1;
            if (node2 !== null) return node2;
            return null;
        }

        function sumAtDistK(root, target, k)
        {
        
            // Map to store the parent of a node
            let par = new Map();

            // Make the parent of root node as NULL
            // since it does not have any parent
            par.set(root, null);

            // Mark the parent node for all the
            // nodes using DFS
            dfs(root, par);

            // Find the target node in the tree
            let node = dfs2(root, target);

            // Map to mark the visited nodes
            let vis = new Map();
            sum = 1;

            // DFS call to find the sum
            dfs3(node, 0, k, vis, par);
            return sum;
        }

        // Taking Input
        let root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right = new Node(11);
        root.left.left.left.left = new Node(30);
        root.left.left.right.left = new Node(40);
        root.left.left.right.right = new Node(50);

        let target = 9;
        let K = 1;

        console.log(sumAtDistK(root, target, K));

 // This code is contributed by Potta Lokesh

Output
22

Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)

Approach using BFS:-

  • We will be using level order traversal to find the sum of nodes

Implementation:-

  • First we will find the target node using level order traversal.
  • While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
  • After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer.
C++
// C++ code to implement above approach

#include <bits/stdc++.h>
using namespace std;

// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
    {
        this->data = val;
        this->left = 0;
        this->right = 0;
    }
};

// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
{
  //variable to store answer
  int ans = 0;
  
  //queue for bfs
  queue<Node*> q;
  
  q.push(root);
  
  //to store target node
  Node* need;
  
  //map to store parent of each node
  unordered_map<Node*, Node*> m;
  
  //bfs
  while(q.size()){
    
    int s = q.size();
    
    //traversing to current level
    for(int i=0;i<s;i++){
      
      Node* temp = q.front();
      
      q.pop();
      
      //if target value found
      if(temp->data==target) need=temp;
      
      if(temp->left){
        q.push(temp->left);
        m[temp->left]=temp;
      }
      
      if(temp->right){
        q.push(temp->right);
        m[temp->right]=temp;
      }
      
    }
    
  }
  
  //map to store occurrence of a node
  //that is the node has taken or not
  unordered_map<Node*, int> mm;
  
  q.push(need);
  
  //to store current distance
  int c = 0;
  
  while(q.size()){
    
    int s = q.size();
    
    for(int i=0;i<s;i++){
      
      Node* temp = q.front();
      
      q.pop();
      
      mm[temp] = 1;
      
      ans+=temp->data;
      
      //moving left
      if(temp->left&&mm[temp->left]==0){
        q.push(temp->left);
      }
      
      //moving right
      if(temp->right&&mm[temp->right]==0){
        q.push(temp->right);
      }
      
      //movinf to parent
      if(m[temp]&&mm[m[temp]]==0){
        q.push(m[temp]);
      }
      
    }
    
    c++;
    if(c>k)break;
    
  }
  return ans;
}

// Driver Code
int main()
{
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
    root->right->right->right
        = new Node(11);
    root->left->left->left->left
        = new Node(30);
    root->left->left->right->left
        = new Node(40);
    root->left->left->right->right
        = new Node(50);

    int target = 9, K = 1;

    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;
}
//code contributed by shubhamrajput6156
Java
import java.util.*;

// Structure of a tree node
class Node {
    int data;
    Node left;
    Node right;

    public Node(int val) {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}

public class Main {

    // Function to find the sum at distance K
    public static int sumAtDistK(Node root, int target, int k) {
        // Variable to store the answer
        int ans = 0;

        // Queue for BFS
        Queue<Node> q = new LinkedList<>();

        q.add(root);

        // To store the target node
        Node need = null;

        // Map to store the parent of each node
        Map<Node, Node> parentMap = new HashMap<>();

        // BFS
        while (!q.isEmpty()) {
            int size = q.size();

            // Traverse the current level
            for (int i = 0; i < size; i++) {
                Node temp = q.poll();

                // If the target value is found
                if (temp.data == target) {
                    need = temp;
                }

                if (temp.left != null) {
                    q.add(temp.left);
                    parentMap.put(temp.left, temp);
                }

                if (temp.right != null) {
                    q.add(temp.right);
                    parentMap.put(temp.right, temp);
                }
            }
        }

        // Map to store the occurrence of a node (whether it has been visited)
        Map<Node, Integer> visitedMap = new HashMap<>();

        q.add(need);

        // Current distance
        int currentDistance = 0;

        while (!q.isEmpty()) {
            int size = q.size();

            for (int i = 0; i < size; i++) {
                Node temp = q.poll();

                visitedMap.put(temp, 1);

                ans += temp.data;

                // Moving left
                if (temp.left != null && visitedMap.getOrDefault(temp.left, 0) == 0) {
                    q.add(temp.left);
                }

                // Moving right
                if (temp.right != null && visitedMap.getOrDefault(temp.right, 0) == 0) {
                    q.add(temp.right);
                }

                // Moving to parent
                if (parentMap.containsKey(temp) && visitedMap.getOrDefault(parentMap.get(temp), 0) == 0) {
                    q.add(parentMap.get(temp));
                }
            }

            currentDistance++;
            if (currentDistance > k) {
                break;
            }
        }

        return ans;
    }

    // Driver code
    public static void main(String[] args) {
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right = new Node(11);
        root.left.left.left.left = new Node(30);
        root.left.left.right.left = new Node(40);
        root.left.left.right.right = new Node(50);

        int target = 9, K = 1;

        // Function call
        System.out.println(sumAtDistK(root, target, K));
    }
}
Python3
from collections import deque

class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None

# Function to find the sum at distance K
def sum_at_distK(root, target, k):
    ans = 0

    # Queue for BFS
    q = deque()
    q.append(root)

    need = None

    # Dictionary to store parent of each node
    m = {}

    # BFS traversal to find the target node
    while q:
        s = len(q)

        # Traversing the current level
        for i in range(s):
            temp = q.popleft()

            if temp.data == target:
                need = temp

            if temp.left:
                q.append(temp.left)
                m[temp.left] = temp

            if temp.right:
                q.append(temp.right)
                m[temp.right] = temp

    # Dictionary to store occurrence of a node (visited or not)
    mm = {}

    q.append(need)

    c = 0

    # BFS traversal within K distance
    while q:
        s = len(q)

        for i in range(s):
            temp = q.popleft()

            mm[temp] = 1

            ans += temp.data

            # Moving left
            if temp.left and temp.left not in mm:
                q.append(temp.left)

            # Moving right
            if temp.right and temp.right not in mm:
                q.append(temp.right)

            # Moving to parent
            if temp in m and m[temp] not in mm:
                q.append(m[temp])

        c += 1
        if c > k:
            break

    return ans

# Driver Code
# Taking Input
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)

target = 9
K = 1

# Function call
print(sum_at_distK(root, target, K))
C#
using System;
using System.Collections.Generic;

// Structure of a tree node
class Node
{
    public int data;
    public Node left;
    public Node right;

    public Node(int val)
    {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}

class GFG
{
    // Function to find the sum at distance K
    public static int SumAtDistK(Node root, int target, int k)
    {
        // Variable to store the answer
        int ans = 0;

        // Queue for BFS
        Queue<Node> q = new Queue<Node>();

        q.Enqueue(root);

        // To store the target node
        Node need = null;

        // Dictionary to store the parent of each node
        Dictionary<Node, Node> parentMap = new Dictionary<Node, Node>();

        // BFS
        while (q.Count > 0)
        {
            int size = q.Count;

            // Traverse the current level
            for (int i = 0; i < size; i++)
            {
                Node temp = q.Dequeue();

                // If the target value is found
                if (temp.data == target)
                {
                    need = temp;
                }

                if (temp.left != null)
                {
                    q.Enqueue(temp.left);
                    parentMap[temp.left] = temp;
                }

                if (temp.right != null)
                {
                    q.Enqueue(temp.right);
                    parentMap[temp.right] = temp;
                }
            }
        }

        // Dictionary to store the occurrence of a node (whether it has been visited)
        Dictionary<Node, int> visitedMap = new Dictionary<Node, int>();

        q.Enqueue(need);

        // Current distance
        int currentDistance = 0;

        while (q.Count > 0)
        {
            int size = q.Count;

            for (int i = 0; i < size; i++)
            {
                Node temp = q.Dequeue();

                visitedMap[temp] = 1;

                ans += temp.data;

                // Moving left
                if (temp.left != null && visitedMap.GetValueOrDefault(temp.left, 0) == 0)
                {
                    q.Enqueue(temp.left);
                }

                // Moving right
                if (temp.right != null && visitedMap.GetValueOrDefault(temp.right, 0) == 0)
                {
                    q.Enqueue(temp.right);
                }

                // Moving to parent
                if (parentMap.ContainsKey(temp) && visitedMap.GetValueOrDefault(parentMap[temp], 0) == 0)
                {
                    q.Enqueue(parentMap[temp]);
                }
            }

            currentDistance++;
            if (currentDistance > k)
            {
                break;
            }
        }

        return ans;
    }

    // Driver code
    public static void Main(string[] args)
    {
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right = new Node(11);
        root.left.left.left.left = new Node(30);
        root.left.left.right.left = new Node(40);
        root.left.left.right.right = new Node(50);

        int target = 9, K = 1;

        // Function call
        Console.WriteLine(SumAtDistK(root, target, K));
    }
}
JavaScript
class Node {
    constructor(val) {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}

// Function to find the sum at distance K
function sum_at_distK(root, target, k) {
    let ans = 0;

    // Queue for BFS
    let q = [];
    q.push(root);

    let need = null;

    // Map to store parent of each node
    let m = new Map();

    // BFS traversal to find the target node
    while (q.length) {
        let s = q.length;

        // Traversing the current level
        for (let i = 0; i < s; i++) {
            let temp = q.shift();

            if (temp.data === target) {
                need = temp;
            }

            if (temp.left) {
                q.push(temp.left);
                m.set(temp.left, temp);
            }

            if (temp.right) {
                q.push(temp.right);
                m.set(temp.right, temp);
            }
        }
    }

    // Map to store occurrence of a node (visited or not)
    let mm = new Map();

    q.push(need);

    let c = 0;

    // BFS traversal within K distance
    while (q.length) {
        let s = q.length;

        for (let i = 0; i < s; i++) {
            let temp = q.shift();

            mm.set(temp, 1);

            ans += temp.data;

            // Moving left
            if (temp.left && !mm.has(temp.left)) {
                q.push(temp.left);
            }

            // Moving right
            if (temp.right && !mm.has(temp.right)) {
                q.push(temp.right);
            }

            // Moving to parent
            if (m.has(temp) && !mm.has(m.get(temp))) {
                q.push(m.get(temp));
            }
        }

        c++;
        if (c > k) break;
    }
    return ans;
}

// Driver Code
// Taking Input
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);

let target = 9, K = 1;

// Function call
console.log(sum_at_distK(root, target, K));

Output
22

Time Complexity:- O(N) Where N is the number of nodes
Auxiliary Space:- O(N)


Next Article

Similar Reads