Open In App

Add all greater values to every node in a given BST

Last Updated : 17 Oct, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Given a Binary Search Tree (BST), the task is to modify it so that all greater and equal values in the given BST are added to every node.

Examples:

Input:

Add-all-greater-values-to-every-node-in-a-given-BST

Output:

Add-all-greater-values-to-every-node-in-a-given-BST-2


Explanation: The above tree represents the greater sum tree where each node contains the sum of all nodes greater than or equal to that node in original tree.

  • The root node 50 becomes 260 (sum of 50 + 60 + 70 + 80).
  • The left child of 50 becomes 330 (sum of 30 + 40 + 50 + 60 + 70 + 80).
  • The right child of 50 becomes 150 (70 + 80) and so on.

[Naive Approach] By Calculating Sum for Each Node – O(n^2) Time and O(n) Space

The idea is to traverse the binary tree and for each node, find the sum of all nodes with values greater than or equal to it. As we traverse, we compute these sums and replace the current node’s value with the corresponding sum.

This method doesn’t require the tree to be a BST. Following are the steps:

  • Traverse node by node (in-order, pre-order, etc.).
  • For each node, find all the nodes greater than equal to the current node and sum their values. Store all these sums.
  • Replace each node’s value with its corresponding sum by traversing in the same order as in Step 1.

Below is the implementation of the above approach:

C++
// C++ program to transform a BST to
// sum tree

#include <bits/stdc++.h>
using namespace std;

class Node {
public:
    int data;
    Node* left;
    Node* right;

    Node(int value) {
        data = value;
        left = nullptr;
        right = nullptr;
    }
};

// Function to find nodes having greater value than
// current node.
void findGreaterNodes(Node* root, Node* curr, 
                      unordered_map<Node*,int> &mp) {
    if (root == nullptr) return;
    
    // if value is greater than equal to node,  
    // then increment it in the map
    if (root->data >= curr->data) 
        mp[curr] += root->data;
        
    findGreaterNodes(root->left, curr, mp);
    findGreaterNodes(root->right, curr, mp);
}

void transformToGreaterSumTree(Node* curr, Node* root, 
                               unordered_map<Node*,int>&mp) {
    if (curr == nullptr) {
        return;
    }

    // Find all nodes greater than current node
    findGreaterNodes(root, curr, mp);
    
    // Recursively check for left and right subtree.
    transformToGreaterSumTree(curr->left, root, mp);
    transformToGreaterSumTree(curr->right, root, mp);
}

// Function to update value of each node.
void preOrderTrav(Node* root, unordered_map<Node*, int> &mp) {
    if (root == nullptr) return;
    
    root->data = mp[root];
    
    preOrderTrav(root->left, mp);
    preOrderTrav(root->right, mp);
}

void transformTree(Node* root) {
  
      // map to store greater sum for each node.
    unordered_map<Node*, int> mp;
    transformToGreaterSumTree(root, root, mp);
    
    // update the value of nodes
    preOrderTrav(root, mp);
}

void inorder(Node* root) {
    if (root == nullptr) {
        return;
    }
    inorder(root->left);
    cout << root->data << " ";
    inorder(root->right);
}

int main() {
  
// Representation of input binary tree:
//           50
//         /    \
//        30    70
//      /  \    / \  
//     20  40  60  80
Node* root = new Node(50);
root->left = new Node(30);
root->right = new Node(70);
root->left->left = new Node(20);
root->left->right = new Node(40);
root->right->left = new Node(60);
root->right->right = new Node(80);

transformTree(root);
inorder(root);
    
    return 0;
}
Java
// Java program to transform a BST to
// sum tree
import java.util.HashMap;

class Node {
    int data;
    Node left, right;

    Node(int value) {
        data = value;
        left = null;
        right = null;
    }
}

class GfG {
    
    // Function to find nodes having greater value than
   // current node.
    static void findGreaterNodes(Node root, Node curr, 
                                 HashMap<Node, Integer> mp) {
        if (root == null) return;

        // if value is greater than equal to node, 
        // then increment it in the map
        if (root.data >= curr.data)
            mp.put(curr, mp.getOrDefault(curr, 0) + root.data);

        findGreaterNodes(root.left, curr, mp);
        findGreaterNodes(root.right, curr, mp);
    }

