On Thu, 19 Feb 2026 20:01:11 GMT, Ben Perez <[email protected]> wrote:

> That would be really useful! I tinkered with it a bit but would be nice to 
> see what you had in mind

Like this:

  address generate_intpoly_montgomeryMult_P256() {

    __ align(CodeEntryAlignment);
    StubId stub_id = StubId::stubgen_intpoly_montgomeryMult_P256_id;
    StubCodeMark mark(this, stub_id);
    address start = __ pc();
    __ enter();

    static const int64_t modulus[] = {
      0x000fffffffffffffL, 0x00000fffffffffffL,
      0x0000001000000000L, 0x0000ffffffff0000L,
      0L
    };

    int shift1 = 12; // 64 - bits per limb
    int shift2 = 52; // bits per limb

    // Registers that are used throughout entire routine
    const Register a = c_rarg0;
    const Register b = c_rarg1;
    const Register result = c_rarg2;

    RegSet regs = RegSet::range(r0, r28) + rfp + lr - a - b - result;
    FloatRegSet floatRegs = FloatRegSet::range(v0, v31)
      - FloatRegSet::range(v8, v15)   // Caller saved vectors
      - FloatRegSet::range(v16, v31); // Manually-allocated vectors

    auto common_regs = regs.begin();
    Register limb_mask = *common_regs++,
      c_ptr = *common_regs++,
      mod_0 = *common_regs++,
      mod_1 = *common_regs++,
      mod_3 = *common_regs++,
      mod_4 = *common_regs++,
      b_0 = *common_regs++,
      b_1 = *common_regs++,
      b_2 = *common_regs++,
      b_3 = *common_regs++,
      b_4 = *common_regs++;
    regs = common_regs.remaining();

    auto common_vectors = floatRegs.begin();
    FloatRegister limb_mask_vec = *common_vectors++,
      b_lows = *common_vectors++,
      b_highs = *common_vectors++,
      a_vals = *common_vectors++;

    // Push callee saved registers on to the stack
    RegSet callee_saved = RegSet::range(r19, r28);
    __ push(callee_saved, sp);

    // Allocate space on the stack for carry values
    __ sub(sp, sp, 48);
    __ mov(c_ptr, sp);

    // Calculate limb mask
    __ mov(limb_mask, -UCONST64(1) >> (64 - shift2));
    __ dup(limb_mask_vec, __ T2D, limb_mask);

    // Load input arrays and modulus
    {
      auto r = regs.begin();
      Register a_ptr = *r++, mod_ptr = *r++;
      __ add(a_ptr, a, 24);
      __ lea(mod_ptr, ExternalAddress((address)modulus));
      __ ldr(b_0, Address(b));
      __ ldr(b_1, Address(b, 8));
      __ ldr(b_2, Address(b, 16));
      __ ldr(b_3, Address(b, 24));
      __ ldr(b_4, Address(b, 32));
      __ ldr(mod_0, __ post(mod_ptr, 8));
      __ ldr(mod_1, __ post(mod_ptr, 8));
      __ ldr(mod_3, __ post(mod_ptr, 8));
      __ ldr(mod_4, mod_ptr);
      __ ld1(a_vals, __ T2D, a_ptr);
      __ ld2(b_lows, b_highs, __ T4S, b);
    }


    //Regs used throughout the main "loop", which is partially unrolled here
    auto loop_regs = regs.begin();
    Register high = *loop_regs++,
      low = *loop_regs++,
      mod_high = *loop_regs++,
      mod_low = *loop_regs++,
      a_i = *loop_regs++,
      c_i = *loop_regs++,
      tmp = *loop_regs++,
      n = *loop_regs++;

    VSeq<4> A(16);
    VSeq<4> B(20);
    VSeq<4> C(24);
    VSeq<4> D(28);

    /////////////////////////

    {
      auto r = loop_regs;
      Register mul_ptr = *r++;

      __ sub(sp, sp, 128);
      __ mov(mul_ptr, sp);

      neon_partial_mult_64(A, b_lows, a_vals, 0);

      // Limb 0
      __ ldr(a_i, __ post(a, 8));
      __ umulh(high, a_i, b_0);
      __ mul(low, a_i, b_0);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);
      __ andr(n, low, limb_mask);

      neon_partial_mult_64(B, b_highs, a_vals, 0);

      // Limb 0 cont
      __ umulh(mod_high, n, mod_0);
      __ mul(mod_low, n, mod_0);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ lsr(c_i, low, shift2);
      __ add(c_i, c_i, high);

      neon_partial_mult_64(C, b_lows, a_vals, 1);

      // Limb 1
      __ umulh(high, a_i, b_1);
      __ mul(low, a_i, b_1);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);

      neon_partial_mult_64(D, b_highs, a_vals, 1);

      __ umulh(mod_high, n, mod_1);
      __ mul(mod_low, n, mod_1);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ add(c_i, c_i, low);
      __ str(c_i, c_ptr);
      __ mov(c_i, high);

      vs_addv(B, __ T2D, B, C); // Store (B+C) in B

      // Limb 2
      __ umulh(high, a_i, b_2);
      __ mul(low, a_i, b_2);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);
      __ add(c_i, c_i, low);
      __ str(c_i, Address(c_ptr, 8));
      __ mov(c_i, high);

      vs_shl(D, __ T2D, D, 12);

      // Limb 3
      __ umulh(high, a_i, b_3); //compute next mult to avoid waiting for result
      __ mul(low, a_i, b_3);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);

      vs_ushr(C, __ T2D, B, 20); // Use C for ((B+C) >>> 20)

      __ umulh(mod_high, n, mod_3);
      __ mul(mod_low, n, mod_3);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ add(c_i, c_i, low);
      __ str(c_i, Address(c_ptr, 16));
      __ mov(c_i, high);

      vs_shl(B, __ T2D, B, 32);

      // Limb 4
      __ umulh(high, a_i, b_4);
      __ mul(low, a_i, b_4);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);

      vs_addv(D, __ T2D, D, C);

      __ umulh(mod_high, n, mod_4);
      __ mul(mod_low, n, mod_4);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ add(c_i, c_i, low);
      __ str(c_i, Address(c_ptr, 24));
      __ str(high, Address(c_ptr, 32));

      vs_ushr(C, __ T2D, A, 52); // C now holds (A >>> 52)
      vs_andr(B, B, limb_mask_vec);
      vs_andr(A, A, limb_mask_vec);
      vs_addv(D, __ T2D, D, C);
      vs_addv(A, __ T2D, A, B);

      vs_ushr(B, __ T2D, A, shift2);
      vs_andr(A, A, limb_mask_vec);
      vs_addv(D, __ T2D, D, B);

      __ st1(A[0], __ T2D, __ post(mul_ptr, 16));
      __ st1(D[0], __ T2D, __ post(mul_ptr, 16));
      __ st1(A[1], __ T2D, __ post(mul_ptr, 16));
      __ st1(D[1], __ T2D, __ post(mul_ptr, 16));

      __ st1(A[2], __ T2D, __ post(mul_ptr, 16));
      __ st1(D[2], __ T2D, __ post(mul_ptr, 16));
      __ st1(A[3], __ T2D, __ post(mul_ptr, 16));
      __ st1(D[3], __ T2D, mul_ptr);
    }
    /////////////////////////
    // Loop 2 & 3
    /////////////////////////

    for (int i = 0; i < 2; i++) {
      // Load a_i and increment by 8 bytes
      __ ldr(a_i, __ post(a, 8));
      __ ldr(c_i, c_ptr); //Load prior c_i

      // Limb 0
      __ umulh(high, a_i, b_0);
      __ mul(low, a_i, b_0);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);
      __ add(low, low, c_i);
      __ ldr(c_i, Address(c_ptr, 8));
      __ andr(n, low, limb_mask);
      __ umulh(mod_high, n, mod_0);
      __ mul(mod_low, n, mod_0);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ lsr(tmp, low, shift2);
      __ add(c_i, c_i, tmp);
      __ add(c_i, c_i, high);

      // Limb 1
      __ umulh(high, a_i, b_1);
      __ mul(low, a_i, b_1);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);
      __ umulh(mod_high, n, mod_1);
      __ mul(mod_low, n, mod_1);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ ldr(tmp, Address(c_ptr, 16));
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ add(c_i, c_i, low);
      __ str(c_i, c_ptr);
      __ add(c_i, tmp, high);

      // Limb 2
      __ umulh(high, a_i, b_2);
      __ mul(low, a_i, b_2);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ ldr(tmp, Address(c_ptr, 24));
      __ andr(low, low, limb_mask);
      __ add(c_i, c_i, low);
      __ str(c_i, Address(c_ptr, 8));
      __ add(c_i, tmp, high);

      // Limb 3
      __ umulh(high, a_i, b_3);
      __ mul(low, a_i, b_3);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);
      __ umulh(mod_high, n, mod_3);
      __ mul(mod_low, n, mod_3);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ ldr(tmp, Address(c_ptr, 32));
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ add(c_i, c_i, low);
      __ str(c_i, Address(c_ptr, 16));
      __ add(c_i, tmp, high);

      // Limb 4
      __ umulh(high, a_i, b_4);
      __ mul(low, a_i, b_4);
      __ lsl(high, high, shift1);
      __ lsr(tmp, low, shift2);
      __ orr(high, high, tmp);
      __ andr(low, low, limb_mask);
      __ umulh(mod_high, n, mod_4);
      __ mul(mod_low, n, mod_4);
      __ lsl(mod_high, mod_high, shift1);
      __ lsr(tmp, mod_low, shift2);
      __ orr(mod_high, mod_high, tmp);
      __ andr(mod_low, mod_low, limb_mask);
      __ add(low, low, mod_low);
      __ add(high, high, mod_high);
      __ add(c_i, c_i, low);
      __ str(c_i, Address(c_ptr, 24));
      __ str(high, Address(c_ptr, 32));
    }

    Register low_1 = *loop_regs++;
    Register high_1 = *loop_regs++;

    //////////////////////////////
    // a[3]
    //////////////////////////////

    __ ldr(low_1, Address(sp));
    __ ldr(high_1, Address(sp, 16));

    __ ldr(low, Address(sp, 8));
    __ ldr(high, Address(sp, 24));
    __ ldr(a_i, __ post(a, 8));
    __ ldr(c_i, c_ptr);

    // Limb 1
    __ add(low_1, low_1, c_i);
    __ ldr(c_i, Address(c_ptr, 8));
    __ andr(n, low_1, limb_mask);
    __ umulh(mod_high, n, mod_0);
    __ mul(mod_low, n, mod_0);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low_1, low_1, mod_low);
    __ add(high_1, high_1, mod_high);
    __ lsr(tmp, low_1, shift2);
    __ add(c_i, c_i, tmp);
    __ add(c_i, c_i, high_1);

    // Limb 2
    __ ldr(low_1, Address(sp, 32));
    __ ldr(high_1, Address(sp, 48));
    __ umulh(mod_high, n, mod_1);
    __ mul(mod_low, n, mod_1);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ ldr(tmp, Address(c_ptr, 16));
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low, low, mod_low);
    __ add(high, high, mod_high);
    __ add(c_i, c_i, low);
    __ str(c_i, c_ptr);
    __ add(c_i, tmp, high);

    // Limb 2
    __ ldr(low, Address(sp, 40));
    __ ldr(high, Address(sp, 56));
    __ ldr(tmp, Address(c_ptr, 24));
    __ add(c_i, c_i, low_1);
    __ str(c_i, Address(c_ptr, 8));
    __ add(c_i, tmp, high_1);

    // Limb 3
    __ umulh(mod_high, n, mod_3);
    __ mul(mod_low, n, mod_3);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ ldr(tmp, Address(c_ptr, 32));
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low, low, mod_low);
    __ add(high, high, mod_high);
    __ add(c_i, c_i, low);
    __ str(c_i, Address(c_ptr, 16));
    __ add(c_i, tmp, high);

    // Limb 4
    __ ldr(low, Address(sp, 64));
    __ ldr(high, Address(sp, 80));
    __ umulh(high_1, a_i, b_4);
    __ mul(low_1, a_i, b_4);
    __ lsl(high_1, high_1, shift1);
    __ lsr(tmp, low_1, shift2);
    __ orr(high_1, high_1, tmp);
    __ andr(low_1, low_1, limb_mask);
    __ umulh(mod_high, n, mod_4);
    __ mul(mod_low, n, mod_4);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low_1, low_1, mod_low);
    __ add(high_1, high_1, mod_high);
    __ add(c_i, c_i, low_1);
    __ str(c_i, Address(c_ptr, 24));
    __ str(high_1, Address(c_ptr, 32));

    //////////////////////////////
    // a[4]
    //////////////////////////////
    __ ldr(a_i, a);
    __ ldr(c_i, c_ptr);

    // Limb 0
    __ ldr(low_1, Address(sp, 72));
    __ ldr(high_1, Address(sp, 88));

    __ add(low, low, c_i);
    __ ldr(c_i, Address(c_ptr, 8));
    __ andr(n, low, limb_mask);
    __ umulh(mod_high, n, mod_0);
    __ mul(mod_low, n, mod_0);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low, low, mod_low);
    __ add(high, high, mod_high);
    __ lsr(tmp, low, shift2);
    __ add(c_i, c_i, tmp);
    __ add(c_i, c_i, high);

    // Limb 1
    __ ldr(low, Address(sp, 96));
    __ ldr(high, Address(sp, 112));
    __ umulh(mod_high, n, mod_1);
    __ mul(mod_low, n, mod_1);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low_1, low_1, mod_low);
    __ add(high_1, high_1, mod_high);

    loop_regs = loop_regs.remaining()
      + b_0 + b_1 + b_2;
        b_0 = b_1 = b_2 = noreg; // reuse
    Register c5 = *loop_regs++;

    __ add(c5, c_i, low_1);
    __ ldr(c_i, Address(c_ptr, 16));
    __ lsr(tmp, c5, shift2);
    __ add(c_i, c_i, tmp);
    __ add(c_i, c_i, high_1);

    // Limb 2
    __ ldr(low_1, Address(sp, 104));
    __ ldr(high_1, Address(sp, 120));
    Register c6 = * loop_regs++;
    __ add(c6, c_i, low);
    __ ldr(c_i, Address(c_ptr, 24));
    __ lsr(tmp, c6, shift2);
    __ add(c_i, c_i, tmp);
    __ add(c_i, c_i, high);

    // Limb 3
    __ umulh(mod_high, n, mod_3);
    __ mul(mod_low, n, mod_3);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low_1, low_1, mod_low);
    __ add(high_1, high_1, mod_high);

    Register c7 = *loop_regs++;
    __ add(c7, c_i, low_1);
    __ ldr(c_i, Address(c_ptr, 32));
    __ lsr(tmp, c7, shift2);
    __ add(c_i, c_i, tmp);
    __ add(c_i, c_i, high_1);

    // Limb 4
    __ umulh(high, a_i, b_4);
    __ mul(low, a_i, b_4);
    __ lsl(high, high, shift1);
    __ lsr(tmp, low, shift2);
    __ orr(high, high, tmp);
    __ andr(low, low, limb_mask);
    __ umulh(mod_high, n, mod_4);
    __ mul(mod_low, n, mod_4);
    __ lsl(mod_high, mod_high, shift1);
    __ lsr(tmp, mod_low, shift2);
    __ orr(mod_high, mod_high, tmp);
    __ andr(mod_low, mod_low, limb_mask);
    __ add(low, low, mod_low);
    __ add(high, high, mod_high);

    /////////////////////////////
    // Final carry propagate
    /////////////////////////////

    // c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB);
    // c6 += (c5 >>> BITS_PER_LIMB);
    // c7 += (c6 >>> BITS_PER_LIMB);
    // c8 += (c7 >>> BITS_PER_LIMB);
    // c9 += (c8 >>> BITS_PER_LIMB);

    Register c8 = *loop_regs++;
    Register c9 = *loop_regs++;
    __ add(c8, c_i, low);
    __ lsr(c9, c8, shift2);
    __ add(c9, c9, high);
    __ andr(c5, c5, limb_mask);
    __ andr(c6, c6, limb_mask);
    __ andr(c7, c7, limb_mask);
    __ andr(c8, c8, limb_mask);


    // c0 = c5 - modulus[0];
    // c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
    // c0 &= LIMB_MASK;
    // c2 = c7 + (c1 >> BITS_PER_LIMB);
    // c1 &= LIMB_MASK;
    // c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
    // c2 &= LIMB_MASK;
    // c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
    // c3 &= LIMB_MASK;

    // Push back dead registers
    loop_regs = loop_regs.remaining()
      + high + low + low_1 + mod_high + mod_low + a_i + c_i + n;
        high = low = low_1 = mod_high = mod_low = a_i = c_i = n = noreg;

    Register c0 = *loop_regs++,
      c1 = *loop_regs++,
      c2 = *loop_regs++,
      c3 = *loop_regs++,
      c4 = *loop_regs++,
      tmp0 = *loop_regs++,
      tmp1 = *loop_regs++,
      tmp2 = *loop_regs++,
      tmp3 = *loop_regs++,
      tmp4 = *loop_regs++;

    __ sub(c0, c5, mod_0);
    __ sub(c1, c6, mod_1);
    __ sub(c3, c8, mod_3);
    __ sub(c4, c9, mod_4);
    __ add(c1, c1, c0, Assembler::ASR, shift2);
    __ andr(c0, c0, limb_mask);
    __ add(c2, c7, c1, Assembler::ASR, shift2);
    __ andr(c1, c1, limb_mask);
    __ add(c3, c3, c2, Assembler::ASR, shift2);
    __ andr(c2, c2, limb_mask);
    __ add(c4, c4, c3, Assembler::ASR, shift2);
    __ andr(c3, c3, limb_mask);

    // Final write back
    // mask = c4 >> 63
    // r[0] = ((c5 & mask) | (c0 & ~mask));
    // r[1] = ((c6 & mask) | (c1 & ~mask));
    // r[2] = ((c7 & mask) | (c2 & ~mask));
    // r[3] = ((c8 & mask) | (c3 & ~mask));
    // r[4] = ((c9 & mask) | (c4 & ~mask));

    Register mask = *loop_regs++;
    Register nmask = *loop_regs++;

    __ asr(mask, c4, 63);
    __ mvn(nmask, mask);
    __ andr(c5, c5, mask);
    __ andr(tmp, c0, nmask);
    __ orr(c5, c5, tmp);
    __ andr(c6, c6, mask);
    __ andr(tmp, c1, nmask);
    __ orr(c6, c6, tmp);
    __ andr(c7, c7, mask);
    __ andr(tmp, c2, nmask);
    __ orr(c7, c7, tmp);
    __ andr(c8, c8, mask);
    __ andr(tmp, c3, nmask);
    __ orr(c8, c8, tmp);
    __ andr(c9, c9, mask);
    __ andr(tmp, c4, nmask);
    __ orr(c9, c9, tmp);

    __ str(c5, result);
    __ str(c6, Address(result, 8));
    __ str(c7, Address(result, 16));
    __ str(c8, Address(result, 24));
    __ str(c9, Address(result, 32));

    // End intrinsic call
    __ add(sp, sp, 176);
    __ pop(callee_saved, sp);
    __ leave();
    __ mov(r0, zr); // return 0
    __ ret(lr);

    return start;
  }

-------------

PR Review Comment: https://git.openjdk.org/jdk/pull/27946#discussion_r2834210598

Reply via email to