2020
2121""" Implements exponential functions for the BigDecimal type."""
2222
23+ import time
24+
2325from decimojo.bigdecimal.bigdecimal import BigDecimal
2426from decimojo.rounding_mode import RoundingMode
2527import decimojo.utility
@@ -200,14 +202,13 @@ fn exp(x: BigDecimal, precision: Int = 28) raises -> BigDecimal:
200202
201203 # For very large positive values, result will overflow BigDecimal capacity
202204 # Calculate rough estimate to detect overflow early
203- if not x.sign and x.scale <= - 40 : # x > 10^40
205+ # TODO : Use BigInt as scale can avoid overflow in this case
206+ if not x.sign and x.exponent() >= 20 : # x > 10^20
204207 raise Error(" Error in `exp`: Result too large to represent" )
205208
206209 # For very large negative values, result will be effectively zero
207- if x.sign and x.scale <= - 40 : # x < -10^40
208- return BigDecimal(
209- BigUInt.ONE , precision, False
210- ) # Return very small number
210+ if x.sign and x.exponent() >= 20 : # x < -10^20
211+ return BigDecimal(BigUInt.ZERO , precision, False )
211212
212213 # Handle negative x using identity: exp(-x) = 1/exp(x)
213214 if x.sign:
@@ -217,37 +218,59 @@ fn exp(x: BigDecimal, precision: Int = 28) raises -> BigDecimal:
217218 )
218219
219220 # Range reduction for faster convergence
220- # If x > 1, use exp(x) = exp(x/2)²
221- if x > BigDecimal(BigUInt.ONE , 0 , False ):
222- # Find k where 2^k > x
221+ # If x >= 0. 1, use exp(x) = exp(x/2)²
222+ if x >= BigDecimal(BigUInt.ONE , 1 , False ):
223+ # var t_before_range_reduction = time.perf_counter_ns()
223224 var k = 0
224225 var threshold = BigDecimal(BigUInt.ONE , 0 , False )
225- while threshold <= x:
226- threshold = threshold + threshold # Multiply by 2
226+ while threshold.exponent() <= x.exponent() + 1 :
227+ threshold.coefficient = (
228+ threshold.coefficient + threshold.coefficient
229+ ) # Multiply by 2
227230 k += 1
228231
229232 # Calculate exp(x/2^k)
230233 var reduced_x = x.true_divide_fast(threshold, working_precision)
231- var reduced_exp = exp_taylor_series(reduced_x, working_precision)
234+
235+ # var t_after_range_reduction = time.perf_counter_ns()
236+
237+ var result = exp_taylor_series(reduced_x, working_precision)
238+
239+ # var t_after_taylor_series = time.perf_counter_ns()
232240
233241 # Square result k times: exp(x) = exp(x/2^k)^(2^k)
234- var result = reduced_exp
235- for i in range (k):
242+ for _ in range (k):
236243 result = result * result
237- # Round intermediates to working precision to avoid explosion
238- if i % 2 == 1 : # Every few iterations
239- result.round_to_precision(
240- precision = working_precision,
241- rounding_mode = RoundingMode.ROUND_HALF_UP ,
242- remove_extra_digit_due_to_rounding = False ,
243- )
244+ result.round_to_precision(
245+ precision = working_precision,
246+ rounding_mode = RoundingMode.ROUND_HALF_UP ,
247+ remove_extra_digit_due_to_rounding = False ,
248+ )
244249
245250 result.round_to_precision(
246251 precision = precision,
247252 rounding_mode = RoundingMode.ROUND_HALF_EVEN ,
248253 remove_extra_digit_due_to_rounding = False ,
249254 )
250255
256+ # var t_after_scale_up = time.perf_counter_ns()
257+
258+ # print(
259+ # "TIME: range reduction: {}ns".format(
260+ # t_after_range_reduction - t_before_range_reduction
261+ # )
262+ # )
263+ # print(
264+ # "TIME: taylor series: {}ns".format(
265+ # t_after_taylor_series - t_after_range_reduction
266+ # )
267+ # )
268+ # print(
269+ # "TIME: scale up: {}ns".format(
270+ # t_after_scale_up - t_after_taylor_series
271+ # )
272+ # )
273+
251274 return result^
252275
253276 # For small values, use Taylor series directly
@@ -284,7 +307,7 @@ fn exp_taylor_series(
284307 # There are intotal 2.3 * precision iterations
285308
286309 # print("DEBUG: exp_taylor_series")
287- # print("DEBUG: x", x)
310+ # print("DEBUG: x = ", x)
288311
289312 var max_number_of_terms = Int(minimum_precision * 2.5 ) + 1
290313 var result = BigDecimal(BigUInt.ONE , 0 , False )
@@ -309,7 +332,7 @@ fn exp_taylor_series(
309332 # Add term to result
310333 result += term
311334
312- print (" DEUBG: round {} , term {} , result {} " .format(n, term, result))
335+ # print("DEUBG: round {}, term {}, result {}".format(n, term, result))
313336
314337 # Check if we've reached desired precision
315338 if term.exponent() < - minimum_precision:
0 commit comments