https://github.com/brendandahl created https://github.com/llvm/llvm-project/pull/93360
This reuses most of the code that was created for f32x4 and f64x2 binary instructions and tries to follow how they were implemented. add/sub/mul/div - use regular LL instructions min/max - use the minimum/maximum intrinsic, and also have builtins pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins Specified at: https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md >From c33801afebb6720bc4b51fb4064b59529c40d298 Mon Sep 17 00:00:00 2001 From: Brendan Dahl <brendan.d...@gmail.com> Date: Thu, 23 May 2024 23:38:51 +0000 Subject: [PATCH] [WebAssembly] Implement all f16x8 binary instructions. This reuses most of the code that was created for f32x4 and f64x2 binary instructions and tries to follow how they were implemented. add/sub/mul/div - use regular LL instructions min/max - use the minimum/maximum intrinsic, and also have builtins pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins Specified at: https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md --- .../clang/Basic/BuiltinsWebAssembly.def | 4 ++ clang/lib/CodeGen/CGBuiltin.cpp | 4 ++ clang/test/CodeGen/builtins-wasm.c | 24 +++++++ .../WebAssembly/WebAssemblyISelLowering.cpp | 5 ++ .../WebAssembly/WebAssemblyInstrSIMD.td | 37 +++++++--- .../CodeGen/WebAssembly/half-precision.ll | 68 +++++++++++++++++++ llvm/test/MC/WebAssembly/simd-encodings.s | 24 +++++++ 7 files changed, 157 insertions(+), 9 deletions(-) diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def index fd8c1b480d6da..4e48ff48b60f5 100644 --- a/clang/include/clang/Basic/BuiltinsWebAssembly.def +++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def @@ -135,6 +135,10 @@ TARGET_BUILTIN(__builtin_wasm_min_f64x2, "V2dV2dV2d", "nc", "simd128") TARGET_BUILTIN(__builtin_wasm_max_f64x2, "V2dV2dV2d", "nc", "simd128") TARGET_BUILTIN(__builtin_wasm_pmin_f64x2, "V2dV2dV2d", "nc", "simd128") TARGET_BUILTIN(__builtin_wasm_pmax_f64x2, "V2dV2dV2d", "nc", "simd128") +TARGET_BUILTIN(__builtin_wasm_min_f16x8, "V8hV8hV8h", "nc", "half-precision") +TARGET_BUILTIN(__builtin_wasm_max_f16x8, "V8hV8hV8h", "nc", "half-precision") +TARGET_BUILTIN(__builtin_wasm_pmin_f16x8, "V8hV8hV8h", "nc", "half-precision") +TARGET_BUILTIN(__builtin_wasm_pmax_f16x8, "V8hV8hV8h", "nc", "half-precision") TARGET_BUILTIN(__builtin_wasm_ceil_f32x4, "V4fV4f", "nc", "simd128") TARGET_BUILTIN(__builtin_wasm_floor_f32x4, "V4fV4f", "nc", "simd128") diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 0549afa12e430..f8be7182b5267 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -20779,6 +20779,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID, } case WebAssembly::BI__builtin_wasm_min_f32: case WebAssembly::BI__builtin_wasm_min_f64: + case WebAssembly::BI__builtin_wasm_min_f16x8: case WebAssembly::BI__builtin_wasm_min_f32x4: case WebAssembly::BI__builtin_wasm_min_f64x2: { Value *LHS = EmitScalarExpr(E->getArg(0)); @@ -20789,6 +20790,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID, } case WebAssembly::BI__builtin_wasm_max_f32: case WebAssembly::BI__builtin_wasm_max_f64: + case WebAssembly::BI__builtin_wasm_max_f16x8: case WebAssembly::BI__builtin_wasm_max_f32x4: case WebAssembly::BI__builtin_wasm_max_f64x2: { Value *LHS = EmitScalarExpr(E->getArg(0)); @@ -20797,6 +20799,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID, CGM.getIntrinsic(Intrinsic::maximum, ConvertType(E->getType())); return Builder.CreateCall(Callee, {LHS, RHS}); } + case WebAssembly::BI__builtin_wasm_pmin_f16x8: case WebAssembly::BI__builtin_wasm_pmin_f32x4: case WebAssembly::BI__builtin_wasm_pmin_f64x2: { Value *LHS = EmitScalarExpr(E->getArg(0)); @@ -20805,6 +20808,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID, CGM.getIntrinsic(Intrinsic::wasm_pmin, ConvertType(E->getType())); return Builder.CreateCall(Callee, {LHS, RHS}); } + case WebAssembly::BI__builtin_wasm_pmax_f16x8: case WebAssembly::BI__builtin_wasm_pmax_f32x4: case WebAssembly::BI__builtin_wasm_pmax_f64x2: { Value *LHS = EmitScalarExpr(E->getArg(0)); diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c index 93a6ab06081c9..d6ee4f68700dc 100644 --- a/clang/test/CodeGen/builtins-wasm.c +++ b/clang/test/CodeGen/builtins-wasm.c @@ -825,6 +825,30 @@ float extract_lane_f16x8(f16x8 a, int i) { // WEBASSEMBLY-NEXT: ret float %0 return __builtin_wasm_extract_lane_f16x8(a, i); } + +f16x8 min_f16x8(f16x8 a, f16x8 b) { + // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.minimum.v8f16(<8 x half> %a, <8 x half> %b) + // WEBASSEMBLY-NEXT: ret <8 x half> %0 + return __builtin_wasm_min_f16x8(a, b); +} + +f16x8 max_f16x8(f16x8 a, f16x8 b) { + // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.maximum.v8f16(<8 x half> %a, <8 x half> %b) + // WEBASSEMBLY-NEXT: ret <8 x half> %0 + return __builtin_wasm_max_f16x8(a, b); +} + +f16x8 pmin_f16x8(f16x8 a, f16x8 b) { + // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b) + // WEBASSEMBLY-NEXT: ret <8 x half> %0 + return __builtin_wasm_pmin_f16x8(a, b); +} + +f16x8 pmax_f16x8(f16x8 a, f16x8 b) { + // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b) + // WEBASSEMBLY-NEXT: ret <8 x half> %0 + return __builtin_wasm_pmax_f16x8(a, b); +} __externref_t externref_null() { return __builtin_wasm_ref_null_extern(); // WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern() diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 518b6932a0c87..7cbae1bef8ef4 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -142,6 +142,11 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( setTruncStoreAction(T, MVT::f16, Expand); } + if (Subtarget->hasHalfPrecision()) { + setOperationAction(ISD::FMINIMUM, MVT::v8f16, Legal); + setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Legal); + } + // Expand unavailable integer operations. for (auto Op : {ISD::BSWAP, ISD::SMUL_LOHI, ISD::UMUL_LOHI, ISD::MULHS, ISD::MULHU, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index 558e3d859dcd8..83260fbaa700b 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -16,33 +16,34 @@ multiclass ABSTRACT_SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s, list<dag> pattern_r, string asmstr_r, string asmstr_s, bits<32> simdop, - Predicate simd_level> { + list<Predicate> reqs> { defm "" : I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, asmstr_s, !if(!ge(simdop, 0x100), !or(0xfd0000, !and(0xffff, simdop)), !or(0xfd00, !and(0xff, simdop)))>, - Requires<[simd_level]>; + Requires<reqs>; } multiclass SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s, list<dag> pattern_r, string asmstr_r = "", - string asmstr_s = "", bits<32> simdop = -1> { + string asmstr_s = "", bits<32> simdop = -1, + list<Predicate> reqs = []> { defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, - asmstr_s, simdop, HasSIMD128>; + asmstr_s, simdop, !listconcat([HasSIMD128], reqs)>; } multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s, list<dag> pattern_r, string asmstr_r = "", string asmstr_s = "", bits<32> simdop = -1> { defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, - asmstr_s, simdop, HasRelaxedSIMD>; + asmstr_s, simdop, [HasRelaxedSIMD]>; } multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s, list<dag> pattern_r, string asmstr_r = "", string asmstr_s = "", bits<32> simdop = -1> { defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, - asmstr_s, simdop, HasHalfPrecision>; + asmstr_s, simdop, [HasHalfPrecision]>; } @@ -152,6 +153,18 @@ def F64x2 : Vec { let prefix = "f64x2"; } +def F16x8 : Vec { + let vt = v8f16; + let int_vt = v8i16; + let lane_vt = f32; + let lane_rc = F32; + let lane_bits = 16; + let lane_idx = LaneIdx8; + let lane_load = int_wasm_loadf16_f32; + let splat = PatFrag<(ops node:$x), (v8f16 (splat_vector (f16 $x)))>; + let prefix = "f16x8"; +} + defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2]; defvar IntVecs = [I8x16, I16x8, I32x4, I64x2]; @@ -781,13 +794,14 @@ def : Pat<(v2i64 (nodes[0] (v2f64 V128:$lhs), (v2f64 V128:$rhs))), // Bitwise operations //===----------------------------------------------------------------------===// -multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> { +multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name, + bits<32> simdop, list<Predicate> reqs = []> { defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins), [(set (vec.vt V128:$dst), (node (vec.vt V128:$lhs), (vec.vt V128:$rhs)))], vec.prefix#"."#name#"\t$dst, $lhs, $rhs", - vec.prefix#"."#name, simdop>; + vec.prefix#"."#name, simdop, reqs>; } multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop, @@ -1199,6 +1213,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>; multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> { defm "" : SIMDBinary<F32x4, node, name, baseInst>; defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>; + defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>; } // Addition: add @@ -1242,7 +1257,7 @@ defm PMAX : SIMDBinaryFP<pmax, "pmax", 235>; // Also match the pmin/pmax cases where the operands are int vectors (but the // comparison is still a floating point comparison). This can happen when using // the wasm_simd128.h intrinsics because v128_t is an integer vector. -foreach vec = [F32x4, F64x2] in { +foreach vec = [F32x4, F64x2, F16x8] in { defvar pmin = !cast<NI>("PMIN_"#vec); defvar pmax = !cast<NI>("PMAX_"#vec); def : Pat<(vec.int_vt (vselect @@ -1266,6 +1281,10 @@ def : Pat<(v2f64 (int_wasm_pmin (v2f64 V128:$lhs), (v2f64 V128:$rhs))), (PMIN_F64x2 V128:$lhs, V128:$rhs)>; def : Pat<(v2f64 (int_wasm_pmax (v2f64 V128:$lhs), (v2f64 V128:$rhs))), (PMAX_F64x2 V128:$lhs, V128:$rhs)>; +def : Pat<(v8f16 (int_wasm_pmin (v8f16 V128:$lhs), (v8f16 V128:$rhs))), + (PMIN_F16x8 V128:$lhs, V128:$rhs)>; +def : Pat<(v8f16 (int_wasm_pmax (v8f16 V128:$lhs), (v8f16 V128:$rhs))), + (PMAX_F16x8 V128:$lhs, V128:$rhs)>; //===----------------------------------------------------------------------===// // Conversions diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll index d9d3f6be800fd..73ccea8d652db 100644 --- a/llvm/test/CodeGen/WebAssembly/half-precision.ll +++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll @@ -35,3 +35,71 @@ define float @extract_lane_v8f16(<8 x half> %v) { %r = call float @llvm.wasm.extract.lane.f16x8(<8 x half> %v, i32 1) ret float %r } + +; CHECK-LABEL: add_v8f16: +; CHECK: f16x8.add $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +define <8 x half> @add_v8f16(<8 x half> %a, <8 x half> %b) { + %r = fadd <8 x half> %a, %b + ret <8 x half> %r +} + +; CHECK-LABEL: sub_v8f16: +; CHECK: f16x8.sub $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +define <8 x half> @sub_v8f16(<8 x half> %a, <8 x half> %b) { + %r = fsub <8 x half> %a, %b + ret <8 x half> %r +} + +; CHECK-LABEL: mul_v8f16: +; CHECK: f16x8.mul $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +define <8 x half> @mul_v8f16(<8 x half> %a, <8 x half> %b) { + %r = fmul <8 x half> %a, %b + ret <8 x half> %r +} + +; CHECK-LABEL: div_v8f16: +; CHECK: f16x8.div $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +define <8 x half> @div_v8f16(<8 x half> %a, <8 x half> %b) { + %r = fdiv <8 x half> %a, %b + ret <8 x half> %r +} + +; CHECK-LABEL: min_intrinsic_v8f16: +; CHECK: f16x8.min $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +declare <8 x half> @llvm.minimum.v8f16(<8 x half>, <8 x half>) +define <8 x half> @min_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) { + %a = call <8 x half> @llvm.minimum.v8f16(<8 x half> %x, <8 x half> %y) + ret <8 x half> %a +} + +; CHECK-LABEL: max_intrinsic_v8f16: +; CHECK: f16x8.max $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +declare <8 x half> @llvm.maximum.v8f16(<8 x half>, <8 x half>) +define <8 x half> @max_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) { + %a = call <8 x half> @llvm.maximum.v8f16(<8 x half> %x, <8 x half> %y) + ret <8 x half> %a +} + +; CHECK-LABEL: pmin_intrinsic_v8f16: +; CHECK: f16x8.pmin $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +declare <8 x half> @llvm.wasm.pmin.v8f16(<8 x half>, <8 x half>) +define <8 x half> @pmin_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) { + %v = call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b) + ret <8 x half> %v +} + +; CHECK-LABEL: pmax_intrinsic_v8f16: +; CHECK: f16x8.pmax $push0=, $0, $1 +; CHECK-NEXT: return $pop0 +declare <8 x half> @llvm.wasm.pmax.v8f16(<8 x half>, <8 x half>) +define <8 x half> @pmax_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) { + %v = call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b) + ret <8 x half> %v +} diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s index d397188a9882e..113a23da776fa 100644 --- a/llvm/test/MC/WebAssembly/simd-encodings.s +++ b/llvm/test/MC/WebAssembly/simd-encodings.s @@ -851,4 +851,28 @@ main: # CHECK: f16x8.extract_lane 1 # encoding: [0xfd,0xa1,0x02,0x01] f16x8.extract_lane 1 + # CHECK: f16x8.add # encoding: [0xfd,0xb4,0x02] + f16x8.add + + # CHECK: f16x8.sub # encoding: [0xfd,0xb5,0x02] + f16x8.sub + + # CHECK: f16x8.mul # encoding: [0xfd,0xb6,0x02] + f16x8.mul + + # CHECK: f16x8.div # encoding: [0xfd,0xb7,0x02] + f16x8.div + + # CHECK: f16x8.min # encoding: [0xfd,0xb8,0x02] + f16x8.min + + # CHECK: f16x8.max # encoding: [0xfd,0xb9,0x02] + f16x8.max + + # CHECK: f16x8.pmin # encoding: [0xfd,0xba,0x02] + f16x8.pmin + + # CHECK: f16x8.pmax # encoding: [0xfd,0xbb,0x02] + f16x8.pmax + end_function _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits