On Mon, 24 Mar 2025 15:16:20 GMT, Volodymyr Paprotski <vpaprot...@openjdk.org> wrote:
>> Ferenc Rakoczi has updated the pull request incrementally with two >> additional commits since the last revision: >> >> - Further readability improvements. >> - Added asserts for array sizes > > I still need to have a look at the sha3 changes, but I think I am done with > the most complex part of the review. This was a really interesting bit of > code to review! @vpaprotsk , thanks a lot for the very thorough review! > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 270: > >> 268: } >> 269: >> 270: static void loadPerm(int destinationRegs[], Register perms, > > `replXmm`? i.e. this function is replicating (any) Xmm register, not just > perm?.. Since I am only using it for permutation describers, I thought this way it is easier to follow what is happening. > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 327: > >> 325: // >> 326: // >> 327: static address generate_dilithiumAlmostNtt_avx512(StubGenerator >> *stubgen, > > Similar comments as to `generate_dilithiumAlmostInverseNtt_avx512` > > - similar comment about the 'pair-wise' operation, updating `[j]` and `[j+l]` > at a time.. > - somehow had less trouble following the flow through registers here, perhaps > I am getting used to it. FYI, ended renaming some as: > > // xmm16_27 = Temp1 > // xmm0_3 = Coeffs1 > // xmm4_7 = Coeffs2 > // xmm8_11 = Coeffs3 > // xmm12_15 = Coeffs4 = Temp2 > // xmm16_27 = Scratch For me, it was easier to follow what goes where using the xmm... names (with the symbolic names you always have to remember which one overlaps with another and how much). > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 421: > >> 419: for (int i = 0; i < 8; i += 2) { >> 420: __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), >> Assembler::AVX_512bit); >> 421: } > > Wish there was a more 'abstract' way to arrange this, so its obvious from the > shape of the code what registers are input/outputs (i.e. and use the register > arrays). Even though its just 'elementary index operations' `i/2 + 16` is > still 'clever'. Couldnt think of anything myself though (same elsewhere in > this function for the table permutes). Well, this is how it is when we have three inputs, one of which also plays as output... At least the output is always the first one (so that one gets clobbered). This is why you have to replicate the permutation describer when you need both permutands later. > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 509: > >> 507: // coeffs (int[256]) = c_rarg0 >> 508: // zetas (int[256]) = c_rarg1 >> 509: static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator >> *stubgen, > > Done with this function; Perhaps the 'permute table' is a common > vector-algorithm pattern, but this is really clever! > > Some general comments first, rest inline. > > - The array names for registers helped a lot. And so did the new helper > functions! > - The java version of this code is quite intimidating to vectorize.. 3D loop, > with geometric iteration variables.. and the literature is even more > intimidating (discrete convolutions which I havent touched in two decades, > ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next > reader, though feel free to reword however you like. > > The core of the (Java) loop is this 'pair-wise' operation: > int a = coeffs[j]; > int b = coeffs[j + offset]; > coeffs[j] = (a + b); > coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]); > > There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer > (Java) loop) > At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, > 128). > > To vectorize this Java code, observe that at each level, REGARDLESS the > offset, half the operations are the SUM, and the other half is the > montgomery MULTIPLICATION (of the pair-difference with a constant). At each > level, one 'just' has to shuffle > the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly. > > Otherwise, this pattern is 'lightly similar' to a discrete convolution > (compute integral/summation of two functions at every offset) > > - I still would prefer (more) symbolic register names.. I wouldn't hold my > approval over it so won't object if nobody else does, but register numbers > are harder to 'see' through the flow. I ended up > search/replacing/'annotating' to make it easier on myself to follow the flow > of data: > > // xmm8_11 = Perms1 > // xmm12_15 = Perms2 > // xmm16_27 = Scratch > // xmm0_3 = CoeffsPlus > // xmm4_7 = CoeffsMul > // xmm24_27 = CoeffsMinus (overlaps with Scratch) > > (I made a similar comment, but I think it is now hidden after the last > refactor) > - would prefer to see the helper functions to get ALL the registers passed > explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, > `xmm29`, are implicit.). As a general rule, I've tried to set up all the > registers up at the 'entry' function (`generate_dilithium*` in this case) and > ... I added some more comments, but I kept the xmm... names for the registers, just like with the ntt function. > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 554: > >> 552: for (int i = 0; i < 8; i += 2) { >> 553: __ evpermi2d(xmm(i / 2 + 8), xmm(i), xmm(i + 1), >> Assembler::AVX_512bit); >> 554: __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), >> Assembler::AVX_512bit); > > Took a bit to unscramble the flow, so a comment needed? Purpose 'fairly > obvious' once I got the general shape of the level/algorithm (as per my > top-level comment) but something like "shuffle xmm0-7 into xmm8-15"? I hope the comment that I added at the beginning of the function sheds some light on the purpose of these permutations. > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 656: > >> 654: for (int i = 0; i < 8; i++) { >> 655: __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, >> Assembler::AVX_512bit); >> 656: } > > Fairly clean as is, but could also be two sub_add calls, I think (you have to > swap order of add/sub in the helper, to be able to clobber `xmm(i)`.. or swap > register usage downstream, so perhaps not.. but would be cleaner) > > sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm); > sub_add(CoeffsMul, &Scratch[4], Perms2, CoeffsMul, _masm); > > > If nothing else, would had prefered to see the use of the register array > variables I would rather leave this alone, too. I was considering the same, but decided that this is fairly easy to follow, it would be more complicated to either add a new helper function or follow where there are overlaps in the symbolically named register sets. > src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 871: > >> 869: __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, >> Assembler::AVX_512bit); >> 870: __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, >> Assembler::AVX_512bit); >> 871: __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, >> Assembler::AVX_512bit); > > Fairly 'straightforward' transcription of the java code.. no comments from me. > > At first glance using `xmm0_3`, `xmm4_7`, etc. might had been a good idea, > but you only save one line per 4x group. (Unless you have one big loop, but I > suspect that give you worse performance? Is that something you tried already? > Might be worth it otherwise..) I have considered this but decided to leave it alone (for the reason that you mentioned). > src/java.base/share/classes/sun/security/provider/ML_DSA.java line 1418: > >> 1416: int twoGamma2, int >> multiplier) { >> 1417: assert (input.length == ML_DSA_N) && (lowPart.length == >> ML_DSA_N) >> 1418: && (highPart.length == ML_DSA_N); > > I wrote this test to test java-to-intrinsic correspondence. Might be good to > include it (and add the other 4 intrinsics). This is very similar to all my > other *Fuzz* tests I've been adding for my own intrinsics (and you made this > test FAR easier to write by breaking out the java implementation; need to > 'copy' that pattern myself) > > import java.util.Arrays; > import java.util.Random; > > import java.lang.invoke.MethodHandle; > import java.lang.invoke.MethodHandles; > import java.lang.reflect.Field; > import java.lang.reflect.Method; > import java.lang.reflect.Constructor; > > public class ML_DSA_Intrinsic_Test { > > public static void main(String[] args) throws Exception { > MethodHandles.Lookup lookup = MethodHandles.lookup(); > Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA"); > Constructor<?> constructor = kClazz.getDeclaredConstructor( > int.class); > constructor.setAccessible(true); > > Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply", > int[].class, int[].class, int[].class); > m.setAccessible(true); > MethodHandle mult = lookup.unreflect(m); > > m = kClazz.getDeclaredMethod("implDilithiumNttMultJava", > int[].class, int[].class, int[].class); > m.setAccessible(true); > MethodHandle multJava = lookup.unreflect(m); > > Random rnd = new Random(); > long seed = rnd.nextLong(); > rnd.setSeed(seed); > //Note: it might be useful to increase this number during development > of new intrinsics > final int repeat = 1000000; > int[] coeffs1 = new int[ML_DSA_N]; > int[] coeffs2 = new int[ML_DSA_N]; > int[] prod1 = new int[ML_DSA_N]; > int[] prod2 = new int[ML_DSA_N]; > try { > for (int i = 0; i < repeat; i++) { > run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, > seed, i); > } > System.out.println("Fuzz Success"); > } catch (Throwable e) { > System.out.println("Fuzz Failed: " + e); > } > } > > private static final int ML_DSA_N = 256; > public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] > coeffs2, > MethodH... We will consider it for a follow-up PR. ------------- PR Comment: https://git.openjdk.org/jdk/pull/23860#issuecomment-2766414076 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021150966 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021151152 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021151361 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021151680 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021152095 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021152962 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021154571 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2021156249