On Thu, 19 Feb 2026 20:01:11 GMT, Ben Perez <[email protected]> wrote:
> That would be really useful! I tinkered with it a bit but would be nice to
> see what you had in mind
Like this:
address generate_intpoly_montgomeryMult_P256() {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_intpoly_montgomeryMult_P256_id;
StubCodeMark mark(this, stub_id);
address start = __ pc();
__ enter();
static const int64_t modulus[] = {
0x000fffffffffffffL, 0x00000fffffffffffL,
0x0000001000000000L, 0x0000ffffffff0000L,
0L
};
int shift1 = 12; // 64 - bits per limb
int shift2 = 52; // bits per limb
// Registers that are used throughout entire routine
const Register a = c_rarg0;
const Register b = c_rarg1;
const Register result = c_rarg2;
RegSet regs = RegSet::range(r0, r28) + rfp + lr - a - b - result;
FloatRegSet floatRegs = FloatRegSet::range(v0, v31)
- FloatRegSet::range(v8, v15) // Caller saved vectors
- FloatRegSet::range(v16, v31); // Manually-allocated vectors
auto common_regs = regs.begin();
Register limb_mask = *common_regs++,
c_ptr = *common_regs++,
mod_0 = *common_regs++,
mod_1 = *common_regs++,
mod_3 = *common_regs++,
mod_4 = *common_regs++,
b_0 = *common_regs++,
b_1 = *common_regs++,
b_2 = *common_regs++,
b_3 = *common_regs++,
b_4 = *common_regs++;
regs = common_regs.remaining();
auto common_vectors = floatRegs.begin();
FloatRegister limb_mask_vec = *common_vectors++,
b_lows = *common_vectors++,
b_highs = *common_vectors++,
a_vals = *common_vectors++;
// Push callee saved registers on to the stack
RegSet callee_saved = RegSet::range(r19, r28);
__ push(callee_saved, sp);
// Allocate space on the stack for carry values
__ sub(sp, sp, 48);
__ mov(c_ptr, sp);
// Calculate limb mask
__ mov(limb_mask, -UCONST64(1) >> (64 - shift2));
__ dup(limb_mask_vec, __ T2D, limb_mask);
// Load input arrays and modulus
{
auto r = regs.begin();
Register a_ptr = *r++, mod_ptr = *r++;
__ add(a_ptr, a, 24);
__ lea(mod_ptr, ExternalAddress((address)modulus));
__ ldr(b_0, Address(b));
__ ldr(b_1, Address(b, 8));
__ ldr(b_2, Address(b, 16));
__ ldr(b_3, Address(b, 24));
__ ldr(b_4, Address(b, 32));
__ ldr(mod_0, __ post(mod_ptr, 8));
__ ldr(mod_1, __ post(mod_ptr, 8));
__ ldr(mod_3, __ post(mod_ptr, 8));
__ ldr(mod_4, mod_ptr);
__ ld1(a_vals, __ T2D, a_ptr);
__ ld2(b_lows, b_highs, __ T4S, b);
}
//Regs used throughout the main "loop", which is partially unrolled here
auto loop_regs = regs.begin();
Register high = *loop_regs++,
low = *loop_regs++,
mod_high = *loop_regs++,
mod_low = *loop_regs++,
a_i = *loop_regs++,
c_i = *loop_regs++,
tmp = *loop_regs++,
n = *loop_regs++;
VSeq<4> A(16);
VSeq<4> B(20);
VSeq<4> C(24);
VSeq<4> D(28);
/////////////////////////
{
auto r = loop_regs;
Register mul_ptr = *r++;
__ sub(sp, sp, 128);
__ mov(mul_ptr, sp);
neon_partial_mult_64(A, b_lows, a_vals, 0);
// Limb 0
__ ldr(a_i, __ post(a, 8));
__ umulh(high, a_i, b_0);
__ mul(low, a_i, b_0);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ andr(n, low, limb_mask);
neon_partial_mult_64(B, b_highs, a_vals, 0);
// Limb 0 cont
__ umulh(mod_high, n, mod_0);
__ mul(mod_low, n, mod_0);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ lsr(c_i, low, shift2);
__ add(c_i, c_i, high);
neon_partial_mult_64(C, b_lows, a_vals, 1);
// Limb 1
__ umulh(high, a_i, b_1);
__ mul(low, a_i, b_1);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
neon_partial_mult_64(D, b_highs, a_vals, 1);
__ umulh(mod_high, n, mod_1);
__ mul(mod_low, n, mod_1);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, c_ptr);
__ mov(c_i, high);
vs_addv(B, __ T2D, B, C); // Store (B+C) in B
// Limb 2
__ umulh(high, a_i, b_2);
__ mul(low, a_i, b_2);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 8));
__ mov(c_i, high);
vs_shl(D, __ T2D, D, 12);
// Limb 3
__ umulh(high, a_i, b_3); //compute next mult to avoid waiting for result
__ mul(low, a_i, b_3);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
vs_ushr(C, __ T2D, B, 20); // Use C for ((B+C) >>> 20)
__ umulh(mod_high, n, mod_3);
__ mul(mod_low, n, mod_3);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 16));
__ mov(c_i, high);
vs_shl(B, __ T2D, B, 32);
// Limb 4
__ umulh(high, a_i, b_4);
__ mul(low, a_i, b_4);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
vs_addv(D, __ T2D, D, C);
__ umulh(mod_high, n, mod_4);
__ mul(mod_low, n, mod_4);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 24));
__ str(high, Address(c_ptr, 32));
vs_ushr(C, __ T2D, A, 52); // C now holds (A >>> 52)
vs_andr(B, B, limb_mask_vec);
vs_andr(A, A, limb_mask_vec);
vs_addv(D, __ T2D, D, C);
vs_addv(A, __ T2D, A, B);
vs_ushr(B, __ T2D, A, shift2);
vs_andr(A, A, limb_mask_vec);
vs_addv(D, __ T2D, D, B);
__ st1(A[0], __ T2D, __ post(mul_ptr, 16));
__ st1(D[0], __ T2D, __ post(mul_ptr, 16));
__ st1(A[1], __ T2D, __ post(mul_ptr, 16));
__ st1(D[1], __ T2D, __ post(mul_ptr, 16));
__ st1(A[2], __ T2D, __ post(mul_ptr, 16));
__ st1(D[2], __ T2D, __ post(mul_ptr, 16));
__ st1(A[3], __ T2D, __ post(mul_ptr, 16));
__ st1(D[3], __ T2D, mul_ptr);
}
/////////////////////////
// Loop 2 & 3
/////////////////////////
for (int i = 0; i < 2; i++) {
// Load a_i and increment by 8 bytes
__ ldr(a_i, __ post(a, 8));
__ ldr(c_i, c_ptr); //Load prior c_i
// Limb 0
__ umulh(high, a_i, b_0);
__ mul(low, a_i, b_0);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ add(low, low, c_i);
__ ldr(c_i, Address(c_ptr, 8));
__ andr(n, low, limb_mask);
__ umulh(mod_high, n, mod_0);
__ mul(mod_low, n, mod_0);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ lsr(tmp, low, shift2);
__ add(c_i, c_i, tmp);
__ add(c_i, c_i, high);
// Limb 1
__ umulh(high, a_i, b_1);
__ mul(low, a_i, b_1);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ umulh(mod_high, n, mod_1);
__ mul(mod_low, n, mod_1);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ ldr(tmp, Address(c_ptr, 16));
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, c_ptr);
__ add(c_i, tmp, high);
// Limb 2
__ umulh(high, a_i, b_2);
__ mul(low, a_i, b_2);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ ldr(tmp, Address(c_ptr, 24));
__ andr(low, low, limb_mask);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 8));
__ add(c_i, tmp, high);
// Limb 3
__ umulh(high, a_i, b_3);
__ mul(low, a_i, b_3);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ umulh(mod_high, n, mod_3);
__ mul(mod_low, n, mod_3);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ ldr(tmp, Address(c_ptr, 32));
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 16));
__ add(c_i, tmp, high);
// Limb 4
__ umulh(high, a_i, b_4);
__ mul(low, a_i, b_4);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ umulh(mod_high, n, mod_4);
__ mul(mod_low, n, mod_4);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 24));
__ str(high, Address(c_ptr, 32));
}
Register low_1 = *loop_regs++;
Register high_1 = *loop_regs++;
//////////////////////////////
// a[3]
//////////////////////////////
__ ldr(low_1, Address(sp));
__ ldr(high_1, Address(sp, 16));
__ ldr(low, Address(sp, 8));
__ ldr(high, Address(sp, 24));
__ ldr(a_i, __ post(a, 8));
__ ldr(c_i, c_ptr);
// Limb 1
__ add(low_1, low_1, c_i);
__ ldr(c_i, Address(c_ptr, 8));
__ andr(n, low_1, limb_mask);
__ umulh(mod_high, n, mod_0);
__ mul(mod_low, n, mod_0);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low_1, low_1, mod_low);
__ add(high_1, high_1, mod_high);
__ lsr(tmp, low_1, shift2);
__ add(c_i, c_i, tmp);
__ add(c_i, c_i, high_1);
// Limb 2
__ ldr(low_1, Address(sp, 32));
__ ldr(high_1, Address(sp, 48));
__ umulh(mod_high, n, mod_1);
__ mul(mod_low, n, mod_1);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ ldr(tmp, Address(c_ptr, 16));
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, c_ptr);
__ add(c_i, tmp, high);
// Limb 2
__ ldr(low, Address(sp, 40));
__ ldr(high, Address(sp, 56));
__ ldr(tmp, Address(c_ptr, 24));
__ add(c_i, c_i, low_1);
__ str(c_i, Address(c_ptr, 8));
__ add(c_i, tmp, high_1);
// Limb 3
__ umulh(mod_high, n, mod_3);
__ mul(mod_low, n, mod_3);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ ldr(tmp, Address(c_ptr, 32));
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ add(c_i, c_i, low);
__ str(c_i, Address(c_ptr, 16));
__ add(c_i, tmp, high);
// Limb 4
__ ldr(low, Address(sp, 64));
__ ldr(high, Address(sp, 80));
__ umulh(high_1, a_i, b_4);
__ mul(low_1, a_i, b_4);
__ lsl(high_1, high_1, shift1);
__ lsr(tmp, low_1, shift2);
__ orr(high_1, high_1, tmp);
__ andr(low_1, low_1, limb_mask);
__ umulh(mod_high, n, mod_4);
__ mul(mod_low, n, mod_4);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low_1, low_1, mod_low);
__ add(high_1, high_1, mod_high);
__ add(c_i, c_i, low_1);
__ str(c_i, Address(c_ptr, 24));
__ str(high_1, Address(c_ptr, 32));
//////////////////////////////
// a[4]
//////////////////////////////
__ ldr(a_i, a);
__ ldr(c_i, c_ptr);
// Limb 0
__ ldr(low_1, Address(sp, 72));
__ ldr(high_1, Address(sp, 88));
__ add(low, low, c_i);
__ ldr(c_i, Address(c_ptr, 8));
__ andr(n, low, limb_mask);
__ umulh(mod_high, n, mod_0);
__ mul(mod_low, n, mod_0);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
__ lsr(tmp, low, shift2);
__ add(c_i, c_i, tmp);
__ add(c_i, c_i, high);
// Limb 1
__ ldr(low, Address(sp, 96));
__ ldr(high, Address(sp, 112));
__ umulh(mod_high, n, mod_1);
__ mul(mod_low, n, mod_1);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low_1, low_1, mod_low);
__ add(high_1, high_1, mod_high);
loop_regs = loop_regs.remaining()
+ b_0 + b_1 + b_2;
b_0 = b_1 = b_2 = noreg; // reuse
Register c5 = *loop_regs++;
__ add(c5, c_i, low_1);
__ ldr(c_i, Address(c_ptr, 16));
__ lsr(tmp, c5, shift2);
__ add(c_i, c_i, tmp);
__ add(c_i, c_i, high_1);
// Limb 2
__ ldr(low_1, Address(sp, 104));
__ ldr(high_1, Address(sp, 120));
Register c6 = * loop_regs++;
__ add(c6, c_i, low);
__ ldr(c_i, Address(c_ptr, 24));
__ lsr(tmp, c6, shift2);
__ add(c_i, c_i, tmp);
__ add(c_i, c_i, high);
// Limb 3
__ umulh(mod_high, n, mod_3);
__ mul(mod_low, n, mod_3);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low_1, low_1, mod_low);
__ add(high_1, high_1, mod_high);
Register c7 = *loop_regs++;
__ add(c7, c_i, low_1);
__ ldr(c_i, Address(c_ptr, 32));
__ lsr(tmp, c7, shift2);
__ add(c_i, c_i, tmp);
__ add(c_i, c_i, high_1);
// Limb 4
__ umulh(high, a_i, b_4);
__ mul(low, a_i, b_4);
__ lsl(high, high, shift1);
__ lsr(tmp, low, shift2);
__ orr(high, high, tmp);
__ andr(low, low, limb_mask);
__ umulh(mod_high, n, mod_4);
__ mul(mod_low, n, mod_4);
__ lsl(mod_high, mod_high, shift1);
__ lsr(tmp, mod_low, shift2);
__ orr(mod_high, mod_high, tmp);
__ andr(mod_low, mod_low, limb_mask);
__ add(low, low, mod_low);
__ add(high, high, mod_high);
/////////////////////////////
// Final carry propagate
/////////////////////////////
// c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB);
// c6 += (c5 >>> BITS_PER_LIMB);
// c7 += (c6 >>> BITS_PER_LIMB);
// c8 += (c7 >>> BITS_PER_LIMB);
// c9 += (c8 >>> BITS_PER_LIMB);
Register c8 = *loop_regs++;
Register c9 = *loop_regs++;
__ add(c8, c_i, low);
__ lsr(c9, c8, shift2);
__ add(c9, c9, high);
__ andr(c5, c5, limb_mask);
__ andr(c6, c6, limb_mask);
__ andr(c7, c7, limb_mask);
__ andr(c8, c8, limb_mask);
// c0 = c5 - modulus[0];
// c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
// c0 &= LIMB_MASK;
// c2 = c7 + (c1 >> BITS_PER_LIMB);
// c1 &= LIMB_MASK;
// c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
// c2 &= LIMB_MASK;
// c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
// c3 &= LIMB_MASK;
// Push back dead registers
loop_regs = loop_regs.remaining()
+ high + low + low_1 + mod_high + mod_low + a_i + c_i + n;
high = low = low_1 = mod_high = mod_low = a_i = c_i = n = noreg;
Register c0 = *loop_regs++,
c1 = *loop_regs++,
c2 = *loop_regs++,
c3 = *loop_regs++,
c4 = *loop_regs++,
tmp0 = *loop_regs++,
tmp1 = *loop_regs++,
tmp2 = *loop_regs++,
tmp3 = *loop_regs++,
tmp4 = *loop_regs++;
__ sub(c0, c5, mod_0);
__ sub(c1, c6, mod_1);
__ sub(c3, c8, mod_3);
__ sub(c4, c9, mod_4);
__ add(c1, c1, c0, Assembler::ASR, shift2);
__ andr(c0, c0, limb_mask);
__ add(c2, c7, c1, Assembler::ASR, shift2);
__ andr(c1, c1, limb_mask);
__ add(c3, c3, c2, Assembler::ASR, shift2);
__ andr(c2, c2, limb_mask);
__ add(c4, c4, c3, Assembler::ASR, shift2);
__ andr(c3, c3, limb_mask);
// Final write back
// mask = c4 >> 63
// r[0] = ((c5 & mask) | (c0 & ~mask));
// r[1] = ((c6 & mask) | (c1 & ~mask));
// r[2] = ((c7 & mask) | (c2 & ~mask));
// r[3] = ((c8 & mask) | (c3 & ~mask));
// r[4] = ((c9 & mask) | (c4 & ~mask));
Register mask = *loop_regs++;
Register nmask = *loop_regs++;
__ asr(mask, c4, 63);
__ mvn(nmask, mask);
__ andr(c5, c5, mask);
__ andr(tmp, c0, nmask);
__ orr(c5, c5, tmp);
__ andr(c6, c6, mask);
__ andr(tmp, c1, nmask);
__ orr(c6, c6, tmp);
__ andr(c7, c7, mask);
__ andr(tmp, c2, nmask);
__ orr(c7, c7, tmp);
__ andr(c8, c8, mask);
__ andr(tmp, c3, nmask);
__ orr(c8, c8, tmp);
__ andr(c9, c9, mask);
__ andr(tmp, c4, nmask);
__ orr(c9, c9, tmp);
__ str(c5, result);
__ str(c6, Address(result, 8));
__ str(c7, Address(result, 16));
__ str(c8, Address(result, 24));
__ str(c9, Address(result, 32));
// End intrinsic call
__ add(sp, sp, 176);
__ pop(callee_saved, sp);
__ leave();
__ mov(r0, zr); // return 0
__ ret(lr);
return start;
}
-------------
PR Review Comment: https://git.openjdk.org/jdk/pull/27946#discussion_r2834210598