Persistent Segment Tree in Python
Persistent data structures are a powerful tool in computer science, enabling us to maintain and access multiple versions of a data structure over time. One such structure is the Persistent Segment Tree. Segment trees are versatile data structures that allow efficient querying and updating of array intervals. By making a segment tree persistent, we enhance its capability to maintain historical versions, which is particularly useful in competitive programming and real-time applications where rollback and point-in-time queries are needed. This article explores the concept, implementation in Python.

What is Persistent Segment Tree?
Persistence in data structures refers to the ability to maintain access to previous versions of the data structure even after modifications. This can be achieved using techniques such as path copying, where only the parts of the structure that need to be changed are copied, thus saving space and time.
Representation of Persistent Segment Tree:
Each node in a Persistent Segment Tree contains:
- Value: The value stored at this node, which typically represents an aggregate (like sum, min, or max) over a segment of the array.
- Left: A reference to the left child node.
- Right: A reference to the right child node.
class Node:
def __init__(self, value=0, left=None, right=None):
# Initialize a new node
self.value = value # The value stored in this node
self.left = left # Reference to the left child node
self.right = right # Reference to the right child node
Persistent Segment Tree Operations
Persistent Segment Trees allow you to perform updates and queries efficiently while preserving the history of changes. Here are the primary operations you can perform on a Persistent Segment Tree:
- Build: Construct the initial segment tree from an array.
- Update: Create a new version of the tree with an updated value at a specific index.
- Query: Perform range queries on any version of the tree.
1. Building the Initial Segment Tree
The build
function constructs the segment tree from an array. This is similar to building a regular segment tree but sets up the foundation for persistence.
Algorithm:
- If the current segment is a single element (i.e.,
left == right
), create a leaf node with the value of that element. - Otherwise, split the segment into two halves and recursively build the left and right subtrees.
- Create a new node whose value is the sum (or other aggregate function) of the values of the left and right children.
Implementation:
def build(arr, left, right):
# Function to build the initial segment tree
if left == right:
# If the current segment is a single element, create a leaf node
return Node(value=arr[left])
# Calculate the mid-point of the current segment
mid = (left + right) // 2
# Recursively build the left and right subtrees
left_child = build(arr, left, mid)
right_child = build(arr, mid + 1, right)
# Create a new node with the sum of values of left and right children
return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)
2. Updating the Segment Tree Persistently
The update
function creates a new version of the segment tree by copying only the necessary parts. It returns a new root node for the updated version, leaving the previous versions intact.
Algorithm:
- If the current segment is a single element (i.e.,
left == right
), create a new node with the updated value. - Otherwise, determine whether the index to be updated lies in the left or right subtree.
- Recursively update the relevant subtree while keeping the other subtree unchanged.
- Create a new node whose value is the sum (or other aggregate function) of the values of the updated left and right children.
Implementation:
def update(prev_node, left, right, idx, new_value):
# Function to perform an update operation on the segment tree persistently
if left == right:
# If the current segment is a single element, create a new node with the updated value
return Node(value=new_value)
# Calculate the mid-point of the current segment
mid = (left + right) // 2
# Determine whether the index to be updated lies in the left or right subtree
if idx <= mid:
# Recursively update the left subtree, keep the right subtree unchanged
left_child = update(prev_node.left, left, mid, idx, new_value)
right_child = prev_node.right
else:
# Recursively update the right subtree, keep the left subtree unchanged
left_child = prev_node.left
right_child = update(prev_node.right, mid + 1, right, idx, new_value)
# Create a new node with the sum of values of the updated left and right children
return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)
3. Querying the Segment Tree
The query
function retrieves information from a segment of the tree. It can perform range queries on any version of the tree.
Algorithm:
- If the query range does not overlap with the current segment, return 0 (or the identity value for the aggregate function).
- If the current segment is completely within the query range, return the value of the current node.
- Otherwise, split the query range and recursively query the left and right subtrees.
- Return the sum (or other aggregate function) of the results of the left and right subtree queries.
Implementation:
def query(node, left, right, query_left, query_right):
# Function to perform a range query on the segment tree
if query_left > right or query_right < left:
# If the query range does not overlap with the current segment, return 0
return 0
if query_left <= left and right <= query_right:
# If the current segment is completely within the query range, return the value of the current node
return node.value
# Calculate the mid-point of the current segment
mid = (left + right) // 2
# Recursively query the left and right subtrees and return the sum of results
return query(node.left, left, mid, query_left, query_right) + query(node.right, mid + 1, right, query_left, query_right)
Illustration of working of Persistent Segment Tree in Python:
Let's walk through a detailed example step-by-step to illustrate how the persistent segment tree works:
Initial Array
- arr = [1,2,3,4,5]
Building the Initial Segment Tree
- Build the leaf nodes:
- Leaf node for
arr[0]
: value = 1- Leaf node for
arr[1]
: value = 2- Leaf node for
arr[2]
: value = 3- Leaf node for
arr[3]
: value = 4- Leaf node for
arr[4]
: value = 5- Build the internal nodes:
- Node for range [0, 1]: value = 1 + 2 = 3
- Node for range [2, 2]: value = 3
- Node for range [3, 4]: value = 4 + 5 = 9
- Node for range [0, 2]: value = 3 + 3 = 6
- Node for range [0, 4]: value = 6 + 9 = 15
Updating the Segment Tree
- Update
arr[2]
from 3 to 10:
- Create new leaf node for
arr[2]
: value = 10- Update node for range [2, 2]: new value = 10
- Update node for range [0, 2]: new value = 3 + 10 = 13
- Update node for range [0, 4]: new value = 13 + 9 = 22
Querying the Segment Tree
- Query range [1, 3] in the original tree:
- Node for range [1, 3] overlaps with range [0, 4], [0, 2], and [3, 4]
- Query range [1, 3] results in sum = 2 + 3 + 4 = 9
- Query range [1, 3] in the updated tree:
- Node for range [1, 3] overlaps with range [0, 4], [0, 2], and [3, 4]
- Query range [1, 3] results in sum = 2 + 10 + 4 = 16
This example demonstrates how the persistent segment tree maintains different versions efficiently and supports range queries on any version.
Implementation of Persistent segment trees in Python
Below is the complete implementation of a Persistent Segment Tree in Python:
class Node:
def __init__(self, value=0, left=None, right=None):
# Initialize a new node
self.value = value # The value stored in this node
self.left = left # Reference to the left child node
self.right = right # Reference to the right child node
def build(arr, left, right):
# Function to build the initial segment tree
if left == right:
# If the current segment is a single element, create a leaf node
return Node(value=arr[left])
# Calculate the mid-point of the current segment
mid = (left + right) // 2
# Recursively build the left and right subtrees
left_child = build(arr, left, mid)
right_child = build(arr, mid + 1, right)
# Create a new node with the sum of values of left and right children
return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)
def update(prev_node, left, right, idx, new_value):
# Function to perform an update operation on the segment tree persistently
if left == right:
# If the current segment is a single element, create a new node with the updated value
return Node(value=new_value)
# Calculate the mid-point of the current segment
mid = (left + right) // 2
# Determine whether the index to be updated lies in the left or right subtree
if idx <= mid:
# Recursively update the left subtree, keep the right subtree unchanged
left_child = update(prev_node.left, left, mid, idx, new_value)
right_child = prev_node.right
else:
# Recursively update the right subtree, keep the left subtree unchanged
left_child = prev_node.left
right_child = update(prev_node.right, mid + 1, right, idx, new_value)
# Create a new node with the sum of values of the updated left and right children
return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)
def query(node, left, right, query_left, query_right):
# Function to perform a range query on the segment tree
if query_left > right or query_right < left:
# If the query range does not overlap with the current segment, return 0
return 0
if query_left <= left and right <= query_right:
# If the current segment is completely within the query range, return the value of the current node
return node.value
# Calculate the mid-point of the current segment
mid = (left + right) // 2
# Recursively query the left and right subtrees and return the sum of results
return query(node.left, left, mid, query_left, query_right) + query(node.right, mid + 1, right, query_left, query_right)
# Example usage
if __name__ == "__main__":
# Initial array
arr = [1, 2, 3, 4, 5]
# Build the initial segment tree
root = build(arr, 0, len(arr) - 1)
# Create a new version with an update (change the value at index 2 to 10)
new_root = update(root, 0, len(arr) - 1, 2, 10)
# Query the original and new versions
print(query(root, 0, len(arr) - 1, 1, 3)) # Output: 9 (2+3+4)
print(query(new_root, 0, len(arr) - 1, 1, 3)) # Output: 16 (2+10+4)
Output
9 16
Complexity Analysis of Persistent Segment Tree:
Operation | Time Complexity | Space Complexity (per update) | Description |
---|---|---|---|
Build | O(n log n) | O(n log n) | Construct the initial segment tree from an array of size n |
Update | O(log n) | O(log n) | Create a new version of the tree with an updated value at a specific index |
Query | O(log n) | O(1) | Perform a range query on any version of the tree |