Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add root
  • Loading branch information
forfudan committed Apr 4, 2025
commit 89d6fbbdd6101b07049cf5036f5de205efbd6aa6
16 changes: 16 additions & 0 deletions src/decimojo/bigdecimal/bigdecimal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ struct BigDecimal:
"""Returns the minimum of two BigDecimal numbers."""
return decimojo.bigdecimal.comparison.min(self, other)

@always_inline
fn root(self, root: Self, precision: Int = 28) raises -> Self:
"""Returns the root of the BigDecimal number."""
return decimojo.bigdecimal.exponential.root(self, root, precision)

@always_inline
fn sqrt(self, precision: Int = 28) raises -> Self:
"""Returns the square root of the BigDecimal number."""
Expand Down Expand Up @@ -766,6 +771,17 @@ struct BigDecimal:
else:
return True

@always_inline
fn is_one(self) raises -> Bool:
"""Returns True if this number represents one."""
if self.sign:
return False
if self.coefficient.number_of_digits() - self.scale != 1:
return False
if self.coefficient.ith_digit(-self.scale) != 1:
return False
return True

@always_inline
fn is_zero(self) -> Bool:
"""Returns True if this number represents zero."""
Expand Down
193 changes: 193 additions & 0 deletions src/decimojo/bigdecimal/exponential.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import decimojo.utility

# ===----------------------------------------------------------------------=== #
# Power and root functions
# power(base, exponent, precision)
# integer_power(base, exponent, precision)
# sqrt(x, precision)
# ===----------------------------------------------------------------------=== #


Expand Down Expand Up @@ -169,6 +172,196 @@ fn integer_power(
return result^


fn root(x: BigDecimal, n: BigDecimal, precision: Int) raises -> BigDecimal:
"""Calculate the nth root of a BigDecimal number.

Args:
x: The number to calculate the root of.
n: The root value.
precision: The precision (number of significant digits) of the result.

Returns:
The nth root of x with the specified precision.

Raises:
Error: If x is negative and n is not an odd integer.
Error: If n is zero.

Notes:
Uses the identity x^(1/n) = exp(ln(|x|)/n) for calculation.
For integer roots, calls the specialized integer_root function.
"""
alias BUFFER_DIGITS = 9
var working_precision = precision + BUFFER_DIGITS

# Check for n = 0
if n.coefficient.is_zero():
raise Error("Error in `root`: Cannot compute zeroth root")

# Special case for integer roots - use more efficient implementation
if not n.sign and n.is_integer():
return integer_root(x, n, precision)

# Handle negative n as 1/(x^(1/|n|))
if n.sign:
var positive_root = root(x, -n, working_precision)
var result = BigDecimal(BigUInt.ONE, 0, False).true_divide(
positive_root, precision
)
return result^

# Handle special cases for x
if x.coefficient.is_zero():
return BigDecimal(BigUInt.ZERO, 0, False)

if x.is_one():
return BigDecimal(BigUInt.ONE, 0, False)

# Check if x is negative - only odd integer roots of negative numbers are defined
if x.sign:
if not n.is_integer() or not is_odd_reciprocal(n):
raise Error(
"Error in `root`: Cannot compute non-odd-integer root of a"
" negative number"
)

# Compute root using the identity: x^(1/n) = exp(ln(|x|)/n)
var abs_x = abs(x)
var ln_x = ln(abs_x, working_precision)
var ln_divided = ln_x.true_divide(n, working_precision)
var result = exp(ln_divided, working_precision)

# Handle sign for negative inputs (only possible with odd integer roots)
if x.sign:
result.sign = True

result.round_to_precision(
precision=precision,
rounding_mode=RoundingMode.ROUND_HALF_EVEN,
remove_extra_digit_due_to_rounding=True,
)

return result^


fn integer_root(
x: BigDecimal, n: BigDecimal, precision: Int
) raises -> BigDecimal:
"""Calculate the nth integer root of a BigDecimal number.

Args:
x: The number to calculate the root of.
n: The root value (must be a positive integer).
precision: The precision (number of significant digits) of the result.

Returns:
The nth root of x with the specified precision.

Raises:
Error: If x is negative and n is even.
Error: If n is not a positive integer.
Error: If n is zero.

Notes:
Uses the identity x^(1/n) = exp(ln(|x|)/n) for calculation.
Optimizes for special cases including n=1 and n=2.
"""
alias BUFFER_DIGITS = 9
var working_precision = precision + BUFFER_DIGITS

# Handle special case: n must be a positive integer
if n.sign:
raise Error("Error in `root`: Root value must be positive")

if not n.is_integer():
raise Error("Error in `root`: Root value must be an integer")

if n.coefficient.is_zero():
raise Error("Error in `root`: Cannot compute zeroth root")

# Special case: n = 1 (1st root is just the number itself)
if n.is_one():
var result = x
result.round_to_precision(
precision,
rounding_mode=RoundingMode.ROUND_HALF_EVEN,
remove_extra_digit_due_to_rounding=True,
)
return result^

# Special case: n = 2 (use dedicated sqrt function for better performance)
if n == BigDecimal(BigUInt(2), 0, False):
return sqrt(x, precision)

# Handle special cases for x
if x.coefficient.is_zero():
return BigDecimal(BigUInt.ZERO, 0, False)

# For x = 1, the result is always 1
if x.is_one():
return BigDecimal(BigUInt.ONE, 0, False)

var result_sign = False
# Check if x is negative
if x.sign:
# Convert n to integer to check odd/even
var n_uint: BigUInt
if n.scale > 0:
n_uint = n.coefficient.scale_down_by_power_of_10(n.scale)
else: # n.scale <= 0
n_uint = n.coefficient

if n_uint.words[0] % 2 == 1: # Odd root
result_sign = True
else: # n_uint.words[0] % 2 == 0: # Even root
raise Error(
"Error in `root`: Cannot compute even root of a negative number"
)

# Compute root using the identity: x^(1/n) = exp(ln(|x|)/n)
var abs_x = abs(x)
var ln_x = ln(abs_x, working_precision)
var ln_divided = ln_x.true_divide(n, working_precision)
var result = exp(ln_divided, working_precision)
result.sign = result_sign

result.round_to_precision(
precision=precision,
rounding_mode=RoundingMode.ROUND_HALF_EVEN,
remove_extra_digit_due_to_rounding=True,
)

return result^


fn is_odd_reciprocal(n: BigDecimal) raises -> Bool:
"""Check if 1/n represents an odd integer.

Args:
n: The value to check.

Returns:
True if 1/n is an odd integer, False otherwise.
"""
# If n is of form 1/m where m is an odd integer, then 1/n = m is odd
# This is true when n = 1/m for odd integer m

# n must have only one significant digit that equals 1
if n.coefficient.number_of_digits() != 1 or n.coefficient.words[0] != 1:
return False

# Check if n = 1/(2k+1) for some integer k
# This means n = 1/m where m is odd
# For this to be true, n's scale must be positive and 10^scale must be odd
if n.scale <= 0:
return False

# For the scale to represent a valid odd denominator:
# n = 1/10^scale and 10^scale = 2^scale * 5^scale
# This is odd only when scale = 0, which means n = 1
return n.is_one()


fn sqrt(x: BigDecimal, precision: Int) raises -> BigDecimal:
"""Calculate the square root of a BigDecimal number.

Expand Down