https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/123615
>From ffb554ea1ff638237ffc9cb9a491e5a6ad66d8f6 Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Tue, 17 Dec 2024 17:10:38 +0000 Subject: [PATCH 1/2] [AArch64] Implement NEON FP8 fused multiply-add intrinsics (non-indexed) This patch adds the following intrinsics: float16x8_t vmlalbq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float16x8_t vmlaltq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallbtq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlalltbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) float32x4_t vmlallttq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t) [fixup] Update intrinsics definitions [fixup] Remove some opt passes from RUN lines --- clang/include/clang/Basic/arm_neon.td | 10 ++ clang/lib/CodeGen/CGBuiltin.cpp | 43 +++++-- clang/lib/CodeGen/CodeGenFunction.h | 4 +- .../fp8-intrinsics/acle_neon_fp8_fmla.c | 121 ++++++++++++++++++ .../acle_neon_fp8_fmla.c | 22 ++++ llvm/include/llvm/IR/IntrinsicsAArch64.td | 17 +++ .../lib/Target/AArch64/AArch64InstrFormats.td | 9 +- llvm/lib/Target/AArch64/AArch64InstrInfo.td | 14 +- llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll | 56 ++++++++ 9 files changed, 274 insertions(+), 22 deletions(-) create mode 100644 clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c create mode 100644 clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c create mode 100644 llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td index c6609f312969ee..7e7faa68c55692 100644 --- a/clang/include/clang/Basic/arm_neon.td +++ b/clang/include/clang/Basic/arm_neon.td @@ -2161,6 +2161,16 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in { def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>; } +let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in { + def VMLALB_F16_F8 : VInst<"vmlalb_f16_mf8_fpm", "(>F)(>F)..V", "Qm">; + def VMLALT_F16_F8 : VInst<"vmlalt_f16_mf8_fpm", "(>F)(>F)..V", "Qm">; + + def VMLALLBB_F32_F8 : VInst<"vmlallbb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; + def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; + def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; + def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; +} + let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in { def FAMIN : WInst<"vamin", "...", "fhQdQfQh">; def FAMAX : WInst<"vamax", "...", "fhQdQfQh">; diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index b4b26eb84d5f92..8dbc8bfff95a4b 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -6759,11 +6759,14 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops, return Builder.CreateCall(F, Ops, name); } -Value *CodeGenFunction::EmitFP8NeonCall(Function *F, +Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID, + ArrayRef<llvm::Type *> Tys, SmallVectorImpl<Value *> &Ops, - Value *FPM, const char *name) { + const CallExpr *E, const char *name) { + llvm::Value *FPM = + EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E); Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM); - return EmitNeonCall(F, Ops, name); + return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name); } llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall( @@ -6779,9 +6782,7 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall( Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2], Builder.getInt64(0)); } - llvm::Value *FPM = - EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E); - return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name); + return EmitFP8NeonCall(IID, Tys, Ops, E, name); } Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty, @@ -6802,9 +6803,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0, Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8); Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0)); } - llvm::Value *FPM = - EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E); - return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name); + return EmitFP8NeonCall(IID, Tys, Ops, E, name); } // Right-shift a vector by a constant. @@ -14072,6 +14071,32 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID, case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm: return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane, ExtendLane, FloatTy, Ops, E, "fdot4_lane"); + + case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm: + return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb, + {llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E, + "vmlal"); + case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm: + return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt, + {llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E, + "vmlal"); + case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm: + return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb, + {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E, + "vmlall"); + case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm: + return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt, + {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E, + "vmlall"); + case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm: + return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb, + {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E, + "vmlall"); + case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm: + return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt, + {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E, + "vmlall"); + case NEON::BI__builtin_neon_vamin_f16: case NEON::BI__builtin_neon_vaminq_f16: case NEON::BI__builtin_neon_vamin_f32: diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index fd6d44b2579b92..92aee1a05764cd 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4692,9 +4692,9 @@ class CodeGenFunction : public CodeGenTypeCache { SmallVectorImpl<llvm::Value*> &O, const char *name, unsigned shift = 0, bool rightshift = false); - llvm::Value *EmitFP8NeonCall(llvm::Function *F, + llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys, SmallVectorImpl<llvm::Value *> &O, - llvm::Value *FPM, const char *name); + const CallExpr *E, const char *name); llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0, llvm::Type *Ty1, bool Extract, SmallVectorImpl<llvm::Value *> &Ops, diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c new file mode 100644 index 00000000000000..b0b96a4b6725da --- /dev/null +++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c @@ -0,0 +1,121 @@ +// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s +// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s -check-prefix CHECK-CXX + +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -S -o /dev/null %s + +// REQUIRES: aarch64-registered-target + +#include <arm_neon.h> + +// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb( +// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]] +// +// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalb13__Float16x8_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]] +// +float16x8_t test_vmlalb(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlalbq_f16_mf8_fpm(vd, vn, vm, fpm); +} + +// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt( +// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]] +// +// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalt13__Float16x8_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]] +// +float16x8_t test_vmlalt(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlaltq_f16_mf8_fpm(vd, vn, vm, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-NEXT: ret <4 x float> [[VMLALL_I]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbb13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]] +// +float32x4_t test_vmlallbb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlallbbq_f32_mf8_fpm(vd, vn, vm, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-NEXT: ret <4 x float> [[VMLALL_I]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbt13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]] +// +float32x4_t test_vmlallbt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlallbtq_f32_mf8_fpm(vd, vn, vm, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-NEXT: ret <4 x float> [[VMLALL_I]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltb13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]] +// +float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlalltbq_f32_mf8_fpm(vd, vn, vm, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-NEXT: ret <4 x float> [[VMLALL_I]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltt13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]]) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]] +// +float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm); +} diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c new file mode 100644 index 00000000000000..fcdd14e583101e --- /dev/null +++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +faminmax -target-feature +fp8 -emit-llvm -verify %s -o /dev/null + +// REQUIRES: aarch64-registered-target + +#include <arm_neon.h> + +void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) { + + (void) vmlalbq_f16_mf8_fpm(a, u, u, fpm); + // expected-error@-1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}} + (void) vmlaltq_f16_mf8_fpm(a, u, u, fpm); + // expected-error@-1 {{'vmlaltq_f16_mf8_fpm' requires target feature 'fp8fma'}} + (void) vmlallbbq_f32_mf8_fpm(b, u, u, fpm); + // expected-error@-1 {{'vmlallbbq_f32_mf8_fpm' requires target feature 'fp8fma'}} + (void) vmlallbtq_f32_mf8_fpm(b, u, u, fpm); + // expected-error@-1 {{'vmlallbtq_f32_mf8_fpm' requires target feature 'fp8fma'}} + (void) vmlalltbq_f32_mf8_fpm(b, u, u, fpm); + // expected-error@-1 {{'vmlalltbq_f32_mf8_fpm' requires target feature 'fp8fma'}} + (void) vmlallttq_f32_mf8_fpm(b, u, u, fpm); + // expected-error@-1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}} +} + diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index 395db293063f45..9244c549dc469b 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1036,6 +1036,23 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat def int_aarch64_neon_fp8_fdot4 : AdvSIMD_FP8_DOT_Intrinsic; def int_aarch64_neon_fp8_fdot4_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic; + + +// Fused multiply-add + class AdvSIMD_FP8_FMLA_Intrinsic + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + llvm_v16i8_ty, + llvm_v16i8_ty], + [IntrReadMem, IntrInaccessibleMemOnly]>; + + def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic; + def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic; + + def int_aarch64_neon_fp8_fmlallbb : AdvSIMD_FP8_FMLA_Intrinsic; + def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic; + def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic; + def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic; } def llvm_nxv1i1_ty : LLVMType<nxv1i1>; diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index dea2af16e3184a..38ab1ea785c706 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -6519,14 +6519,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm, v4f32, v8f16, OpNode>; } -multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{ +multiclass SIMDThreeSameVectorMLA<bit Q, string asm, SDPatternOperator op> { + def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b", - V128, v8f16, v16i8, null_frag>; + V128, v8f16, v16i8, op>; } -multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{ +multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm, SDPatternOperator op> { def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b", - V128, v4f32, v16i8, null_frag>; + V128, v4f32, v16i8, op>; } // FP8 assembly/disassembly classes diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 364566f63bca10..ff65cd1fca9e1c 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -10353,7 +10353,7 @@ let Predicates = [HasNEON, HasFAMINMAX] in { defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>; } // End let Predicates = [HasNEON, HasFAMINMAX] -let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in { +let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in { defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">; defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">; defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">; @@ -10361,12 +10361,12 @@ let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in { defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">; defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">; - defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">; - defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">; - defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">; - defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">; - defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">; - defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">; + defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>; + defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>; + defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>; + defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt>; + defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb>; + defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt>; } // End let Predicates = [HasFP8FMA] let Predicates = [HasFP8DOT2] in { diff --git a/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll new file mode 100644 index 00000000000000..008069ff63761f --- /dev/null +++ b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64-linux -mattr=+neon,+fp8fma < %s | FileCheck %s + +define <8 x half> @test_fmlalb(<8 x half> %d, <16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: test_fmlalb: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalb v0.8h, v1.16b, v2.16b +; CHECK-NEXT: ret + %r = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> %d, <16 x i8> %a, <16 x i8> %b) + ret <8 x half> %r +} + +define <8 x half> @test_fmlalt(<8 x half> %d, <16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: test_fmlalt: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalt v0.8h, v1.16b, v2.16b +; CHECK-NEXT: ret + %r = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> %d, <16 x i8> %a, <16 x i8> %b) + ret <8 x half> %r +} + +define <4 x float> @test_fmlallbb(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: test_fmlallbb: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlallbb v0.4s, v1.16b, v2.16b +; CHECK-NEXT: ret + %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) + ret <4 x float> %r +} + +define <4 x float> @test_fmlallbt(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: test_fmlallbt: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlallbt v0.4s, v1.16b, v2.16b +; CHECK-NEXT: ret + %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) + ret <4 x float> %r +} + +define <4 x float> @test_fmlalltb(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: test_fmlalltb: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalltb v0.4s, v1.16b, v2.16b +; CHECK-NEXT: ret + %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) + ret <4 x float> %r +} + +define <4 x float> @test_fmlalltt(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: test_fmlalltt: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalltt v0.4s, v1.16b, v2.16b +; CHECK-NEXT: ret + %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) + ret <4 x float> %r +} >From f79eab53b0485c8ee8dee4e64fd1288a57a8648f Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Wed, 18 Dec 2024 10:52:34 +0000 Subject: [PATCH 2/2] [AArch64] Implement NEON FP8 intrinsics for fused multiply-add (indexed) This patch adds the following intrinsics: * Floating-point multiply-add long to half-precision (vector, by element) float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) * Floating-point multiply-add long-long to single-precision (vector, by element) float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) [fixup] Update intrinsics definitions [fixup] Regenerate tests --- clang/include/clang/Basic/arm_neon.td | 14 + clang/lib/CodeGen/CGBuiltin.cpp | 66 ++++- clang/lib/CodeGen/CodeGenFunction.h | 6 +- .../fp8-intrinsics/acle_neon_fp8_fmla.c | 244 ++++++++++++++++++ .../acle_neon_fp8_fmla.c | 29 ++- llvm/include/llvm/IR/IntrinsicsAArch64.td | 16 ++ .../lib/Target/AArch64/AArch64InstrFormats.td | 24 +- llvm/lib/Target/AArch64/AArch64InstrInfo.td | 16 +- llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll | 54 ++++ 9 files changed, 445 insertions(+), 24 deletions(-) diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td index 7e7faa68c55692..577379e4921605 100644 --- a/clang/include/clang/Basic/arm_neon.td +++ b/clang/include/clang/Basic/arm_neon.td @@ -2169,6 +2169,20 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in { def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">; + + def VMLALB_F16_F8_LANE : VInst<"vmlalb_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>; + def VMLALB_F16_F8_LANEQ : VInst<"vmlalb_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>; + def VMLALT_F16_F8_LANE : VInst<"vmlalt_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>; + def VMLALT_F16_F8_LANEQ : VInst<"vmlalt_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>; + + def VMLALLBB_F32_F8_LANE : VInst<"vmlallbb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>; + def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>; + def VMLALLBT_F32_F8_LANE : VInst<"vmlallbt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>; + def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>; + def VMLALLTB_F32_F8_LANE : VInst<"vmlalltb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>; + def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>; + def VMLALLTT_F32_F8_LANE : VInst<"vmlalltt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>; + def VMLALLTT_F32_F8_LANEQ : VInst<"vmlalltt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>; } let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 8dbc8bfff95a4b..9ae133421d8dac 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -6770,14 +6770,14 @@ Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID, } llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall( - unsigned IID, bool ExtendLane, llvm::Type *RetTy, + unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy, SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) { const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() / RetTy->getPrimitiveSizeInBits(); llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount), Ops[1]->getType()}; - if (ExtendLane) { + if (ExtendLaneArg) { auto *VT = llvm::FixedVectorType::get(Int8Ty, 16); Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2], Builder.getInt64(0)); @@ -6785,6 +6785,21 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall( return EmitFP8NeonCall(IID, Tys, Ops, E, name); } +llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall( + unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy, + SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) { + + if (ExtendLaneArg) { + auto *VT = llvm::FixedVectorType::get(Int8Ty, 16); + Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2], + Builder.getInt64(0)); + } + const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() / + RetTy->getPrimitiveSizeInBits(); + return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)}, + Ops, E, name); +} + Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty, bool neg) { int SV = cast<ConstantInt>(V)->getSExtValue(); @@ -12778,7 +12793,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID, unsigned Int; bool ExtractLow = false; - bool ExtendLane = false; + bool ExtendLaneArg = false; switch (BuiltinID) { default: return nullptr; case NEON::BI__builtin_neon_vbsl_v: @@ -14053,24 +14068,24 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID, Ops, E, "fdot2"); case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm: case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm: - ExtendLane = true; + ExtendLaneArg = true; LLVM_FALLTHROUGH; case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm: case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm: return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane, - ExtendLane, HalfTy, Ops, E, "fdot2_lane"); + ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane"); case NEON::BI__builtin_neon_vdot_f32_mf8_fpm: case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm: return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false, FloatTy, Ops, E, "fdot4"); case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm: case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm: - ExtendLane = true; + ExtendLaneArg = true; LLVM_FALLTHROUGH; case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm: case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm: return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane, - ExtendLane, FloatTy, Ops, E, "fdot4_lane"); + ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane"); case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm: return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb, @@ -14096,7 +14111,42 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID, return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt, {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E, "vmlall"); - + case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm: + ExtendLaneArg = true; + LLVM_FALLTHROUGH; + case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm: + return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane, + ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane"); + case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm: + ExtendLaneArg = true; + LLVM_FALLTHROUGH; + case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm: + return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane, + ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane"); + case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm: + ExtendLaneArg = true; + LLVM_FALLTHROUGH; + case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm: + return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane, + ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane"); + case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm: + ExtendLaneArg = true; + LLVM_FALLTHROUGH; + case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm: + return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane, + ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane"); + case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm: + ExtendLaneArg = true; + LLVM_FALLTHROUGH; + case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm: + return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane, + ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane"); + case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm: + ExtendLaneArg = true; + LLVM_FALLTHROUGH; + case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm: + return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane, + ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane"); case NEON::BI__builtin_neon_vamin_f16: case NEON::BI__builtin_neon_vaminq_f16: case NEON::BI__builtin_neon_vamin_f32: diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 92aee1a05764cd..f70e73fdab0e39 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4699,7 +4699,11 @@ class CodeGenFunction : public CodeGenTypeCache { llvm::Type *Ty1, bool Extract, SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name); - llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane, + llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg, + llvm::Type *RetTy, + SmallVectorImpl<llvm::Value *> &Ops, + const CallExpr *E, const char *name); + llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy, SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name); diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c index b0b96a4b6725da..736538073cb391 100644 --- a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c +++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c @@ -119,3 +119,247 @@ float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_ float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm); } + +// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb_lane( +// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 0) +// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z16test_vmlalb_lane13__Float16x8_t14__Mfloat8x16_t13__Mfloat8x8_tm( +// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-CXX-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 0) +// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +float16x8_t test_vmlalb_lane(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) { + return vmlalbq_lane_f16_mf8_fpm(vd, vn, vm, 0, fpm); +} + +// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb_laneq( +// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0) +// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z17test_vmlalb_laneq13__Float16x8_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0) +// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +float16x8_t test_vmlalb_laneq(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlalbq_laneq_f16_mf8_fpm(vd, vn, vm, 0, fpm); +} + +// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt_lane( +// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 7) +// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z16test_vmlalt_lane13__Float16x8_t14__Mfloat8x16_t13__Mfloat8x8_tm( +// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-CXX-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 7) +// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +float16x8_t test_vmlalt_lane(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) { + return vmlaltq_lane_f16_mf8_fpm(vd, vn, vm, 7, fpm); +} + +// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt_laneq( +// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15) +// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z17test_vmlalt_laneq13__Float16x8_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8> +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half> +// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15) +// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]] +// +float16x8_t test_vmlalt_laneq(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlaltq_laneq_f16_mf8_fpm(vd, vn, vm, 15, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb_lane( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 0) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlallbb_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 0) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlallbb_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) { + return vmlallbbq_lane_f32_mf8_fpm(vd, vn, vm, 0, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb_laneq( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlallbb_laneq13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlallbb_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlallbbq_laneq_f32_mf8_fpm(vd, vn, vm, 0, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt_lane( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 3) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlallbt_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 3) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlallbt_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) { + return vmlallbtq_lane_f32_mf8_fpm(vd, vn, vm, 3, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt_laneq( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 3) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlallbt_laneq13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 3) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlallbt_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlallbtq_laneq_f32_mf8_fpm(vd, vn, vm, 3, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb_lane( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlalltb_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlalltb_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) { + return vmlalltbq_lane_f32_mf8_fpm(vd, vn, vm, 7, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb_laneq( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 7) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlalltb_laneq13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 7) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlalltb_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlalltbq_laneq_f32_mf8_fpm(vd, vn, vm, 7, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt_lane( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlalltt_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0) +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlalltt_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) { + return vmlallttq_lane_f32_mf8_fpm(vd, vn, vm, 7, fpm); +} + +// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt_laneq( +// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15) +// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlalltt_laneq13__Float32x4_t14__Mfloat8x16_tS0_m( +// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] { +// CHECK-CXX-NEXT: [[ENTRY:.*:]] +// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]]) +// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15) +// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]] +// +float32x4_t test_vmlalltt_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) { + return vmlallttq_laneq_f32_mf8_fpm(vd, vn, vm, 15, fpm); +} diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c index fcdd14e583101e..4a507b08040fff 100644 --- a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c +++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c @@ -5,7 +5,6 @@ #include <arm_neon.h> void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) { - (void) vmlalbq_f16_mf8_fpm(a, u, u, fpm); // expected-error@-1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}} (void) vmlaltq_f16_mf8_fpm(a, u, u, fpm); @@ -20,3 +19,31 @@ void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) { // expected-error@-1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}} } +void test_imm(float16x8_t d, float32x4_t c, mfloat8x16_t a, mfloat8x8_t b, fpm_t fpm) { +(void) vmlalbq_lane_f16_mf8_fpm(d, a, b, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 7]}} +(void) vmlalbq_laneq_f16_mf8_fpm(d, a, a, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 15]}} +(void) vmlaltq_lane_f16_mf8_fpm(d, a, b, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 7]}} +(void) vmlaltq_laneq_f16_mf8_fpm(d, a, a, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 15]}} + +(void) vmlallbbq_lane_f32_mf8_fpm(c, a, b, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 7]}} +(void) vmlallbbq_laneq_f32_mf8_fpm(c, a, a, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 15]}} +(void) vmlallbtq_lane_f32_mf8_fpm(c, a, b, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 7]}} +(void) vmlallbtq_laneq_f32_mf8_fpm(c, a, a, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 15]}} +(void) vmlalltbq_lane_f32_mf8_fpm(c, a, b, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 7]}} +(void) vmlalltbq_laneq_f32_mf8_fpm(c, a, a, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 15]}} +(void) vmlallttq_lane_f32_mf8_fpm(c, a, b, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 7]}} +(void) vmlallttq_laneq_f32_mf8_fpm(c, a, a, -1, fpm); +// expected-error@-1 {{argument value -1 is outside the valid range [0, 15]}} +} + diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index 9244c549dc469b..6dfc3c8f2a3931 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1046,6 +1046,14 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat llvm_v16i8_ty], [IntrReadMem, IntrInaccessibleMemOnly]>; + class AdvSIMD_FP8_FMLA_LANE_Intrinsic + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + llvm_v16i8_ty, + llvm_v16i8_ty, + llvm_i32_ty], + [IntrReadMem, IntrInaccessibleMemOnly, ImmArg<ArgIndex<3>>]>; + def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic; def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic; @@ -1053,6 +1061,14 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic; def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic; def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic; + + def int_aarch64_neon_fp8_fmlalb_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic; + def int_aarch64_neon_fp8_fmlalt_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic; + + def int_aarch64_neon_fp8_fmlallbb_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic; + def int_aarch64_neon_fp8_fmlallbt_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic; + def int_aarch64_neon_fp8_fmlalltb_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic; + def int_aarch64_neon_fp8_fmlalltt_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic; } def llvm_nxv1i1_ty : LLVMType<nxv1i1>; diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 38ab1ea785c706..3bb5d3cb4d09de 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -9105,7 +9105,7 @@ class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc, RegisterOperand RegType, RegisterOperand RegType_lo> : BaseSIMDIndexedTied<Q, U, 0b0, sz, opc, - RegType, RegType, RegType_lo, VectorIndexB, + RegType, RegType, RegType_lo, VectorIndexB32b_timm, asm, "", dst_kind, ".16b", ".b", []> { // idx = H:L:M @@ -9114,14 +9114,24 @@ class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc, let Inst{21-19} = idx{2-0}; } -multiclass SIMDThreeSameVectorMLAIndex<bit Q, string asm> { - def v8f16 : BaseSIMDThreeSameVectorIndexB<Q, 0b0, 0b11, 0b0000, asm, ".8h", - V128, V128_0to7>; +multiclass SIMDThreeSameVectorMLAIndex<bit Q, string asm, SDPatternOperator op> { + let Uses = [FPMR, FPCR], mayLoad = 1 in { + def v8f16 : BaseSIMDThreeSameVectorIndexB<Q, 0b0, 0b11, 0b0000, asm, ".8h", + V128, V128_0to7>; + } + + def : Pat<(v8f16 (op (v8f16 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128_0to7:$Rm), VectorIndexB32b_timm:$Idx)), + (!cast<Instruction>(NAME # v8f16) $Rd, $Rn, $Rm, $Idx)>; } -multiclass SIMDThreeSameVectorMLALIndex<bit Q, bits<2> sz, string asm> { - def v4f32 : BaseSIMDThreeSameVectorIndexB<Q, 0b1, sz, 0b1000, asm, ".4s", - V128, V128_0to7>; +multiclass SIMDThreeSameVectorMLALIndex<bit Q, bits<2> sz, string asm, SDPatternOperator op> { + let Uses = [FPMR, FPCR], mayLoad = 1 in { + def v4f32 : BaseSIMDThreeSameVectorIndexB<Q, 0b1, sz, 0b1000, asm, ".4s", + V128, V128_0to7>; + } + + def : Pat<(v4f32 (op (v4f32 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128_0to7:$Rm), VectorIndexB32b_timm:$Idx)), + (!cast<Instruction>(NAME # v4f32) $Rd, $Rn, $Rm, $Idx)>; } //---------------------------------------------------------------------------- diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index ff65cd1fca9e1c..d112d4f10e47d9 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -10353,14 +10353,16 @@ let Predicates = [HasNEON, HasFAMINMAX] in { defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>; } // End let Predicates = [HasNEON, HasFAMINMAX] -let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in { - defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">; - defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">; - defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">; - defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">; - defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">; - defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">; +let Predicates = [HasFP8FMA] in { + defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb_lane>; + defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt_lane>; + defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb_lane>; + defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt_lane>; + defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb_lane>; + defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt_lane>; +} +let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in { defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>; defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>; defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>; diff --git a/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll index 008069ff63761f..60957a7c0f2f41 100644 --- a/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll +++ b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll @@ -54,3 +54,57 @@ define <4 x float> @test_fmlalltt(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) { %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) ret <4 x float> %r } + +define <8 x half> @test_fmlalb_lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm) { +; CHECK-LABEL: test_fmlalb_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalb v0.8h, v1.16b, v2.b[0] +; CHECK-NEXT: ret + %res = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 0) + ret <8 x half> %res +} + +define <8 x half> @test_fmlalt_lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm) { +; CHECK-LABEL: test_fmlalt_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalt v0.8h, v1.16b, v2.b[4] +; CHECK-NEXT: ret + %res = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 4) + ret <8 x half> %res +} + +define <4 x float> @test_fmlallbb_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) { +; CHECK-LABEL: test_fmlallbb_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlallbb v0.4s, v1.16b, v2.b[7] +; CHECK-NEXT: ret + %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 7) + ret <4 x float> %res +} + +define <4 x float> @test_fmlallbt_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) { +; CHECK-LABEL: test_fmlallbt_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlallbt v0.4s, v1.16b, v2.b[10] +; CHECK-NEXT: ret + %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 10) + ret <4 x float> %res +} + +define <4 x float> @test_fmlalltb_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) { +; CHECK-LABEL: test_fmlalltb_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalltb v0.4s, v1.16b, v2.b[13] +; CHECK-NEXT: ret + %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 13) + ret <4 x float> %res +} + +define <4 x float> @test_fmlalltt_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) { +; CHECK-LABEL: test_fmlalltt_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: fmlalltt v0.4s, v1.16b, v2.b[15] +; CHECK-NEXT: ret + %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 15) + ret <4 x float> %res +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits