replace the stack by a bump allocator.  this rids the algorithm of all the
mallocs, reallocs and frees it did, doubling its throughput.

`./benchmark' prev:

    zmul(c, a, b): 0.001108253 ms (152 bits)          (GMP x 104)
    zmodmul(c, a, b, d): 0.002137137 ms (152 bits)    (GMP x 34)
    zpow(c, a, d): 0.059366386 ms (152 bits)          (GMP x 616)

`./benchmark' current:

    zmul(c, a, b): 0.000560419 ms (152 bits)          (GMP x 52)
    zmodmul(c, a, b, d): 0.001514106 ms (152 bits)    (GMP x 24)
    zpow(c, a, d): 0.031881215 ms (152 bits)          (GMP x 331)

zmul_get_nb is a simplification of the following function.  as a function of u32
(actually, [2^5,2^32)), it is quick to test, that all 3 give the same result, or
the same result + 64 bytes (fine, as long as it overshoots).  it assumes an n=m
multiply.

    uint64_t
    get_nb(uint32_t n)
    {
        double ex, bits_to_alloc, bits_carried_into;
        size_t bytes_to_alloc, bytes_to_alloc_low;
        size_t N = n,nb;

        ex = ceil(log2(N)-1.0);

        bits_to_alloc = N * (
            (1.0 - pow(1.5, ex)) /
            (1.0 - 1.5)
        );
        bits_carried_into = 2 * (
            (1.0 - pow(3.0, ex)) /
            (1.0 - 3.0)
        );

        bytes_to_alloc = bits_to_alloc / 8.0;
        bytes_to_alloc_low = (bits_carried_into + bits_to_alloc) / 8.0;

        nb = bytes_to_alloc + bytes_to_alloc_low;
        nb = ((nb+64)/64)*64;
        return nb;
    }
---
 src/zmul.c | 143 +++++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 127 insertions(+), 16 deletions(-)

diff --git a/src/zmul.c b/src/zmul.c
index bd8d311..1b09122 100644
--- a/src/zmul.c
+++ b/src/zmul.c
@@ -2,6 +2,47 @@
 #include "internals.h"
 
 
+struct zmul_heap {
+       size_t len, cap;
+       unsigned char * buf;
+};
+
+static void zmul_ll_recurse(z_t, z_t, z_t, struct zmul_heap *);
+static void zmul_bump_alloc_temps(size_t, z_t, z_t, z_t, z_t, struct zmul_heap 
*);
+static size_t zmul_get_nb(uint32_t);
+
+void
+zmul_ll(z_t a, z_t b, z_t c)
+{
+       uint64_t N, nb;
+       struct zmul_heap h;
+
+       N = zbits(b);
+       if (b != c) {
+               size_t w=zbits(c);
+               if (w>N) N=w;
+       }
+       if (N < (BITS_PER_CHAR>>1)) {
+               zmul_ll_recurse(a, b, c, 0);
+               return;
+       }
+       if (N >> 32) {
+               libzahl_failure(ERANGE);
+       }
+
+       nb = zmul_get_nb(N);
+       memset(&h, 0, sizeof(h));
+       h.buf = malloc(nb);
+       if (check(!h.buf)) {
+               libzahl_memfailure();
+       }
+       h.cap = nb;
+
+       zmul_ll_recurse(a, b, c, &h);
+
+       free(h.buf);
+}
+
 static inline void
 zmul_ll_single_char(z_t a, z_t b, z_t c)
 {
@@ -11,18 +52,18 @@ zmul_ll_single_char(z_t a, z_t b, z_t c)
        SET_SIGNUM(a, 1);
 }
 
-void
-zmul_ll(z_t a, z_t b, z_t c)
+static void
+zmul_ll_recurse(z_t a, z_t b, z_t c, struct zmul_heap * h)
 {
        /*
         * Karatsuba algorithm
-        * 
+        *
         * Basically, this is how you were taught to multiply large numbers
         * by hand in school: 4010⋅3020 = (4000 + 10)(3000 + 20) =
         * = 40⋅30⋅10⁴ + (40⋅20 + 30⋅10)⋅10² + 10⋅20, but the middle is
         * optimised to only one multiplication:
         * 40⋅20 + 30⋅10 = (40 + 10)(30 + 20) − 40⋅30 − 10⋅20.
-        * This optimisation is crucial. Without it, the algorithm with
+        * This optimisation is crucial. Without it, the algorithm would
         * run in O(n²).
         */
 
@@ -45,23 +86,20 @@ zmul_ll(z_t a, z_t b, z_t c)
                return;
        }
 
-        m = MAX(m, m2);
+       m = MAX(m, m2);
        m2 = m >> 1;
 
-       zinit_temp(b_high);
-       zinit_temp(b_low);
-       zinit_temp(c_high);
-       zinit_temp(c_low);
+       zmul_bump_alloc_temps(m, b_high, c_high, b_low, c_low, h);
 
        zsplit_pz(b_high, b_low, b, m2);
        zsplit_pz(c_high, c_low, c, m2);
 
 
