From c4e44224cf617c8cd33a734f888c045ac9575226 Mon Sep 17 00:00:00 2001 From: Dean Rasheed Date: Thu, 15 Aug 2024 10:33:12 +0100 Subject: [PATCH] Extend mul_var_short() to 5 and 6-digit inputs. Commit ca481d3c9a introduced mul_var_short(), which is used by mul_var() whenever the shorter input has 1-4 NBASE digits and the exact product is requested. As speculated on in that commit, it can be extended to work for more digits in the shorter input. This commit extends it up to 6 NBASE digits (up to 24 decimal digits), for which it also gives a significant speedup. This covers more cases likely to occur in real-world queries, for which using base-NBASE^2 arithmetic provides little benefit. To avoid code bloat and duplication, refactor it a bit using macros and exploiting the fact that some portions of the code are shared between the different cases. Dean Rasheed, reviewed by Joel Jacobson. Discussion: https://postgr.es/m/9d8a4a42-c354-41f3-bbf3-199e1957db97%40app.fastmail.com --- src/backend/utils/adt/numeric.c | 175 ++++++++++++++++++++++---------- 1 file changed, 123 insertions(+), 52 deletions(-) diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 763a7f4be0f..2a74312d354 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -8714,10 +8714,10 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, } /* - * If var1 has 1-4 digits and the exact result was requested, delegate to + * If var1 has 1-6 digits and the exact result was requested, delegate to * mul_var_short() which uses a faster direct multiplication algorithm. */ - if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale) + if (var1ndigits <= 6 && rscale == var1->dscale + var2->dscale) { mul_var_short(var1, var2, result); return; @@ -8876,7 +8876,7 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, /* * mul_var_short() - * - * Special-case multiplication function used when var1 has 1-4 digits, var2 + * Special-case multiplication function used when var1 has 1-6 digits, var2 * has at least as many digits as var1, and the exact product var1 * var2 is * requested. */ @@ -8898,7 +8898,7 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2, /* Check preconditions */ Assert(var1ndigits >= 1); - Assert(var1ndigits <= 4); + Assert(var1ndigits <= 6); Assert(var2ndigits >= var1ndigits); /* @@ -8925,6 +8925,13 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2, * carry up as we go. The i'th result digit consists of the sum of the * products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1. */ +#define PRODSUM1(v1,i1,v2,i2) ((v1)[(i1)] * (v2)[(i2)]) +#define PRODSUM2(v1,i1,v2,i2) (PRODSUM1(v1,i1,v2,i2) + (v1)[(i1)+1] * (v2)[(i2)-1]) +#define PRODSUM3(v1,i1,v2,i2) (PRODSUM2(v1,i1,v2,i2) + (v1)[(i1)+2] * (v2)[(i2)-2]) +#define PRODSUM4(v1,i1,v2,i2) (PRODSUM3(v1,i1,v2,i2) + (v1)[(i1)+3] * (v2)[(i2)-3]) +#define PRODSUM5(v1,i1,v2,i2) (PRODSUM4(v1,i1,v2,i2) + (v1)[(i1)+4] * (v2)[(i2)-4]) +#define PRODSUM6(v1,i1,v2,i2) (PRODSUM5(v1,i1,v2,i2) + (v1)[(i1)+5] * (v2)[(i2)-5]) + switch (var1ndigits) { case 1: @@ -8936,9 +8943,9 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2, * ---------- */ carry = 0; - for (int i = res_ndigits - 2; i >= 0; i--) + for (int i = var2ndigits - 1; i >= 0; i--) { - term = (uint32) var1digits[0] * var2digits[i] + carry; + term = PRODSUM1(var1digits, 0, var2digits, i) + carry; res_digits[i + 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; } @@ -8954,23 +8961,17 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2, * ---------- */ /* last result digit and carry */ - term = (uint32) var1digits[1] * var2digits[res_ndigits - 3]; + term = PRODSUM1(var1digits, 1, var2digits, var2ndigits - 1); res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; /* remaining digits, except for the first two */ - for (int i = res_ndigits - 3; i >= 1; i--) + for (int i = var2ndigits - 1; i >= 1; i--) { - term = (uint32) var1digits[0] * var2digits[i] + - (uint32) var1digits[1] * var2digits[i - 1] + carry; + term = PRODSUM2(var1digits, 0, var2digits, i) + carry; res_digits[i + 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; } - - /* first two digits */ - term = (uint32) var1digits[0] * var2digits[0] + carry; - res_digits[1] = (NumericDigit) (term % NBASE); - res_digits[0] = (NumericDigit) (term / NBASE); break; case 3: @@ -8982,34 +8983,21 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2, * ---------- */ /* last two result digits */ - term = (uint32) var1digits[2] * var2digits[res_ndigits - 4]; + term = PRODSUM1(var1digits, 2, var2digits, var2ndigits - 1); res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; - term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] + - (uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry; + term = PRODSUM2(var1digits, 1, var2digits, var2ndigits - 1) + carry; res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); carry = term / NBASE; /* remaining digits, except for the first three */ - for (int i = res_ndigits - 4; i >= 2; i--) + for (int i = var2ndigits - 1; i >= 2; i--) { - term = (uint32) var1digits[0] * var2digits[i] + - (uint32) var1digits[1] * var2digits[i - 1] + - (uint32) var1digits[2] * var2digits[i - 2] + carry; + term = PRODSUM3(var1digits, 0, var2digits, i) + carry; res_digits[i + 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; } - - /* first three digits */ - term = (uint32) var1digits[0] * var2digits[1] + - (uint32) var1digits[1] * var2digits[0] + carry; - res_digits[2] = (NumericDigit) (term % NBASE); - carry = term / NBASE; - - term = (uint32) var1digits[0] * var2digits[0] + carry; - res_digits[1] = (NumericDigit) (term % NBASE); - res_digits[0] = (NumericDigit) (term / NBASE); break; case 4: @@ -9021,45 +9009,128 @@ mul_var_short(const NumericVar *var1, const NumericVar *var2, * ---------- */ /* last three result digits */ - term = (uint32) var1digits[3] * var2digits[res_ndigits - 5]; + term = PRODSUM1(var1digits, 3, var2digits, var2ndigits - 1); res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; - term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] + - (uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry; + term = PRODSUM2(var1digits, 2, var2digits, var2ndigits - 1) + carry; res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); carry = term / NBASE; - term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] + - (uint32) var1digits[2] * var2digits[res_ndigits - 6] + - (uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry; + term = PRODSUM3(var1digits, 1, var2digits, var2ndigits - 1) + carry; res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE); carry = term / NBASE; /* remaining digits, except for the first four */ - for (int i = res_ndigits - 5; i >= 3; i--) + for (int i = var2ndigits - 1; i >= 3; i--) { - term = (uint32) var1digits[0] * var2digits[i] + - (uint32) var1digits[1] * var2digits[i - 1] + - (uint32) var1digits[2] * var2digits[i - 2] + - (uint32) var1digits[3] * var2digits[i - 3] + carry; + term = PRODSUM4(var1digits, 0, var2digits, i) + carry; res_digits[i + 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; } + break; - /* first four digits */ - term = (uint32) var1digits[0] * var2digits[2] + - (uint32) var1digits[1] * var2digits[1] + - (uint32) var1digits[2] * var2digits[0] + carry; - res_digits[3] = (NumericDigit) (term % NBASE); + case 5: + /* --------- + * 5-digit case: + * var1ndigits = 5 + * var2ndigits >= 5 + * res_ndigits = var2ndigits + 5 + * ---------- + */ + /* last four result digits */ + term = PRODSUM1(var1digits, 4, var2digits, var2ndigits - 1); + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); carry = term / NBASE; - term = (uint32) var1digits[0] * var2digits[1] + - (uint32) var1digits[1] * var2digits[0] + carry; - res_digits[2] = (NumericDigit) (term % NBASE); + term = PRODSUM2(var1digits, 3, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = PRODSUM3(var1digits, 2, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE); carry = term / NBASE; - term = (uint32) var1digits[0] * var2digits[0] + carry; + term = PRODSUM4(var1digits, 1, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 4] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first five */ + for (int i = var2ndigits - 1; i >= 4; i--) + { + term = PRODSUM5(var1digits, 0, var2digits, i) + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + break; + + case 6: + /* --------- + * 6-digit case: + * var1ndigits = 6 + * var2ndigits >= 6 + * res_ndigits = var2ndigits + 6 + * ---------- + */ + /* last five result digits */ + term = PRODSUM1(var1digits, 5, var2digits, var2ndigits - 1); + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = PRODSUM2(var1digits, 4, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = PRODSUM3(var1digits, 3, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = PRODSUM4(var1digits, 2, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 4] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = PRODSUM5(var1digits, 1, var2digits, var2ndigits - 1) + carry; + res_digits[res_ndigits - 5] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first six */ + for (int i = var2ndigits - 1; i >= 5; i--) + { + term = PRODSUM6(var1digits, 0, var2digits, i) + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + break; + } + + /* + * Finally, for var1ndigits > 1, compute the remaining var1ndigits most + * significant result digits. + */ + switch (var1ndigits) + { + case 6: + term = PRODSUM5(var1digits, 0, var2digits, 4) + carry; + res_digits[5] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + /* FALLTHROUGH */ + case 5: + term = PRODSUM4(var1digits, 0, var2digits, 3) + carry; + res_digits[4] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + /* FALLTHROUGH */ + case 4: + term = PRODSUM3(var1digits, 0, var2digits, 2) + carry; + res_digits[3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + /* FALLTHROUGH */ + case 3: + term = PRODSUM2(var1digits, 0, var2digits, 1) + carry; + res_digits[2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + /* FALLTHROUGH */ + case 2: + term = PRODSUM1(var1digits, 0, var2digits, 0) + carry; res_digits[1] = (NumericDigit) (term % NBASE); res_digits[0] = (NumericDigit) (term / NBASE); break; -- 2.30.2