On Mon, 12 May 2025 09:05:10 GMT, Ferenc Rakoczi <d...@openjdk.org> wrote:

>> By using the AVX-512 vector registers the speed of the computation of the 
>> ML-KEM algorithms (key generation, encapsulation, decapsulation) can be 
>> approximately doubled.
>
> Ferenc Rakoczi has updated the pull request incrementally with one additional 
> commit since the last revision:
> 
>   Restoring copyright notice on ML_KEM.java

Only reviewed three intrinsics so far, more review to do.

src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 693:

> 691: // a (short[256]) = c_rarg1
> 692: // b (short[256]) = c_rarg2
> 693: // kyberConsts (short[40]) = c_rarg3

kyberConsts is not one of the arguments passed in.

src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 696:

> 694: address generate_kyberAddPoly_2_avx512(StubGenerator *stubgen,
> 695:                                        MacroAssembler *_masm) {
> 696: 

The Java code for "implKyberAddPoly(short[] result, short[] a, short[] b)" does 
BarrettReduction but the intrinsic code here does not. Is that intentional and 
how is the reduction handled?

src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 742:

> 740: // b (short[256]) = c_rarg2
> 741: // c (short[256]) = c_rarg3
> 742: // kyberConsts (short[40]) = c_rarg4

kyberConsts is not one of the arguments passed in.

src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp line 799:

> 797: // parsedLength (int) = c_rarg3
> 798: address generate_kyber12To16_avx512(StubGenerator *stubgen,
> 799:                                     MacroAssembler *_masm) {

If AVX512_VBMI and AVX512_VBMI2 is available, it looks to me that the loop body 
of this algorithm can be implemented using more efficient instructions in 
simple 5 steps:

Step 1:
Load 0-47, 48-95, 96-143, 144-191 condensed bytes into xmm0, xmm1, xmm2, xmm3 
respectively using masked load.

Step 2:
Use vpermb to arrange xmm0 such that bytes 1, 4, 7, ... are duplicated
xmm0 before  b47, b46, ..., b0 where each b is a byte
xmm0 after b47 b46 b46 b45, ......., b5 b4 b4 b3 b2 b1 b1 b0  
Repeat this for xmm1, xmm2, xmm3

Step 3:
Use vpshldvw to shift every word (16 bits) in the xmm0 appropriately with 
variable shift
Shift word 31 by 4, word 30 by 0, ... word 3 by 4, word 2 by 0,  word 1 by 4, 
word 0 by 0
Repeat this for xmm1, xmm2, xmm3

Step 4:
Use vpand to "and" each word element in xmm0 by 0xfff.
Repeat this for xmm1, xmm2, xmm3

Step 5:
Store xmm0 into parsed
Store xmm1 into parsed + 64
Store xmm2 into parsed +128
Store xmm3 into parsed + 192

If you think there is not sufficient time, we could look into it after the 
merge of this PR as well.

-------------

PR Review: https://git.openjdk.org/jdk/pull/24953#pullrequestreview-2837616051
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2087361991
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2087377640
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2087331798
PR Review Comment: https://git.openjdk.org/jdk/pull/24953#discussion_r2087834072

Reply via email to