    static void transformToGreaterSumTree(Node curr, Node root,
                                          HashMap<Node, Integer> mp) {
        if (curr == null) {
            return;
        }

        // Find all nodes greater than current node
        findGreaterNodes(root, curr, mp);

        // Recursively check for left and right subtree.
        transformToGreaterSumTree(curr.left, root, mp);
        transformToGreaterSumTree(curr.right, root, mp);
    }

    // Function to update value of each node.
    static void preOrderTrav(Node root, HashMap<Node, Integer> mp) {
        if (root == null) return;

        root.data = mp.getOrDefault(root, 0);

        preOrderTrav(root.left, mp);
        preOrderTrav(root.right, mp);
    }
  
    static void transformTree(Node root) {
        
        // map to store greater sum for each node.
        HashMap<Node, Integer> mp = new HashMap<>();
        transformToGreaterSumTree(root, root, mp);

        // update the value of nodes
        preOrderTrav(root, mp);
    }

    static void inorder(Node root) {
        if (root == null) {
            return;
        }
        inorder(root.left);
        System.out.print(root.data + " ");
        inorder(root.right);
    }

    public static void main(String[] args) {
        
        // Representation of input binary tree:
        //           50
        //         /    \
        //        30    70
        //      /  \    / \  
        //     20  40  60  80
        Node root = new Node(50);
        root.left = new Node(30);
        root.right = new Node(70);
        root.left.left = new Node(20);
        root.left.right = new Node(40);
        root.right.left = new Node(60);
        root.right.right = new Node(80);

        transformTree(root);
        inorder(root);
    }
}
Python
# Python program to transform a BST
# to sum tree

class Node:
    def __init__(self, value):
        self.data = value
        self.left = None
        self.right = None

# Function to find nodes having greater 
# value than current node.
def findGreaterNodes(root, curr, map):
    if root is None:
        return
    
    # if value is greater than equal to node, 
    # then increment it in the map
    if root.data >= curr.data:
        map[curr] += root.data

    findGreaterNodes(root.left, curr, map)
    findGreaterNodes(root.right, curr, map)

def transformToGreaterSumTree(curr, root, map):
    if curr is None:
        return

    # Find all nodes greater than current node
    map[curr] = 0
    findGreaterNodes(root, curr, map)

    # Recursively check for left and right subtree.
    transformToGreaterSumTree(curr.left, root, map)
    transformToGreaterSumTree(curr.right, root, map)

# Function to update value of each node.
def preOrderTrav(root, map):
    if root is None:
        return
    
    root.data = map.get(root, root.data)
    
    preOrderTrav(root.left, map)
    preOrderTrav(root.right, map)

def transformTree(root):
    
    # map to store greater sum for each node.
    map = {}
    transformToGreaterSumTree(root, root, map)
    
    # update the value of nodes
    preOrderTrav(root, map)

def inorder(root):
    if root is None:
        return
    inorder(root.left)
    print(root.data, end=" ")
    inorder(root.right)

if __name__ == "__main__":
    
    # Representation of input binary tree:
    #           50
    #         /    \
    #        30    70
    #      /  \    / \  
    #     20  40  60  80
    root = Node(50)
    root.left = Node(30)
    root.right = Node(70)
    root.left.left = Node(20)
    root.left.right = Node(40)
    root.right.left = Node(60)
    root.right.right = Node(80)
    
    transformTree(root)
    inorder(root)
C#
// C# program to transform a BST
// to sum tree
using System;
using System.Collections.Generic;

class Node {
    public int data;
    public Node left, right;

    public Node(int value) {
        data = value;
        left = null;
        right = null;
    }
}

class GfG {
    
    // Function to find nodes having greater value
      // than current node.
    static void FindGreaterNodes(Node root, Node curr, 
                                 Dictionary<Node, int> map) {
        if (root == null) return;

        // if value is greater than equal to node, 
        // then increment it in the map
        if (root.data >= curr.data)
            map[curr] += root.data;

        FindGreaterNodes(root.left, curr, map);
        FindGreaterNodes(root.right, curr, map);
    }

