Skip to content

Commit 2685f20

Browse files
committed
Improve exp
1 parent eb8bb84 commit 2685f20

File tree

2 files changed

+51
-28
lines changed

2 files changed

+51
-28
lines changed

benches/bigdecimal/bench_bigdecimal_exp.mojo

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -557,10 +557,10 @@ fn main() raises:
557557
speedup_factors,
558558
)
559559

560-
# Case 43: exp(5000000)
560+
# Case 43: exp(5000.1234567890)
561561
run_benchmark_exp(
562-
"exp(5000000)",
563-
"5000000",
562+
"exp(5000.1234567890)",
563+
"5000.1234567890",
564564
iterations,
565565
log_file,
566566
speedup_factors,
@@ -611,10 +611,10 @@ fn main() raises:
611611
speedup_factors,
612612
)
613613

614-
# Case 49: exp(12345678)
614+
# Case 49: exp(1234.5678901234567890)
615615
run_benchmark_exp(
616-
"exp(12345678)",
617-
"12345678",
616+
"exp(1234.5678901234567890)",
617+
"1234.5678901234567890",
618618
iterations,
619619
log_file,
620620
speedup_factors,

src/decimojo/bigdecimal/exponential.mojo

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
"""Implements exponential functions for the BigDecimal type."""
2222

23+
import time
24+
2325
from decimojo.bigdecimal.bigdecimal import BigDecimal
2426
from decimojo.rounding_mode import RoundingMode
2527
import decimojo.utility
@@ -200,14 +202,13 @@ fn exp(x: BigDecimal, precision: Int = 28) raises -> BigDecimal:
200202

201203
# For very large positive values, result will overflow BigDecimal capacity
202204
# Calculate rough estimate to detect overflow early
203-
if not x.sign and x.scale <= -40: # x > 10^40
205+
# TODO: Use BigInt as scale can avoid overflow in this case
206+
if not x.sign and x.exponent() >= 20: # x > 10^20
204207
raise Error("Error in `exp`: Result too large to represent")
205208

206209
# For very large negative values, result will be effectively zero
207-
if x.sign and x.scale <= -40: # x < -10^40
208-
return BigDecimal(
209-
BigUInt.ONE, precision, False
210-
) # Return very small number
210+
if x.sign and x.exponent() >= 20: # x < -10^20
211+
return BigDecimal(BigUInt.ZERO, precision, False)
211212

212213
# Handle negative x using identity: exp(-x) = 1/exp(x)
213214
if x.sign:
@@ -217,37 +218,59 @@ fn exp(x: BigDecimal, precision: Int = 28) raises -> BigDecimal:
217218
)
218219

219220
# Range reduction for faster convergence
220-
# If x > 1, use exp(x) = exp(x/2)²
221-
if x > BigDecimal(BigUInt.ONE, 0, False):
222-
# Find k where 2^k > x
221+
# If x >= 0.1, use exp(x) = exp(x/2)²
222+
if x >= BigDecimal(BigUInt.ONE, 1, False):
223+
# var t_before_range_reduction = time.perf_counter_ns()
223224
var k = 0
224225
var threshold = BigDecimal(BigUInt.ONE, 0, False)
225-
while threshold <= x:
226-
threshold = threshold + threshold # Multiply by 2
226+
while threshold.exponent() <= x.exponent() + 1:
227+
threshold.coefficient = (
228+
threshold.coefficient + threshold.coefficient
229+
) # Multiply by 2
227230
k += 1
228231

229232
# Calculate exp(x/2^k)
230233
var reduced_x = x.true_divide_fast(threshold, working_precision)
231-
var reduced_exp = exp_taylor_series(reduced_x, working_precision)
234+
235+
# var t_after_range_reduction = time.perf_counter_ns()
236+
237+
var result = exp_taylor_series(reduced_x, working_precision)
238+
239+
# var t_after_taylor_series = time.perf_counter_ns()
232240

233241
# Square result k times: exp(x) = exp(x/2^k)^(2^k)
234-
var result = reduced_exp
235-
for i in range(k):
242+
for _ in range(k):
236243
result = result * result
237-
# Round intermediates to working precision to avoid explosion
238-
if i % 2 == 1: # Every few iterations
239-
result.round_to_precision(
240-
precision=working_precision,
241-
rounding_mode=RoundingMode.ROUND_HALF_UP,
242-
remove_extra_digit_due_to_rounding=False,
243-
)
244+
result.round_to_precision(
245+
precision=working_precision,
246+
rounding_mode=RoundingMode.ROUND_HALF_UP,
247+
remove_extra_digit_due_to_rounding=False,
248+
)
244249

245250
result.round_to_precision(
246251
precision=precision,
247252
rounding_mode=RoundingMode.ROUND_HALF_EVEN,
248253
remove_extra_digit_due_to_rounding=False,
249254
)
250255

256+
# var t_after_scale_up = time.perf_counter_ns()
257+
258+
# print(
259+
# "TIME: range reduction: {}ns".format(
260+
# t_after_range_reduction - t_before_range_reduction
261+
# )
262+
# )
263+
# print(
264+
# "TIME: taylor series: {}ns".format(
265+
# t_after_taylor_series - t_after_range_reduction
266+
# )
267+
# )
268+
# print(
269+
# "TIME: scale up: {}ns".format(
270+
# t_after_scale_up - t_after_taylor_series
271+
# )
272+
# )
273+
251274
return result^
252275

253276
# For small values, use Taylor series directly
@@ -284,7 +307,7 @@ fn exp_taylor_series(
284307
# There are intotal 2.3 * precision iterations
285308

286309
# print("DEBUG: exp_taylor_series")
287-
# print("DEBUG: x", x)
310+
# print("DEBUG: x =", x)
288311

289312
var max_number_of_terms = Int(minimum_precision * 2.5) + 1
290313
var result = BigDecimal(BigUInt.ONE, 0, False)
@@ -309,7 +332,7 @@ fn exp_taylor_series(
309332
# Add term to result
310333
result += term
311334

312-
print("DEUBG: round {}, term {}, result {}".format(n, term, result))
335+
# print("DEUBG: round {}, term {}, result {}".format(n, term, result))
313336

314337
# Check if we've reached desired precision
315338
if term.exponent() < -minimum_precision:

0 commit comments

Comments
 (0)