diff options
Diffstat (limited to 'prism/util/pm_integer.c')
-rw-r--r-- | prism/util/pm_integer.c | 332 |
1 files changed, 167 insertions, 165 deletions
diff --git a/prism/util/pm_integer.c b/prism/util/pm_integer.c index 5bcb508c1c..4caedc4121 100644 --- a/prism/util/pm_integer.c +++ b/prism/util/pm_integer.c @@ -1,20 +1,14 @@ #include "prism/util/pm_integer.h" /** - * Bigint with arbitary base. In practice, base is 1<<32 or 10**9. - * When base is 10**9, it acts as bigdecimal. + * Adds two positive pm_integer_t with the given base. + * Return pm_integer_t with values allocated. Not normalized. */ -typedef struct { - size_t length; - uint32_t *values; -} bigint_t; - -/** - * Adds two bigint_t with the given base. - */ -static bigint_t -big_add(bigint_t left, bigint_t right, uint64_t base) { - size_t length = (left.length < right.length ? right.length : left.length); +static pm_integer_t +big_add(pm_integer_t left_, pm_integer_t right_, uint64_t base) { + pm_integer_t left = left_.values ? left_ : (pm_integer_t) { 0, 1, &left_.value, false }; + pm_integer_t right = right_.values ? right_ : (pm_integer_t) { 0, 1, &right_.value, false }; + size_t length = left.length < right.length ? right.length : left.length; uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * (length + 1)); uint64_t carry = 0; for (size_t i = 0; i < length; i++) { @@ -26,15 +20,19 @@ big_add(bigint_t left, bigint_t right, uint64_t base) { values[length] = (uint32_t) carry; length++; } - return (bigint_t) { length, values }; + return (pm_integer_t) { 0, length, values, false }; } /** - * Calculates `a - b - c` with the given base. - * Result is assumed to be positive value. Internal use for karatsuba_multiply. + * Internal use for karatsuba_multiply. Calculates `a - b - c` with the given + * base. Assume a, b, c, a - b - c all to be poitive. + * Return pm_integer_t with values allocated. Not normalized. */ -static bigint_t -big_sub2(bigint_t a, bigint_t b, bigint_t c, uint64_t base) { +static pm_integer_t +big_sub2(pm_integer_t a_, pm_integer_t b_, pm_integer_t c_, uint64_t base) { + pm_integer_t a = a_.values ? a_ : (pm_integer_t) { 0, 1, &a_.value, false }; + pm_integer_t b = b_.values ? b_ : (pm_integer_t) { 0, 1, &b_.value, false }; + pm_integer_t c = c_.values ? c_ : (pm_integer_t) { 0, 1, &c_.value, false }; size_t length = a.length; uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * length); int64_t carry = 0; @@ -50,16 +48,19 @@ big_sub2(bigint_t a, bigint_t b, bigint_t c, uint64_t base) { } } while (length > 1 && values[length - 1] == 0) length--; - return (bigint_t) { length, values }; + return (pm_integer_t) { 0, length, values, false }; } /** - * Multiply two bigint_t with the given base using karatsuba algorithm. + * Multiply two positive integers with the given base using karatsuba algorithm. + * Return pm_integer_t with values allocated. Not normalized. */ -static bigint_t -karatsuba_multiply(bigint_t left, bigint_t right, uint64_t base) { +static pm_integer_t +karatsuba_multiply(pm_integer_t left_, pm_integer_t right_, uint64_t base) { + pm_integer_t left = left_.values ? left_ : (pm_integer_t) { 0, 1, &left_.value, false }; + pm_integer_t right = right_.values ? right_ : (pm_integer_t) { 0, 1, &right_.value, false }; if (left.length > right.length) { - bigint_t temp = left; + pm_integer_t temp = left; left = right; right = temp; } @@ -76,40 +77,40 @@ karatsuba_multiply(bigint_t left, bigint_t right, uint64_t base) { values[i + right.length] = carry; } while (length > 1 && values[length - 1] == 0) length--; - return (bigint_t) { length, values }; + return (pm_integer_t) { 0, length, values, false }; } if (left.length * 2 <= right.length) { uint32_t *values = (uint32_t*) calloc(left.length + right.length, sizeof(uint32_t)); for (size_t start_offset = 0; start_offset < right.length; start_offset += left.length) { size_t end_offset = start_offset + left.length; if (end_offset > right.length) end_offset = right.length; - bigint_t sliced_right = { end_offset - start_offset, right.values + start_offset }; - bigint_t v = karatsuba_multiply(left, sliced_right, base); + pm_integer_t sliced_right = { 0, end_offset - start_offset, right.values + start_offset, false }; + pm_integer_t v = karatsuba_multiply(left, sliced_right, base); uint32_t carry = 0; for (size_t i = 0; i < v.length; i++) { uint64_t sum = (uint64_t) values[start_offset + i] + v.values[i] + carry; values[start_offset + i] = (uint32_t) (sum % base); carry = (uint32_t) (sum / base); } - free(v.values); - values[start_offset + v.length] += carry; + if (carry > 0) values[start_offset + v.length] += carry; + pm_integer_free(&v); } - return (bigint_t) { left.length + right.length, values }; + return (pm_integer_t) { 0, left.length + right.length, values, false }; } size_t half = left.length / 2; - bigint_t x0 = { half, left.values }; - bigint_t x1 = { left.length - half, left.values + half }; - bigint_t y0 = { half, right.values }; - bigint_t y1 = { right.length - half, right.values + half }; - bigint_t z0 = karatsuba_multiply(x0, y0, base); - bigint_t z2 = karatsuba_multiply(x1, y1, base); + pm_integer_t x0 = { 0, half, left.values, false }; + pm_integer_t x1 = { 0, left.length - half, left.values + half, false }; + pm_integer_t y0 = { 0, half, right.values, false }; + pm_integer_t y1 = { 0, right.length - half, right.values + half, false }; + pm_integer_t z0 = karatsuba_multiply(x0, y0, base); + pm_integer_t z2 = karatsuba_multiply(x1, y1, base); // For simplicity to avoid considering negative values, // use `z1 = (x0 + x1) * (y0 + y1) - z0 - z2` instead of original karatsuba algorithm. - bigint_t x01 = big_add(x0, x1, base); - bigint_t y01 = big_add(y0, y1, base); - bigint_t xy = karatsuba_multiply(x01, y01, base); - bigint_t z1 = big_sub2(xy, z0, z2, base); + pm_integer_t x01 = big_add(x0, x1, base); + pm_integer_t y01 = big_add(y0, y1, base); + pm_integer_t xy = karatsuba_multiply(x01, y01, base); + pm_integer_t z1 = big_sub2(xy, z0, z2, base); size_t length = left.length + right.length; uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t)); @@ -127,13 +128,13 @@ karatsuba_multiply(bigint_t left, bigint_t right, uint64_t base) { carry = (uint32_t) (sum / base); } while (length > 1 && values[length - 1] == 0) length--; - free(z0.values); - free(z1.values); - free(z2.values); - free(x01.values); - free(y01.values); - free(xy.values); - return (bigint_t) { length, values }; + pm_integer_free(&z0); + pm_integer_free(&z1); + pm_integer_free(&z2); + pm_integer_free(&x01); + pm_integer_free(&y01); + pm_integer_free(&xy); + return (pm_integer_t) { 0, length, values, false }; } /** @@ -163,67 +164,95 @@ pm_integer_parse_digit(const uint8_t character) { } /** - * Create a bigint_t from uint64_t with the given base. + * Create a pm_integer_t from uint64_t with the given base. */ -static bigint_t -uint64_to_bigint(uint64_t value, uint64_t base) { +static pm_integer_t +pm_integer_from_uint64(uint64_t value, uint64_t base) { + if (value < base) { + return (pm_integer_t) { (uint32_t) value, 0, NULL, false }; + } uint64_t v = value; size_t len = 0; while (value > 0) { len++; value /= base; } - if (len == 0) len = 1; uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * len); for (size_t i = 0; i < len; i++) { values[i] = (uint32_t) (v % base); v /= base; } - return (bigint_t) { len, values }; + return (pm_integer_t) { 0, len, values, false }; } /** - * Convert base of bigint. + * Normalize pm_integer_t. + * Heading zero values will be removed. If the integer fits into uint32_t, + * values is set to NULL, length is set to 0, and value field will be used. + */ +static void +pm_integer_normalize(pm_integer_t *integer) { + if (integer->values == NULL) { + return; + } + while (integer->length > 1 && integer->values[integer->length - 1] == 0) { + integer->length--; + } + if (integer->length > 1) { + return; + } + + uint32_t value = integer->values[0]; + bool negative = integer->negative && value != 0; + pm_integer_free(integer); + *integer = (pm_integer_t) { value, 0, NULL, negative }; +} + +/** + * Convert base of the integer. * In practice, it converts 10**9 to 1<<32 or 1<<32 to 10**9. */ -static bigint_t -karatsuba_convert_base(bigint_t source, uint64_t base_from, uint64_t base_to) { +static pm_integer_t +pm_integer_convert_base(pm_integer_t source_, uint64_t base_from, uint64_t base_to) { + pm_integer_t source = source_.values ? source_ : (pm_integer_t) { 0, 1, &source_.value, source_.negative }; size_t bigints_length = (source.length + 1) / 2; - bigint_t *bigints = (bigint_t*) malloc(sizeof(bigint_t) * bigints_length); + pm_integer_t *bigints = (pm_integer_t*) malloc(sizeof(pm_integer_t) * bigints_length); for (size_t i = 0; i < source.length; i += 2) { uint64_t v = source.values[i] + base_from * (i + 1 < source.length ? source.values[i + 1] : 0); - bigints[i / 2] = uint64_to_bigint(v, base_to); + bigints[i / 2] = pm_integer_from_uint64(v, base_to); } - bigint_t base = uint64_to_bigint(base_from, base_to); + pm_integer_t base = pm_integer_from_uint64(base_from, base_to); while (bigints_length > 1) { size_t new_length = (bigints_length + 1) / 2; - bigint_t new_base = karatsuba_multiply(base, base, base_to); - free(base.values); + pm_integer_t new_base = karatsuba_multiply(base, base, base_to); + pm_integer_free(&base); base = new_base; - bigint_t *new_bigints = (bigint_t*) malloc(sizeof(bigint_t) * new_length); + pm_integer_t *new_bigints = (pm_integer_t*) malloc(sizeof(pm_integer_t) * new_length); for (size_t i = 0; i < bigints_length; i += 2) { if (i + 1 == bigints_length) { new_bigints[i / 2] = bigints[i]; } else { - bigint_t multiplied = karatsuba_multiply(base, bigints[i + 1], base_to); + pm_integer_t multiplied = karatsuba_multiply(base, bigints[i + 1], base_to); new_bigints[i / 2] = big_add(bigints[i], multiplied, base_to); - free(bigints[i].values); - free(bigints[i + 1].values); - free(multiplied.values); + pm_integer_free(&bigints[i]); + pm_integer_free(&bigints[i + 1]); + pm_integer_free(&multiplied); } } free(bigints); bigints = new_bigints; bigints_length = new_length; } - free(base.values); - bigint_t result = bigints[0]; + pm_integer_free(&base); + pm_integer_t result = bigints[0]; + result.negative = source.negative; free(bigints); + pm_integer_normalize(&result); return result; } /** - * Convert digits to bigint_t with the given power-of-two base. + * Convert digits to integer with the given power-of-two base. */ -static bigint_t -big_parse_powof2(uint32_t base, const uint8_t *digits, size_t digits_length) { +static void +pm_integer_parse_powof2(pm_integer_t *integer, uint32_t base, const uint8_t *digits, size_t digits_length) { size_t bit = 1; while (base > (uint32_t) (1 << bit)) bit++; size_t length = (digits_length * bit + 31) / 32; @@ -237,32 +266,31 @@ big_parse_powof2(uint32_t base, const uint8_t *digits, size_t digits_length) { if (32 - shift < bit) values[index + 1] |= value >> (32 - shift); } while (length > 1 && values[length - 1] == 0) length--; - return (bigint_t) { length, values }; + *integer = (pm_integer_t) { 0, length, values, false }; + pm_integer_normalize(integer); } /** - * Convert decimal digits to bigint. + * Convert decimal digits to pm_integer_t. */ -static bigint_t -big_parse_decimal(const uint8_t *digits, size_t digits_length) { - // Construct a bigdecimal from the digits. +static void +pm_integer_parse_decimal(pm_integer_t *integer, const uint8_t *digits, size_t digits_length) { + // Construct a bigdecimal with base = 10**9 from the digits const size_t batch = 9; - const uint64_t batch_base = 1000000000; size_t values_length = (digits_length + batch - 1) / batch; - bigint_t bigint = { values_length, (uint32_t*) calloc(values_length, sizeof(uint32_t)) }; + pm_integer_t decimal = { 0, values_length, (uint32_t*) calloc(values_length, sizeof(uint32_t)), false }; uint32_t v = 0; for (size_t i = 0; i < digits_length; i++) { v = v * 10 + digits[i]; size_t reverse_index = digits_length - i - 1; if (reverse_index % batch == 0) { - bigint.values[reverse_index / batch] = v; + decimal.values[reverse_index / batch] = v; v = 0; } } - // Convert bigint base from 10**9 to 1<<32. - bigint_t converted = karatsuba_convert_base(bigint, batch_base, ((uint64_t) 1 << 32)); - free(bigint.values); - return converted; + // Convert base from 10**9 to 1<<32. + *integer = pm_integer_convert_base(decimal, 1000000000, ((uint64_t) 1 << 32)); + pm_integer_free(&decimal); } /** @@ -277,22 +305,12 @@ pm_integer_parse_big(pm_integer_t *integer, uint32_t multiplier, const uint8_t * if (*start == '_') continue; digits[digits_length++] = (uint8_t) pm_integer_parse_digit(*start); } - // Construct bigint_t from the digits. - bigint_t bigint = - multiplier == 10 ? big_parse_decimal(digits, digits_length) : big_parse_powof2(multiplier, digits, digits_length); - - // Pack bigint_t to pm_integer_t. - integer->length = bigint.length - 1; - integer->head.value = bigint.values[0]; - pm_integer_word_t *current = &integer->head; - for (size_t i = 1; i < bigint.length; i++) { - current->next = malloc(sizeof(pm_integer_word_t)); - current = current->next; - current->value = bigint.values[i]; - current->next = NULL; + // Construct pm_integer_t from the digits. + if (multiplier == 10) { + pm_integer_parse_decimal(integer, digits, digits_length); + } else { + pm_integer_parse_powof2(integer, multiplier, digits, digits_length); } - - free(bigint.values); free(digits); } @@ -351,13 +369,13 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s if (*ptr == '_') continue; value = value * multiplier + pm_integer_parse_digit(*ptr); if (value > UINT32_MAX) { - // If the integer is too large to fit into a single node, then we'll + // If the integer is too large to fit into a single uint32_t, then we'll // parse it as a big integer. pm_integer_parse_big(integer, multiplier, start, end); return; } } - integer->head.value = (uint32_t) value; + integer->value = (uint32_t) value; } /** @@ -365,7 +383,7 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s */ size_t pm_integer_memsize(const pm_integer_t *integer) { - return sizeof(pm_integer_t) + integer->length * sizeof(pm_integer_word_t); + return sizeof(pm_integer_t) + integer->length * sizeof(uint32_t); } /** @@ -378,16 +396,21 @@ pm_integer_compare(const pm_integer_t *left, const pm_integer_t *right) { if (left->negative != right->negative) return left->negative ? -1 : 1; int negative = left->negative ? -1 : 1; - if (left->length < right->length) return -1 * negative; - if (left->length > right->length) return 1 * negative; + if (left->values == right->values) { + if (left->value < right->value) return -1 * negative; + if (left->value > right->value) return 1 * negative; + return 0; + } + + if (left->values == NULL || left->length < right->length) return -1 * negative; + if (right->values == NULL || left->length > right->length) return 1 * negative; - for ( - const pm_integer_word_t *left_word = &left->head, *right_word = &right->head; - left_word != NULL && right_word != NULL; - left_word = left_word->next, right_word = right_word->next - ) { - if (left_word->value < right_word->value) return -1 * negative; - if (left_word->value > right_word->value) return 1 * negative; + for (size_t i = 0; i < left->length; i++) { + size_t index = left->length - i - 1; + uint32_t l = left->values[index]; + uint32_t r = right->values[index]; + if (l < r) return -1 * negative; + if (l > r) return 1 * negative; } return 0; @@ -402,75 +425,54 @@ pm_integer_string(pm_buffer_t *buffer, const pm_integer_t *integer) { pm_buffer_append_byte(buffer, '-'); } - switch (integer->length) { - case 0: { - const uint32_t value = integer->head.value; - pm_buffer_append_format(buffer, "%" PRIu32, value); - return; - } - case 1: { - const uint64_t value = ((uint64_t) integer->head.value) | (((uint64_t) integer->head.next->value) << 32); - pm_buffer_append_format(buffer, "%" PRIu64, value); - return; - } - default: { - // Pack pm_integer_t to bigint_t. - size_t length = integer->length + 1; - uint32_t *values = calloc(length, sizeof(uint32_t)); - const pm_integer_word_t *current = &(integer->head); - for (size_t i = 0; i < length; i++) { - values[i] = current->value; - current = current->next; - } - bigint_t bigint = { length, values }; - // Convert bigint base from 1<<32 to 10**9. - bigint_t converted = karatsuba_convert_base(bigint, (uint64_t) 1 << 32, 1000000000); - free(values); - - // Allocate a buffer that we'll copy the decimal digits into. - size_t char_length = converted.length * 9; - char *digits = calloc(char_length, sizeof(char)); - if (digits == NULL) return; + if (integer->values == NULL) { + pm_buffer_append_format(buffer, "%" PRIu32, integer->value); + return; + } + if (integer->length == 2) { + const uint64_t value = ((uint64_t) integer->values[0]) | ((uint64_t) integer->values[1] << 32); + pm_buffer_append_format(buffer, "%" PRIu64, value); + return; + } - // Pack bigdecimal to digits. - for (size_t i = 0; i < converted.length; i++) { - uint32_t v = converted.values[i]; - for (size_t j = 0; j < 9; j++) { - digits[char_length - 9 * i - j - 1] = (char) ('0' + v % 10); - v /= 10; - } - } - size_t start_offset = 0; - while (start_offset < char_length - 1 && digits[start_offset] == '0') start_offset++; + // Convert base from 1<<32 to 10**9. + pm_integer_t converted = pm_integer_convert_base(*integer, (uint64_t) 1 << 32, 1000000000); - // Finally, append the string to the buffer and free the digits. - pm_buffer_append_string(buffer, digits + start_offset, char_length - start_offset); - free(digits); - free(converted.values); - return; - } + if (converted.values == NULL) { + pm_buffer_append_format(buffer, "%" PRIu32, converted.value); + pm_integer_free(&converted); + return; } -} -/** - * Recursively destroy the linked list of an integer. - */ -static void -pm_integer_word_destroy(pm_integer_word_t *integer) { - if (integer->next != NULL) { - pm_integer_word_destroy(integer->next); + // Allocate a buffer that we'll copy the decimal digits into. + size_t char_length = converted.length * 9; + char *digits = calloc(char_length, sizeof(char)); + if (digits == NULL) return; + + // Pack bigdecimal to digits. + for (size_t i = 0; i < converted.length; i++) { + uint32_t v = converted.values[i]; + for (size_t j = 0; j < 9; j++) { + digits[char_length - 9 * i - j - 1] = (char) ('0' + v % 10); + v /= 10; + } } + size_t start_offset = 0; + while (start_offset < char_length - 1 && digits[start_offset] == '0') start_offset++; - xfree(integer); + // Finally, append the string to the buffer and free the digits. + pm_buffer_append_string(buffer, digits + start_offset, char_length - start_offset); + free(digits); + pm_integer_free(&converted); } /** * Free the internal memory of an integer. This memory will only be allocated if - * the integer exceeds the size of a single node in the linked list. + * the integer exceeds the size of a single uint32_t. */ PRISM_EXPORTED_FUNCTION void pm_integer_free(pm_integer_t *integer) { - if (integer->head.next) { - pm_integer_word_destroy(integer->head.next); + if (integer->values) { + free(integer->values); } } |