summaryrefslogtreecommitdiff
path: root/prism
diff options
context:
space:
mode:
authortompng <[email protected]>2024-02-26 22:05:30 +0900
committerKevin Newton <[email protected]>2024-03-07 18:02:33 -0500
commit5113d6b0591fe2c80cace33654e9088d1330277c (patch)
tree45d37308a233563969488276193ac7dc40aeb8b8 /prism
parent4186609d8871fb99eefa871c258658c86600ae3c (diff)
[ruby/prism] Faster pm_integer_parse pm_integer_string using karatsuba algorithm
https://github.com/ruby/prism/commit/ae4fb6b988
Diffstat (limited to 'prism')
-rw-r--r--prism/util/pm_integer.c406
1 files changed, 288 insertions, 118 deletions
diff --git a/prism/util/pm_integer.c b/prism/util/pm_integer.c
index c03b930ad3..5bcb508c1c 100644
--- a/prism/util/pm_integer.c
+++ b/prism/util/pm_integer.c
@@ -1,117 +1,139 @@
#include "prism/util/pm_integer.h"
/**
- * Create a new node for an integer in the linked list.
+ * Bigint with arbitary base. In practice, base is 1<<32 or 10**9.
+ * When base is 10**9, it acts as bigdecimal.
*/
-static pm_integer_word_t *
-pm_integer_node_create(pm_integer_t *integer, uint32_t value) {
- integer->length++;
-
- pm_integer_word_t *node = xmalloc(sizeof(pm_integer_word_t));
- if (node == NULL) return NULL;
-
- *node = (pm_integer_word_t) { .next = NULL, .value = value };
- return node;
-}
+typedef struct {
+ size_t length;
+ uint32_t *values;
+} bigint_t;
/**
- * Copy one integer onto another.
+ * Adds two bigint_t with the given base.
*/
-static void
-pm_integer_copy(pm_integer_t *dest, const pm_integer_t *src) {
- dest->negative = src->negative;
- dest->length = 0;
-
- dest->head.value = src->head.value;
- dest->head.next = NULL;
-
- pm_integer_word_t *dest_current = &dest->head;
- const pm_integer_word_t *src_current = src->head.next;
-
- while (src_current != NULL) {
- dest_current->next = pm_integer_node_create(dest, src_current->value);
- if (dest_current->next == NULL) return;
-
- dest_current = dest_current->next;
- src_current = src_current->next;
+static bigint_t
+big_add(bigint_t left, bigint_t right, uint64_t base) {
+ size_t length = (left.length < right.length ? right.length : left.length);
+ uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * (length + 1));
+ uint64_t carry = 0;
+ for (size_t i = 0; i < length; i++) {
+ uint64_t sum = carry + (i < left.length ? left.values[i] : 0) + (i < right.length ? right.values[i] : 0);
+ values[i] = (uint32_t) (sum % base);
+ carry = sum / base;
}
-
- dest_current->next = NULL;
+ if (carry > 0) {
+ values[length] = (uint32_t) carry;
+ length++;
+ }
+ return (bigint_t) { length, values };
}
/**
- * Add a 32-bit integer to an integer.
+ * Calculates `a - b - c` with the given base.
+ * Result is assumed to be positive value. Internal use for karatsuba_multiply.
*/
-static void
-pm_integer_add(pm_integer_t *integer, uint32_t addend) {
- uint32_t carry = addend;
- pm_integer_word_t *current = &integer->head;
-
- while (carry > 0) {
- uint64_t result = (uint64_t) current->value + carry;
- carry = (uint32_t) (result >> 32);
- current->value = (uint32_t) result;
-
- if (carry > 0) {
- if (current->next == NULL) {
- current->next = pm_integer_node_create(integer, carry);
- break;
- }
-
- current = current->next;
+static bigint_t
+big_sub2(bigint_t a, bigint_t b, bigint_t c, uint64_t base) {
+ size_t length = a.length;
+ uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * length);
+ int64_t carry = 0;
+ for (size_t i = 0; i < length; i++) {
+ int64_t sub = carry + a.values[i] - (i < b.length ? b.values[i] : 0) - (i < c.length ? c.values[i] : 0);
+ if (sub >= 0) {
+ values[i] = (uint32_t) sub;
+ carry = 0;
+ } else {
+ sub += 2 * (int64_t) base;
+ values[i] = (uint32_t) ((uint64_t) sub % base);
+ carry = sub / (int64_t) base - 2;
}
}
+ while (length > 1 && values[length - 1] == 0) length--;
+ return (bigint_t) { length, values };
}
/**
- * Multiple an integer by a 32-bit integer. In practice, the multiplier is the
- * base of the integer, so this is 2, 8, 10, or 16.
+ * Multiply two bigint_t with the given base using karatsuba algorithm.
*/
-static void
-pm_integer_multiply(pm_integer_t *integer, uint32_t multiplier) {
- uint32_t carry = 0;
-
- for (pm_integer_word_t *current = &integer->head; current != NULL; current = current->next) {
- uint64_t result = (uint64_t) current->value * multiplier + carry;
- carry = (uint32_t) (result >> 32);
- current->value = (uint32_t) result;
-
- if (carry > 0 && current->next == NULL) {
- current->next = pm_integer_node_create(integer, carry);
- break;
+static bigint_t
+karatsuba_multiply(bigint_t left, bigint_t right, uint64_t base) {
+ if (left.length > right.length) {
+ bigint_t temp = left;
+ left = right;
+ right = temp;
+ }
+ if (left.length <= 10) {
+ size_t length = left.length + right.length;
+ uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t));
+ for (size_t i = 0; i < left.length; i++) {
+ uint32_t carry = 0;
+ for (size_t j = 0; j < right.length; j++) {
+ uint64_t product = (uint64_t) left.values[i] * right.values[j] + values[i + j] + carry;
+ values[i + j] = (uint32_t) (product % base);
+ carry = (uint32_t) (product / base);
+ }
+ values[i + right.length] = carry;
}
+ while (length > 1 && values[length - 1] == 0) length--;
+ return (bigint_t) { length, values };
}
-}
-
-/**
- * Divide an individual word by a 32-bit integer. This will recursively divide
- * any subsequent nodes in the linked list.
- */
-static uint32_t
-pm_integer_divide_word(pm_integer_t *integer, pm_integer_word_t *word, uint32_t dividend) {
- uint32_t remainder = 0;
- if (word->next != NULL) {
- remainder = pm_integer_divide_word(integer, word->next, dividend);
-
- if (integer->length > 0 && word->next->value == 0) {
- xfree(word->next);
- word->next = NULL;
- integer->length--;
+ if (left.length * 2 <= right.length) {
+ uint32_t *values = (uint32_t*) calloc(left.length + right.length, sizeof(uint32_t));
+ for (size_t start_offset = 0; start_offset < right.length; start_offset += left.length) {
+ size_t end_offset = start_offset + left.length;
+ if (end_offset > right.length) end_offset = right.length;
+ bigint_t sliced_right = { end_offset - start_offset, right.values + start_offset };
+ bigint_t v = karatsuba_multiply(left, sliced_right, base);
+ uint32_t carry = 0;
+ for (size_t i = 0; i < v.length; i++) {
+ uint64_t sum = (uint64_t) values[start_offset + i] + v.values[i] + carry;
+ values[start_offset + i] = (uint32_t) (sum % base);
+ carry = (uint32_t) (sum / base);
+ }
+ free(v.values);
+ values[start_offset + v.length] += carry;
}
+ return (bigint_t) { left.length + right.length, values };
}
-
- uint64_t value = ((uint64_t) remainder << 32) | word->value;
- word->value = (uint32_t) (value / dividend);
- return (uint32_t) (value % dividend);
-}
-
-/**
- * Divide an integer by a 32-bit integer. In practice, this is only 10 so that
- * we can format it as a string. It returns the remainder of the division.
- */
-static uint32_t
-pm_integer_divide(pm_integer_t *integer, uint32_t dividend) {
- return pm_integer_divide_word(integer, &integer->head, dividend);
+ size_t half = left.length / 2;
+ bigint_t x0 = { half, left.values };
+ bigint_t x1 = { left.length - half, left.values + half };
+ bigint_t y0 = { half, right.values };
+ bigint_t y1 = { right.length - half, right.values + half };
+ bigint_t z0 = karatsuba_multiply(x0, y0, base);
+ bigint_t z2 = karatsuba_multiply(x1, y1, base);
+
+ // For simplicity to avoid considering negative values,
+ // use `z1 = (x0 + x1) * (y0 + y1) - z0 - z2` instead of original karatsuba algorithm.
+ bigint_t x01 = big_add(x0, x1, base);
+ bigint_t y01 = big_add(y0, y1, base);
+ bigint_t xy = karatsuba_multiply(x01, y01, base);
+ bigint_t z1 = big_sub2(xy, z0, z2, base);
+
+ size_t length = left.length + right.length;
+ uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t));
+ memcpy(values, z0.values, sizeof(uint32_t) * z0.length);
+ memcpy(values + 2 * half, z2.values, sizeof(uint32_t) * z2.length);
+ uint32_t carry = 0;
+ for(size_t i = 0; i < z1.length; i++) {
+ uint64_t sum = (uint64_t) carry + values[i + half] + z1.values[i];
+ values[i + half] = (uint32_t) (sum % base);
+ carry = (uint32_t) (sum / base);
+ }
+ for(size_t i = half + z1.length; carry > 0; i++) {
+ uint64_t sum = (uint64_t) carry + values[i];
+ values[i] = (uint32_t) (sum % base);
+ carry = (uint32_t) (sum / base);
+ }
+ while (length > 1 && values[length - 1] == 0) length--;
+ free(z0.values);
+ free(z1.values);
+ free(z2.values);
+ free(x01.values);
+ free(y01.values);
+ free(xy.values);
+ return (bigint_t) { length, values };
}
/**
@@ -141,6 +163,140 @@ pm_integer_parse_digit(const uint8_t character) {
}
/**
+ * Create a bigint_t from uint64_t with the given base.
+ */
+static bigint_t
+uint64_to_bigint(uint64_t value, uint64_t base) {
+ uint64_t v = value;
+ size_t len = 0;
+ while (value > 0) { len++; value /= base; }
+ if (len == 0) len = 1;
+ uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * len);
+ for (size_t i = 0; i < len; i++) {
+ values[i] = (uint32_t) (v % base);
+ v /= base;
+ }
+ return (bigint_t) { len, values };
+}
+
+/**
+ * Convert base of bigint.
+ * In practice, it converts 10**9 to 1<<32 or 1<<32 to 10**9.
+ */
+static bigint_t
+karatsuba_convert_base(bigint_t source, uint64_t base_from, uint64_t base_to) {
+ size_t bigints_length = (source.length + 1) / 2;
+ bigint_t *bigints = (bigint_t*) malloc(sizeof(bigint_t) * bigints_length);
+ for (size_t i = 0; i < source.length; i += 2) {
+ uint64_t v = source.values[i] + base_from * (i + 1 < source.length ? source.values[i + 1] : 0);
+ bigints[i / 2] = uint64_to_bigint(v, base_to);
+ }
+ bigint_t base = uint64_to_bigint(base_from, base_to);
+ while (bigints_length > 1) {
+ size_t new_length = (bigints_length + 1) / 2;
+ bigint_t new_base = karatsuba_multiply(base, base, base_to);
+ free(base.values);
+ base = new_base;
+ bigint_t *new_bigints = (bigint_t*) malloc(sizeof(bigint_t) * new_length);
+ for (size_t i = 0; i < bigints_length; i += 2) {
+ if (i + 1 == bigints_length) {
+ new_bigints[i / 2] = bigints[i];
+ } else {
+ bigint_t multiplied = karatsuba_multiply(base, bigints[i + 1], base_to);
+ new_bigints[i / 2] = big_add(bigints[i], multiplied, base_to);
+ free(bigints[i].values);
+ free(bigints[i + 1].values);
+ free(multiplied.values);
+ }
+ }
+ free(bigints);
+ bigints = new_bigints;
+ bigints_length = new_length;
+ }
+ free(base.values);
+ bigint_t result = bigints[0];
+ free(bigints);
+ return result;
+}
+
+/**
+ * Convert digits to bigint_t with the given power-of-two base.
+ */
+static bigint_t
+big_parse_powof2(uint32_t base, const uint8_t *digits, size_t digits_length) {
+ size_t bit = 1;
+ while (base > (uint32_t) (1 << bit)) bit++;
+ size_t length = (digits_length * bit + 31) / 32;
+ uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t));
+ for (size_t i = 0; i < digits_length; i++) {
+ size_t bit_position = bit * (digits_length - i - 1);
+ uint32_t value = digits[i];
+ size_t index = bit_position / 32;
+ size_t shift = bit_position % 32;
+ values[index] |= value << shift;
+ if (32 - shift < bit) values[index + 1] |= value >> (32 - shift);
+ }
+ while (length > 1 && values[length - 1] == 0) length--;
+ return (bigint_t) { length, values };
+}
+
+/**
+ * Convert decimal digits to bigint.
+ */
+static bigint_t
+big_parse_decimal(const uint8_t *digits, size_t digits_length) {
+ // Construct a bigdecimal from the digits.
+ const size_t batch = 9;
+ const uint64_t batch_base = 1000000000;
+ size_t values_length = (digits_length + batch - 1) / batch;
+ bigint_t bigint = { values_length, (uint32_t*) calloc(values_length, sizeof(uint32_t)) };
+ uint32_t v = 0;
+ for (size_t i = 0; i < digits_length; i++) {
+ v = v * 10 + digits[i];
+ size_t reverse_index = digits_length - i - 1;
+ if (reverse_index % batch == 0) {
+ bigint.values[reverse_index / batch] = v;
+ v = 0;
+ }
+ }
+ // Convert bigint base from 10**9 to 1<<32.
+ bigint_t converted = karatsuba_convert_base(bigint, batch_base, ((uint64_t) 1 << 32));
+ free(bigint.values);
+ return converted;
+}
+
+/**
+ * Parse a large integer from a string that does not fit into uint32_t.
+ */
+static void
+pm_integer_parse_big(pm_integer_t *integer, uint32_t multiplier, const uint8_t *start, const uint8_t *end) {
+ // Allocate an array to store digits.
+ uint8_t *digits = malloc(sizeof(uint8_t) * (size_t) (end - start));
+ size_t digits_length = 0;
+ for (; start < end; start++) {
+ if (*start == '_') continue;
+ digits[digits_length++] = (uint8_t) pm_integer_parse_digit(*start);
+ }
+ // Construct bigint_t from the digits.
+ bigint_t bigint =
+ multiplier == 10 ? big_parse_decimal(digits, digits_length) : big_parse_powof2(multiplier, digits, digits_length);
+
+ // Pack bigint_t to pm_integer_t.
+ integer->length = bigint.length - 1;
+ integer->head.value = bigint.values[0];
+ pm_integer_word_t *current = &integer->head;
+ for (size_t i = 1; i < bigint.length; i++) {
+ current->next = malloc(sizeof(pm_integer_word_t));
+ current = current->next;
+ current->value = bigint.values[i];
+ current->next = NULL;
+ }
+
+ free(bigint.values);
+ free(digits);
+}
+
+/**
* Parse an integer from a string. This assumes that the format of the integer
* has already been validated, as internal validation checks are not performed
* here.
@@ -189,15 +345,19 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s
// invalid integer. If this is the case, we'll just return 0.
if (start >= end) return;
- // Add the first digit to the integer.
- pm_integer_add(integer, pm_integer_parse_digit(*start++));
-
- // Add the subsequent digits to the integer.
- for (; start < end; start++) {
- if (*start == '_') continue;
- pm_integer_multiply(integer, multiplier);
- pm_integer_add(integer, pm_integer_parse_digit(*start));
+ const uint8_t *ptr = start;
+ uint64_t value = pm_integer_parse_digit(*ptr++);
+ for (; ptr < end; ptr++) {
+ if (*ptr == '_') continue;
+ value = value * multiplier + pm_integer_parse_digit(*ptr);
+ if (value > UINT32_MAX) {
+ // If the integer is too large to fit into a single node, then we'll
+ // parse it as a big integer.
+ pm_integer_parse_big(integer, multiplier, start, end);
+ return;
+ }
}
+ integer->head.value = (uint32_t) value;
}
/**
@@ -254,29 +414,39 @@ pm_integer_string(pm_buffer_t *buffer, const pm_integer_t *integer) {
return;
}
default: {
- // First, allocate a buffer that we'll copy the decimal digits into.
- size_t length = (integer->length + 1) * 10;
- char *digits = xcalloc(length, sizeof(char));
+ // Pack pm_integer_t to bigint_t.
+ size_t length = integer->length + 1;
+ uint32_t *values = calloc(length, sizeof(uint32_t));
+ const pm_integer_word_t *current = &(integer->head);
+ for (size_t i = 0; i < length; i++) {
+ values[i] = current->value;
+ current = current->next;
+ }
+ bigint_t bigint = { length, values };
+ // Convert bigint base from 1<<32 to 10**9.
+ bigint_t converted = karatsuba_convert_base(bigint, (uint64_t) 1 << 32, 1000000000);
+ free(values);
+
+ // Allocate a buffer that we'll copy the decimal digits into.
+ size_t char_length = converted.length * 9;
+ char *digits = calloc(char_length, sizeof(char));
if (digits == NULL) return;
- // Next, create a new integer that we'll use to store the result of
- // the division and modulo operations.
- pm_integer_t copy;
- pm_integer_copy(&copy, integer);
-
- // Then, iterate through the integer, dividing by 10 and storing the
- // result in the buffer.
- char *ending = digits + length - 1;
- char *current = ending;
-
- while (copy.length > 0 || copy.head.value > 0) {
- uint32_t remainder = pm_integer_divide(&copy, 10);
- *current-- = (char) ('0' + remainder);
+ // Pack bigdecimal to digits.
+ for (size_t i = 0; i < converted.length; i++) {
+ uint32_t v = converted.values[i];
+ for (size_t j = 0; j < 9; j++) {
+ digits[char_length - 9 * i - j - 1] = (char) ('0' + v % 10);
+ v /= 10;
+ }
}
+ size_t start_offset = 0;
+ while (start_offset < char_length - 1 && digits[start_offset] == '0') start_offset++;
// Finally, append the string to the buffer and free the digits.
- pm_buffer_append_string(buffer, current + 1, (size_t) (ending - current));
- xfree(digits);
+ pm_buffer_append_string(buffer, digits + start_offset, char_length - start_offset);
+ free(digits);
+ free(converted.values);
return;
}
}