-       zmul_ll(z0, b_low, c_low);
+       zmul_ll_recurse(z0, b_low, c_low, h);
        zadd_unsigned_assign(b_low, b_high);
        zadd_unsigned_assign(c_low, c_high);
-       zmul_ll(z1, b_low, c_low);
-       zmul_ll(z2, b_high, c_high);
+       zmul_ll_recurse(z1, b_low, c_low, h);
+       zmul_ll_recurse(z2, b_high, c_high, h);
 
        zsub_nonnegative_assign(z1, z0);
        zsub_nonnegative_assign(z1, z2);
@@ -71,10 +109,83 @@ zmul_ll(z_t a, z_t b, z_t c)
        zlsh(z2, z2, m2);
        zadd_unsigned_assign(a, z1);
        zadd_unsigned_assign(a, z2);
+}
+
+static void
+zmul_bump_alloc_temps(size_t m, z_t b_high, z_t c_high, z_t b_low, z_t c_low, 
struct zmul_heap * h)
+{
+       size_t ha,b_m_low,c_m_low,ba,ca;
+       m = (m+1)>>1;
+
+       /* account for carry */
+       b_m_low = (m + 1)*2 + 1 + m;
+       c_m_low = b_m_low + m;
+
+       /* high temps are constants in Karatsuba; no carry */
+       ha = 1 + ((m+BITS_PER_CHAR)/BITS_PER_CHAR);
+
+       ba = 1 + ((b_m_low+BITS_PER_CHAR)/BITS_PER_CHAR);
+       ca = 1 + ((c_m_low+BITS_PER_CHAR)/BITS_PER_CHAR);
+
+       b_high->alloced = c_high->alloced = ha;
+
+       b_high->chars = (void *)&h->buf[h->len];
+       h->len += ha<<3;
 
+       c_high->chars = (void *)&h->buf[h->len];
+       h->len += ha<<3;
 
-       zfree_temp(c_low);
-       zfree_temp(c_high);
-       zfree_temp(b_low);
-       zfree_temp(b_high);
+       b_low->alloced = ba;
+       b_low->chars = (void *)&h->buf[h->len];
+       h->len += ba<<3;
+
+       c_low->alloced = ca;
+       c_low->chars = (void *)&h->buf[h->len];
+       h->len += ca<<3;
+
+       b_high->used = c_high->used = b_low->used = c_low->used = 0;
+       b_high->sign = c_high->sign = b_low->sign = c_low->sign = 0;
+
+       if (check(unlikely(h->len > h->cap))) {
+               if (h->buf) free(h->buf);
+               h->buf = 0;
+               libzahl_memfailure();
+       }
+}
+
+/* tttbl[i] = pow(3,i)-pow(2,i) */
+const static uint64_t tttbl[32] = {0ULL, 1ULL, 5ULL, 19ULL, 65ULL, 211ULL, 
665ULL, 2059ULL, 6305ULL, 19171ULL, 58025ULL, 175099ULL, 527345ULL, 1586131ULL, 
4766585ULL, 14316139ULL, 42981185ULL, 129009091ULL, 387158345ULL, 
1161737179ULL, 3485735825ULL, 10458256051ULL, 31376865305ULL, 94134790219ULL, 
282412759265ULL, 847255055011ULL, 2541798719465ULL, 7625463267259ULL, 
22876524019505ULL, 68629840493971ULL, 205890058352825ULL, 617671248800299ULL};
+
+static uint64_t
+zmul_get_nb(uint32_t n)
+{
+       uint64_t x,ttx,bits_to_alloc,bits_carried_into,ret;
+#if defined(__GNUC__) || defined(__clang__)
+       unsigned __int128 num;
+       /* x=ceil(log2(n)-1.0) */
+       x = 31 - !(n & (n-1)) - __builtin_clzg(n);
+       ttx = tttbl[x];
+       /* bits_to_alloc = N * ((1.0 - pow(1.5, x)) / (1.0 - 1.5)).  need 84 
bits */
+       num = ttx;
+       num *= n;
+       bits_to_alloc = num >> (x-1);
+#else
+       uint64_t num_hi,num_lo,tmp,tmp2;
+       x = BITS_PER_CHAR - 1 - !(n & (n-1));
+       ZAHL_SUB_CLZ(x, (zahl_char_t)n);
+       ttx = tttbl[x];
+       num_lo = (ttx & 0xFFFFFFFF) * n;
+       tmp = (ttx >> 32) * n;
+       num_hi = tmp >> 32;
+       tmp2=tmp<<32;
+       num_hi += libzahl_add_overflow(&num_lo, num_lo, tmp2);
+       bits_to_alloc = (num_hi << (65-x)) | (num_lo >> (x-1));
+#endif
+       bits_to_alloc++;
+
+       /* bits_carried_into = 2.0 * ((1.0 - pow(3.0, x)) / (1.0 - 3.0)) */
+       bits_carried_into = ttx + (1ULL<<x) - 1;
+       ret = ((bits_to_alloc<<1) + bits_carried_into)>>3;
+       ret = ((ret + 64)>>6)<<6;
+       return ret;
 }
-- 
2.53.0



Reply via email to