    static void TransformToGreaterSumTree(Node curr, Node root, 
                                          Dictionary<Node, int> map) {
        if (curr == null) {
            return;
        }

        // Find all nodes greater than
          // current node
        map[curr] = 0;
        FindGreaterNodes(root, curr, map);

        // Recursively check for left and right subtree.
        TransformToGreaterSumTree(curr.left, root, map);
        TransformToGreaterSumTree(curr.right, root, map);
    }

    // Function to update value of each node.
    static void PreOrderTrav(Node root, Dictionary<Node, int> map) {
        if (root == null) return;

        root.data = map.ContainsKey(root) ? map[root] : root.data;

        PreOrderTrav(root.left, map);
        PreOrderTrav(root.right, map);
    }
  
    static void TransformTree(Node root) {
        
        // map to store greater sum for each node.
        Dictionary<Node, int> map = new Dictionary<Node, int>();
        TransformToGreaterSumTree(root, root, map);

        // update the value of nodes
        PreOrderTrav(root, map);
    }

    static void Inorder(Node root) {
        if (root == null) {
            return;
        }
        Inorder(root.left);
        Console.Write(root.data + " ");
        Inorder(root.right);
    }

    static void Main(string[] args) {
        
        // Representation of input binary tree:
        //           50
        //         /    \
        //        30    70
        //      /  \    / \  
        //     20  40  60  80
        Node root = new Node(50);
        root.left = new Node(30);
        root.right = new Node(70);
        root.left.left = new Node(20);
        root.left.right = new Node(40);
        root.right.left = new Node(60);
        root.right.right = new Node(80);

        TransformTree(root);
        Inorder(root);
    }
}
JavaScript
// JavaScript program to transform 
// a BST to sum tree
class Node {
    constructor(value) {
        this.data = value;
        this.left = null;
        this.right = null;
    }
}

// Function to find nodes having greater value 
// than current node.
function findGreaterNodes(root, curr, map) {
    if (root === null) return;

    // if value is greater than equal to node,  
    // then increment it in the map
    if (root.data >= curr.data) {
        map.set(curr, (map.get(curr) || 0) + root.data);
    }

    findGreaterNodes(root.left, curr, map);
    findGreaterNodes(root.right, curr, map);
}

function transformToGreaterSumTree(curr, root, map) {
    if (curr === null) {
        return;
    }

    // Find all nodes greater than current node
    findGreaterNodes(root, curr, map);

    // Recursively check for left and right subtree.
    transformToGreaterSumTree(curr.left, root, map);
    transformToGreaterSumTree(curr.right, root, map);
}

// Function to update value of each node.
function preOrderTrav(root, map) {
    if (root === null) return;

    root.data = map.has(root) ? map.get(root) : 0;

    preOrderTrav(root.left, map);
    preOrderTrav(root.right, map);
}

function transformTree(root) {

    // map to store greater sum for each node.
    const map = new Map();
    transformToGreaterSumTree(root, root, map);

    // update the value of nodes
    preOrderTrav(root, map);
}

function inorder(root) {
    if (root === null) {
        return;
    }
    inorder(root.left);
    console.log(root.data + " ");
    inorder(root.right);
}

// Representation of input binary tree:
//           50
//         /    \
//        30    70
//      /  \    / \  
//     20  40  60  80
let root = new Node(50);
root.left = new Node(30);
root.right = new Node(70);
root.left.left = new Node(20);
root.left.right = new Node(40);
root.right.left = new Node(60);
root.right.right = new Node(80);

transformTree(root);
inorder(root);

Output
350 330 300 260 210 150 80 

Note: Since this approach runs in O(n2) this will give TLE, so we need to think of a more efficient approach.

[Expected Approach] Using Single Traversal – O(n) Time and O(h) Space

The idea is to traverse the tree in reverse in-order (right -> root -> left) while keeping a running sum of all previously visited nodes. The value of each node is updated to this running sum, which ensure that each node contains the sum of all nodes greater than equal to it.

Below is the implementation of the above approach:

