On Mon, 14 Oct 2024 11:40:01 GMT, Jatin Bhateja <jbhat...@openjdk.org> wrote:

> Hi All,
> 
> This patch adds C2 compiler support for various Float16 operations added by 
> [PR#22128](https://github.com/openjdk/jdk/pull/22128)
> 
> Following is the summary of changes included with this patch:-
> 
> 1. Detection of various Float16 operations through inline expansion or 
> pattern folding idealizations.
> 2. Float16 operations like add, sub, mul, div, max, and min are inferred 
> through pattern folding idealization.
> 3. Float16 SQRT and FMA operation are inferred through inline expansion and 
> their corresponding entry points are defined in the newly added Float16Math 
> class.
>       -    These intrinsics receive unwrapped short arguments encoding IEEE 
> 754 binary16 values.
> 5. New specialized IR nodes for Float16 operations, associated idealizations, 
> and constant folding routines.
> 6. New Ideal type for constant and non-constant Float16 IR nodes. Please 
> refer to [FAQs 
> ](https://github.com/openjdk/jdk/pull/21490#issuecomment-2482867818)for more 
> details.
> 7. Since Float16 uses short as its storage type, hence raw FP16 values are 
> always loaded into general purpose register, but FP16 ISA instructions 
> generally operate over floating point registers, therefore compiler injectes 
> reinterpretation IR before and after Float16 operation nodes to move short 
> value to floating point register and vice versa.
> 8. New idealization routines to optimize redundant reinterpretation chains. 
> HF2S + S2HF = HF
> 6. Auto-vectorization of newly supported scalar operations.
> 7. X86 and AARCH64 backend implementation for all supported intrinsics.
> 9. Functional and Performance validation tests.
> 
> **Missing Pieces:-**
> **-  AARCH64 Backend.**
> 
> Kindly review and share your feedback.
> 
> Best Regards,
> Jatin

Extending on John's thoughts. 
![image](https://github.com/user-attachments/assets/c795e79f-a857-4991-9b8a-c36d8525ba73)

![image](https://github.com/user-attachments/assets/264eeeea-86a0-43ed-a365-88b91e85d9cc)

There are two possibilities of a pattern match here, one rooted at node **A** 
and other at **B**

With pattern match rooted at **A**,  we will need to inject additional ConvHF2F 
after replacing AddF with AddHF to preserve the type semantics of IR graph,  
[significand bit preservation 
constraints](https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/lang/Float.java#L1103)
 for NaN value imposed by Float.float16ToFloat API  makes the idealization 
toward the end infeasible, thereby reducing the operating vector size for FP16 
operation to half of what can be possible, as depicted by following Ideal graph 
fragment. 

![image](https://github.com/user-attachments/assets/0094e613-2c11-40db-b2bb-84ddf6b251f2)

Thus only feasible match is the one rooted at node **B** 
 
![image](https://github.com/user-attachments/assets/22576617-9533-40e2-94f0-dd6048e295dd)


Please consider Java side implimentation of Float16.sqrt

Float16 sqrt(Float16 radicand) {
        return valueOf(Math.sqrt(radicand.doubleValue()));
}


Here, radicand is first upcasted to doubelValue, following 2P+2 rule of IEEE 
754,  square root computed at double precision is not subjected to double 
rounding penalty when final results is down casted to Float16 value.

Following is  the C2 IR for above Java implementation.


 T0 = Param0 (TypeInt::SHORT)

 T1 = CastHF2F T0 
 T2 = CastF2D   T1
 T3 = SqrtD T2

 T4 = ConvD2F T3
 T5 = CastF2HF T4


To replace SqrtD with SqrtHF,  we need following IR modifications. 


 T0 = Param0 (TypeInt::SHORT)
 // Replacing IR T1-T3  in original fragment with following IR T1-T6.  
 T1 = ReinterpretS2HF T0
 T3 = SqrtHF T1
 T4 = ReinterpretHF2S T3
 T5 = ConvHF2F  T4
 T6 = ConvF2D T5
 
T7 = ConvD2F T6
T5 = CastF2HF T4

  
Simplified IR after applying Identity rules ,  


 T0 = Param0 (TypeInt::SHORT)
 // Replacing IR T1-T3  in original fragment with following IR T1-T6.  
 T1 = ReinterpretS2HF T0
 T3 = SqrtHF T1
 T4 = ReinterpretHF2S T3

  
While above transformation are valid replacements for current intrinsic 
approach which uses explicit entry points in newly defined Float16Math helper 
class, they deviate from implementation of several j.l intrinsified methods 
which could be replaced by pattern matches e.g. 
https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/lang/Math.java#L2022
https://github.com/openjdk/jdk/blob/master/src/java.base/share/classes/java/lang/Math.java#L2116

I think we need to carefully pick pattern match over intrinsification if former 
handles more general cases.

If our intention is to capture various Float16 operation patterns in user's 
code which does not directly uses Float16 API then pattern matching looks 
appealing, but APIs like SQRT and FMA are very carefully drafted keeping in 
view rounding impact, and such patterns will be hard to find, thus it should be 
ok to take intrinsic route for them, simpler cases like add / sub / mul /div / 
max / min can be handled through a pattern matching approach.

There are also some issues around VM symbol creations for intrinsic entries 
defined in non-java.base modules which did not surface with then Float16 and 
Float16Math were part of java.base module.

For this PR taking hybrid approach comprising of both pattern match and 
intensification looks reasonable to me.

Please let me know if you have any comments.

Some FAQs on the newly added ideal type for half-float IR nodes:-

Q. Why do we not use existing TypeInt::SHORT instead of creating a new TypeH 
type?
A. Newly defined half float type named TypeH is special as its basictype is 
T_SHORT while its ideal type is  RegF. Thus, the C2 type system views its 
associated IR node as a 16-bit short value while the register allocator assigns 
it a floating point register.

Q. Problem with ConF?
A. During Auto-Vectorization, ConF replication constrains the operational 
vector lane count to half of what can otherwise be used for regular Float16 
operation i.e. only 16 floats can be accommodated into a 512-bit vector thereby 
limiting the lane count of vectors in its use-def chain, one possible way to 
address it is through a kludge in auto-vectorizer to cast them to a 16 bits 
constant by analyzing its context. Newly defined Float16 constant nodes 'ConH' 
are inherently 16-bit encoded IEEE 754 FP16 values and can be efficiently 
packed to leverage full target vector width.

All Float16 IR nodes now carry newly defined Type::HALF_FLOAT type instead of 
Type::FLOAT, thus we no longer need special handling in auto-vectorizer to 
prune their container type to short.

-------------

PR Comment: https://git.openjdk.org/jdk/pull/21490#issuecomment-2425873278
PR Comment: https://git.openjdk.org/jdk/pull/21490#issuecomment-2482867818

Reply via email to