On Thu, 20 Mar 2025 20:37:25 GMT, Ferenc Rakoczi <d...@openjdk.org> wrote:
>> By using the AVX-512 vector registers the speed of the computation of the >> ML-DSA algorithms (key generation, document signing, signature verification) >> can be approximately doubled. > > Ferenc Rakoczi has updated the pull request incrementally with one additional > commit since the last revision: > > Fix windows build was going to finish the rest of the functions.. but I see you pushed an update so I better rebase! here are the pending comments I had that perhaps are no longer applicable.. (working through the ntt math..) src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 121: > 119: static void montmulEven(int outputReg, int inputReg1, int inputReg2, > 120: int scratchReg1, int scratchReg2, > 121: int parCnt, MacroAssembler *_masm) { nitpick.. this could be made to look more like `montMul64()` by also taking in an array of registers. src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 160: > 158: for (int i = 0; i < 4; i++) { > 159: __ vpmuldq(xmm(scratchRegs[i]), xmm(inputRegs1[i]), > xmm(inputRegs2[i]), > 160: Assembler::AVX_512bit); using an array of registers, instead of array of ints would read somewhat more compact and fewer 'indirections' . i.e. static void montMul64(XMMRegister outputRegs*, XMMRegister inputRegs1*, XMMRegister inputRegs2*, ... __ vpmuldq(scratchRegs[i], inputRegs1[i], inputRegs2[i], Assembler::AVX_512bit); src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 216: > 214: // Zmm8-Zmm23 used as scratch registers > 215: // result goes to Zmm0-Zmm7 > 216: static void montMulByConst128(MacroAssembler *_masm) { wish the inputs and output register arrays were explicit.. easier to follow that way src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 230: > 228: } > 229: > 230: static void sub_add(int subResult[], int addResult[], Big fan of all these helper functions! Makes reading the top level functions way easier, thanks for refactoring! src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 279: > 277: static int xmm4_20_24[] = {4, 5, 6, 7, 20, 21, 22, 23, 24, 25, 26, 27}; > 278: static int xmm16_27[] = {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}; > 279: static int xmm29_29[] = {29, 29, 29, 29}; I very much like the new refactor, waaaay clearer now. Some 'Could Do' comments.. - I probably would have preferred 'even more symbolic' variable names (i.e. its ideal when you can match the java variable names!). Conversely, if 'forced to defend this style', these names are MUCH much easier to debug from GDB, its clear what the matching instruction is. - Not sure about it being global. It works currently, but less 'future proof'. src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 645: > 643: // poly1 (int[256]) = c_rarg1 > 644: // poly2 (int[256]) = c_rarg2 > 645: static address generate_dilithiumNttMult_avx512(StubGenerator *stubgen, This would be 'nice to have', something 'lost' with the refactor.. As I was reviewing this (original) function, I was thinking, "there is nothing here _that_ specific to AVX512, mostly columnar&independent operations... This function could be made 'vector-length-independent'..." - double the loop length: int iter = vector_len==Assembler::AVX_512bit?4:8; __ movl(len, 4); -> __ movl(len, iter); - halve the register arrays.. (or keep them the same but shuffle them to make SURE the first half are in xmm0-xmm15 range) XMMRegister POLY1[] = {xmm0, xmm1, xmm12, xmm13}; XMMRegister POLY2[] = {xmm4, xmm5, xmm16, xmm17}; XMMRegister SCRATCH1[] = {xmm2, xmm3, xmm14, xmm15}; <<< here XMMRegister SCRATCH2[] = {xmm6, xmm7, xmm18, xmm19}; <<< and here XMMRegister SCRATCH3[] = {xmm8, xmm9, xmm10, xmm11}; - couple of other int constants (like the memory 'step' and such) - for assembler calls, like `evmovdqul` and `evpsubd`, need a few small new MacroAssembler helpers to instead generate VEX encoded versions (plenty of instructions already do that). - I think only the perm instruction was unique to evex (didnt really think of an alternative for AVX2.. but can be abstracted away with another helper) Anyway; not suggesting its something you do here.. but it would be convenient to leave breadcrumbs/hooks for a future update so one of us can revisit this code and add AVX2 support. e.g. `parCnt` variable was very convenient before for exactly this, now its gone... it probably could be derived in each function from vector_len but..; Its now cleaner, but also harder to 'upgrade'? Why AVX2? many of the newer (Atom/Ecore-based/EnableX86ECoreOpts) processors do not have AVX512 support, so its something I've been prioritizing recently The alternative would be to write a completely separate AVX2 implementation, but that would be a shame, not to 'just' reuse this code. "For fun", I had even gone and parametrized the mult function with the `vector_len` to see how it would look (almost identical... to the original version): static void montmulEven2(XMMRegister* outputReg, XMMRegister* inputReg1, XMMRegister* inputReg2, XMMRegister* scratchReg1, XMMRegister* scratchReg2, XMMRegister montQInvModR, XMMRegister dilithium_q, int parCnt, int vector_len, MacroAssembler* _masm) { for (int i = 0; i < parCnt; i++) { // scratch1 = (int64)input1_even*input2_even // Java: long a = (long) b * (long) c; __ vpmuldq(scratchReg1[i], inputReg1[i], inputReg2[i], vector_len); } for (int i = 0; i < parCnt; i++) { // scratch2 = int32(montQInvModR*(int32)scratch1) // Java: int aLow = (int) a; // Java: int m = MONT_Q_INV_MOD_R * aLow; // signed low product __ vpmulld(scratchReg2[i], scratchReg1[i], montQInvModR, vector_len); } for (int i = 0; i < parCnt; i++) { // scratch2 = (int64)scratch2_even*dilithium_q_even // Java: ((long)m * MONT_Q) __ vpmuldq(scratchReg2[i], scratchReg2[i], dilithium_q, vector_len); } for (int i = 0; i < parCnt; i++) { // output_odd = scratch1_odd - scratch2_odd // Java: (aHigh - (int) (("scratch2") >> MONT_R_BITS)) __ vpsubd(outputReg[i], scratchReg1[i], scratchReg2[i], vector_len); } } ------------- PR Review: https://git.openjdk.org/jdk/pull/23860#pullrequestreview-2708079853 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008809855 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008811046 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008811541 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008811704 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008808110 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008824304