C++
// C++ program to transform a BST to sum tree
#include <bits/stdc++.h>
using namespace std;

class Node {
public:
    int data;
    Node* left;
    Node* right;

    Node(int value) {
        data = value;
        left = nullptr;
        right = nullptr;
    }
};

void transformToGreaterSumTree(Node* root, int& sum) {
    if (root == nullptr) {
        return;
    }

    // Traverse the right subtree first (larger values)
    transformToGreaterSumTree(root->right, sum);

    // Update the sum and the current node's value
    sum += root->data;
    root->data = sum;

    // Traverse the left subtree (smaller values)
    transformToGreaterSumTree(root->left, sum);
}

void transformTree(Node* root) {
  
      // Initialize the cumulative sum
    int sum = 0; 
    transformToGreaterSumTree(root, sum);
}

void inorder(Node* root) {
    if (root == nullptr) {
        return;
    }
    inorder(root->left);
    cout << root->data << " ";
    inorder(root->right);
}

int main() {

    // Representation of input binary tree:
    //           50
    //         /    \
    //        30    70
    //      /  \    / \  
    //     20  40  60  80
    Node* root = new Node(50);
    root->left = new Node(30);
    root->right = new Node(70);
    root->left->left = new Node(20);
    root->left->right = new Node(40);
    root->right->left = new Node(60);
    root->right->right = new Node(80);

    transformTree(root);
    inorder(root);

    return 0;
}
C
// C program to transform a BST 
// to sum tree
#include <stdio.h>
#include <stdlib.h>

struct Node {
    int data;
    struct Node* left;
    struct Node* right;
};

void transformToGreaterSumTree(struct Node* root, int* sum) {
    if (root == NULL) {
        return;
    }

    // Traverse the right subtree first (larger values)
    transformToGreaterSumTree(root->right, sum);

    // Update the sum and the current node's value
    *sum += root->data;
    root->data = *sum;

    // Traverse the left subtree (smaller values)
    transformToGreaterSumTree(root->left, sum);
}

void transformTree(struct Node* root) {
  
      // Initialize the cumulative sum
    int sum = 0; 
    transformToGreaterSumTree(root, &sum);
}

void inorder(struct Node* root) {
    if (root == NULL) {
        return;
    }
    inorder(root->left);
    printf("%d ", root->data);
    inorder(root->right);
}

struct Node* createNode(int data) {
    struct Node* node = 
      (struct Node*)malloc(sizeof(struct Node));
    node->data = data;
    node->left = NULL;
    node->right = NULL;
    return node;
}

int main() {
  
    // Representation of input binary tree:
    //           50
    //         /    \
    //        30    70
    //      /  \    / \  
    //     20  40  60  80
    struct Node* root = createNode(50);
    root->left = createNode(30);
    root->right = createNode(70);
    root->left->left = createNode(20);
    root->left->right = createNode(40);
    root->right->left = createNode(60);
    root->right->right = createNode(80);

    transformTree(root);
    inorder(root);
    return 0;
}
Java
// Java program to transform a 
// BST to sum tree
class Node {
    int data;
    Node left, right;

    Node(int value) {
        data = value;
        left = right = null;
    }
}

class GfG {

    static void transformToGreaterSumTree(Node root, int[] sum) {
        if (root == null) {
            return;
        }

        // Traverse the right subtree first (larger values)
        transformToGreaterSumTree(root.right, sum);

        // Update the sum and the current node's value
        sum[0] += root.data;
        root.data = sum[0];

        // Traverse the left subtree (smaller values)
        transformToGreaterSumTree(root.left, sum);
    }

    static void transformTree(Node root) {
      
          // Initialize the cumulative sum
        int[] sum = {0}; 
        transformToGreaterSumTree(root, sum);
    }

    static void inorder(Node root) {
        if (root == null) {
            return;
        }
        inorder(root.left);
        System.out.print(root.data + " ");
        inorder(root.right);
    }

    public static void main(String[] args) {
      
        // Representation of input binary tree:
        //           50
        //         /    \
        //        30    70
        //      /  \     / \  
        //     20  40  60  80
        Node root = new Node(50);
        root.left = new Node(30);
        root.right = new Node(70);
        root.left.left = new Node(20);
        root.left.right = new Node(40);
        root.right.left = new Node(60);
        root.right.right = new Node(80);

        transformTree(root);
        inorder(root);
    }
}
Python
# Python program to transform a 
# BST to sum tree
class Node:
    def __init__(self, value):
        self.data = value
        self.left = None
        self.right = None

def transformToGreaterSumTree(root, sum):
    if root is None:
        return

    # Traverse the right subtree first
    # (larger values)
    transformToGreaterSumTree(root.right, sum)

    # Update the sum and the current node's value
    sum[0] += root.data
    root.data = sum[0]

    # Traverse the left subtree (smaller values)
    transformToGreaterSumTree(root.left, sum)

def transformTree(root):
  
      # Initialize the cumulative sum
    sum = [0]  
    transformToGreaterSumTree(root, sum)

def inorder(root):
    if root is None:
        return
    inorder(root.left)
    print(root.data, end=" ")
    inorder(root.right)

if __name__ == "__main__":
  
    # Representation of input binary tree:
    #           50
    #         /    \
    #        30    70
    #      /  \     / \  
    #     20  40  60  80
    root = Node(50)
    root.left = Node(30)
    root.right = Node(70)
    root.left.left = Node(20)
    root.left.right = Node(40)
    root.right.left = Node(60)
    root.right.right = Node(80)

    transformTree(root)
    inorder(root)
C#
// C# program to transform a BST to
// sum tree
using System;

class Node {
    public int data;
    public Node left, right;

    public Node(int value) {
        data = value;
        left = right = null;
    }
}

class GfG {
  
    static void transformToGreaterSumTree(Node root, ref int sum) {
        if (root == null) {
            return;
        }

        // Traverse the right subtree first (larger values)
        transformToGreaterSumTree(root.right, ref sum);

        // Update the sum and the current node's value
        sum += root.data;
        root.data = sum;

        // Traverse the left subtree (smaller values)
        transformToGreaterSumTree(root.left, ref sum);
    }

    static void transformTree(Node root) {
      
        // Initialize the cumulative sum
        int sum = 0; 
        transformToGreaterSumTree(root, ref sum);
    }

    static void inorder(Node root) {
        if (root == null) {
            return;
        }
        inorder(root.left);
        Console.Write(root.data + " ");
        inorder(root.right);
    }

    static void Main() {
      
        // Representation of input binary tree:
        //           50
        //         /    \
        //        30    70
        //      /  \    / \  
        //     20  40  60  80
        Node root = new Node(50);
        root.left = new Node(30);
        root.right = new Node(70);
        root.left.left = new Node(20);
        root.left.right = new Node(40);
        root.right.left = new Node(60);
        root.right.right = new Node(80);
      
        transformTree(root);
        inorder(root);
    }
}
JavaScript
// JavaScript program to transform a
// BST to sum tree

class Node {
    constructor(value) {
        this.data = value;
        this.left = null;
        this.right = null;
    }
}

function transformToGreaterSumTree(root, sum) {
    if (root === null) {
        return;
    }

    // Traverse the right subtree first (larger values)
    transformToGreaterSumTree(root.right, sum);

    // Update the sum and the current node's value
    sum[0] += root.data;
    root.data = sum[0];

    // Traverse the left subtree (smaller values)
    transformToGreaterSumTree(root.left, sum);
}

function transformTree(root) {
    let sum = [0]; // Initialize the cumulative sum
    transformToGreaterSumTree(root, sum);
}

// Function to perform in-order traversal
function inorder(root) {
    if (root === null) {
        return;
    }
    inorder(root.left);
    console.log(root.data + " ");
    inorder(root.right);
}

// Representation of input binary tree:
//           50
//         /    \
//        30    70
//      /  \     / \  
//     20  40  60  80
const root = new Node(50);
root.left = new Node(30);
root.right = new Node(70);
root.left.left = new Node(20);
root.left.right = new Node(40);
root.right.left = new Node(60);
root.right.right = new Node(80);

transformTree(root);
inorder(root);

Output
350 330 300 260 210 150 80 


Next Article

Similar Reads