https://github.com/fzou1 updated https://github.com/llvm/llvm-project/pull/115625
>From b1d9799b99b45b5af2b63868c4c3b139dbf9378c Mon Sep 17 00:00:00 2001 From: Feng Zou <feng....@intel.com> Date: Sat, 26 Oct 2024 18:44:32 +0800 Subject: [PATCH 1/3] [X86][AMX] Support AMX-TF32 Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368 --- clang/docs/ReleaseNotes.rst | 1 + clang/include/clang/Basic/BuiltinsX86_64.def | 15 +- clang/include/clang/Driver/Options.td | 2 + clang/lib/Basic/Targets/X86.cpp | 6 + clang/lib/Basic/Targets/X86.h | 1 + clang/lib/Headers/CMakeLists.txt | 1 + clang/lib/Headers/amxtf32intrin.h | 194 ++++++++++++++++++ clang/lib/Headers/immintrin.h | 4 + clang/lib/Sema/SemaX86.cpp | 2 + clang/test/CodeGen/X86/amx_tf32.c | 17 ++ clang/test/CodeGen/X86/amx_tf32_api.c | 27 +++ clang/test/CodeGen/X86/amx_tf32_errors.c | 23 +++ clang/test/CodeGen/X86/amx_tf32_inline_asm.c | 18 ++ clang/test/Driver/x86-target-features.c | 7 + clang/test/Preprocessor/x86_target_features.c | 9 + llvm/include/llvm/IR/IntrinsicsX86.td | 19 ++ .../llvm/TargetParser/X86TargetParser.def | 1 + llvm/lib/Target/X86/X86.td | 3 + llvm/lib/Target/X86/X86ExpandPseudo.cpp | 11 +- llvm/lib/Target/X86/X86ISelLowering.cpp | 22 ++ llvm/lib/Target/X86/X86InstrAMX.td | 52 +++++ llvm/lib/Target/X86/X86InstrPredicates.td | 1 + llvm/lib/Target/X86/X86LowerAMXType.cpp | 20 +- llvm/lib/Target/X86/X86RegisterInfo.cpp | 4 +- llvm/lib/TargetParser/Host.cpp | 1 + llvm/lib/TargetParser/X86TargetParser.cpp | 1 + llvm/test/CodeGen/X86/amx-tf32-internal.ll | 47 +++++ llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll | 23 +++ .../Disassembler/X86/AMX/x86-64-amx-tf32.txt | 19 ++ llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s | 17 ++ llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s | 17 ++ 31 files changed, 578 insertions(+), 7 deletions(-) create mode 100644 clang/lib/Headers/amxtf32intrin.h create mode 100644 clang/test/CodeGen/X86/amx_tf32.c create mode 100644 clang/test/CodeGen/X86/amx_tf32_api.c create mode 100644 clang/test/CodeGen/X86/amx_tf32_errors.c create mode 100644 clang/test/CodeGen/X86/amx_tf32_inline_asm.c create mode 100644 llvm/test/CodeGen/X86/amx-tf32-internal.ll create mode 100644 llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll create mode 100644 llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt create mode 100644 llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s create mode 100644 llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index c3424e0e6f34c9..e235a04f78112b 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -740,6 +740,7 @@ X86 Support - Support ISA of ``AMX-FP8``. - Support ISA of ``AMX-TRANSPOSE``. - Support ISA of ``AMX-AVX512``. +- Support ISA of ``AMX-TF32``. Arm and AArch64 Support ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def index 9f7462b1e0d962..25c10d39df32e2 100644 --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -139,6 +139,9 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2pbf16l_internal, "V32yUsUsV256iUi", "n", TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phh_internal, "V32xUsUsV256iUi", "n", "amx-avx512,avx10.2-512") TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl_internal, "V32xUsUsV256iUi", "n", "amx-avx512,avx10.2-512") TARGET_BUILTIN(__builtin_ia32_tilemovrow_internal, "V16iUsUsV256iUi", "n", "amx-avx512,avx10.2-512") +TARGET_BUILTIN(__builtin_ia32_tmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32") +TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32,amx-transpose") + // AMX TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile") @@ -172,10 +175,6 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phh, "V32xIUcUi", "n", "amx-avx512,avx10 TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl, "V32xIUcUi", "n", "amx-avx512,avx10.2-512") TARGET_BUILTIN(__builtin_ia32_tilemovrow, "V16iIUcUi", "n", "amx-avx512,avx10.2-512") -TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi") -TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd") -TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd") - // AMX_FP16 FP16 TARGET_BUILTIN(__builtin_ia32_tdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16") @@ -185,6 +184,14 @@ TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps, "vIUcUIcUIc", "n", "amx-fp8") TARGET_BUILTIN(__builtin_ia32_tdphbf8ps, "vIUcUIcUIc", "n", "amx-fp8") TARGET_BUILTIN(__builtin_ia32_tdphf8ps, "vIUcUIcUIc", "n", "amx-fp8") +// AMX TF32 +TARGET_BUILTIN(__builtin_ia32_tmmultf32ps, "vIUcIUcIUc", "n", "amx-tf32") +TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps, "vIUcIUcIUc", "n", "amx-tf32,amx-transpose") + +TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi") +TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd") +TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd") + // RAO-INT TARGET_BUILTIN(__builtin_ia32_aadd64, "vv*SOi", "n", "raoint") TARGET_BUILTIN(__builtin_ia32_aand64, "vv*SOi", "n", "raoint") diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index 0dba5672c5a85d..1304ef3c5a228b 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6297,6 +6297,8 @@ def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>; def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>; def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>; def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>; +def mamx_tf32 : Flag<["-"], "mamx-tf32">, Group<m_x86_Features_Group>; +def mno_amx_tf32 : Flag<["-"], "mno-amx-tf32">, Group<m_x86_Features_Group>; def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>; def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>; def mamx_transpose : Flag<["-"], "mamx-transpose">, Group<m_x86_Features_Group>; diff --git a/clang/lib/Basic/Targets/X86.cpp b/clang/lib/Basic/Targets/X86.cpp index 3c3dbfa13e452b..dc85e9aa77cd3d 100644 --- a/clang/lib/Basic/Targets/X86.cpp +++ b/clang/lib/Basic/Targets/X86.cpp @@ -434,6 +434,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features, HasAMXTRANSPOSE = true; } else if (Feature == "+amx-avx512") { HasAMXAVX512 = true; + } else if (Feature == "+amx-tf32") { + HasAMXTF32 = true; } else if (Feature == "+cmpccxadd") { HasCMPCCXADD = true; } else if (Feature == "+raoint") { @@ -959,6 +961,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts, Builder.defineMacro("__AMX_TRANSPOSE__"); if (HasAMXAVX512) Builder.defineMacro("__AMX_AVX512__"); + if (HasAMXTF32) + Builder.defineMacro("__AMX_TF32__"); if (HasCMPCCXADD) Builder.defineMacro("__CMPCCXADD__"); if (HasRAOINT) @@ -1090,6 +1094,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const { .Case("amx-fp16", true) .Case("amx-fp8", true) .Case("amx-int8", true) + .Case("amx-tf32", true) .Case("amx-tile", true) .Case("amx-transpose", true) .Case("avx", true) @@ -1211,6 +1216,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const { .Case("amx-fp16", HasAMXFP16) .Case("amx-fp8", HasAMXFP8) .Case("amx-int8", HasAMXINT8) + .Case("amx-tf32", HasAMXTF32) .Case("amx-tile", HasAMXTILE) .Case("amx-transpose", HasAMXTRANSPOSE) .Case("avx", SSELevel >= AVX) diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h index 70047731b17295..04b1d5d33ea231 100644 --- a/clang/lib/Basic/Targets/X86.h +++ b/clang/lib/Basic/Targets/X86.h @@ -160,6 +160,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo { bool HasAMXFP8 = false; bool HasAMXTRANSPOSE = false; bool HasAMXAVX512 = false; + bool HasAMXTF32 = false; bool HasSERIALIZE = false; bool HasTSXLDTRK = false; bool HasUSERMSR = false; diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt index 76366ca1f108e9..0ad9596ba9e257 100644 --- a/clang/lib/Headers/CMakeLists.txt +++ b/clang/lib/Headers/CMakeLists.txt @@ -151,6 +151,7 @@ set(x86_files amxfp16intrin.h amxfp8intrin.h amxintrin.h + amxtf32intrin.h amxtransposeintrin.h avx10_2_512bf16intrin.h avx10_2_512convertintrin.h diff --git a/clang/lib/Headers/amxtf32intrin.h b/clang/lib/Headers/amxtf32intrin.h new file mode 100644 index 00000000000000..f11b7c7499e2d5 --- /dev/null +++ b/clang/lib/Headers/amxtf32intrin.h @@ -0,0 +1,194 @@ +/*===------------- amxtf32intrin.h - AMX_TF32 intrinsics -*- C++ -*---------=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===------------------------------------------------------------------------=== + */ + +#ifndef __IMMINTRIN_H +#error "Never use <amxtf32intrin.h> directly; include <immintrin.h> instead." +#endif // __IMMINTRIN_H + +#ifndef __AMX_TF32INTRIN_H +#define __AMX_TF32INTRIN_H +#ifdef __x86_64__ + +#define __DEFAULT_FN_ATTRS_TF32 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-tf32"))) + +#define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE \ + __attribute__((__always_inline__, __nodebug__, \ + __target__("amx-tf32,amx-transpose"))) + +/// Do Matrix Multiplication of \a a and \a b, and then do Matrix Plus +/// with \a srcdst. +/// All the calculation is base on float32 but with the lower 13-bit set to 0. +/// +/// \headerfile <immintrin.h> +/// +/// \code +/// void _tile_mmultf32ps(constexpr int srcdst, constexpr int a, \ +/// constexpr int b); +/// \endcode +/// +/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction. +/// +/// \param srcdst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +/// +/// \code{.operation} +/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) { +/// dword[12:0] := 0 +/// dword[31:13] := x[31:13] +/// return dword +/// } +/// +/// DEFINE silence_snan_fp32(x[31:0]) { +/// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0) +/// x.fraction[22] := 1 +/// return x +/// } +/// +/// elements_a := a.colsb / 4 +/// elements_dest := srcdst.colsb / 4 +/// +/// FOR m = 0 TO (srcdst.rows-1) +/// tmp[511:0] := 0 +/// FOR k = 0 TO (elements_a-1) +/// FOR n = 0 TO (elements_dest-1) +/// af := silence_snan_fp32(a.row[m].fp32[k]) +/// bf := silence_snan_fp32(b.row[k].fp32[n]) +/// tmp.fp32[n] += zero_lower_mantissa_bits_fp32(af) +/// * zero_lower_mantissa_bits_fp32(bf) +/// ENDFOR +/// ENDFOR +/// +/// FOR n = 0 TO (elements_dest-1) +/// tmp.fp32[n] += srcdst.row[m].fp32[n] +/// ENDFOR +/// write_row_and_zero(srcdst, m, tmp, srcdst.colsb) +/// +/// ENDFOR +/// +/// zero_upper_rows(srcdst, srcdst.rows) +/// zero_tileconfig_start() +/// \endcode +#define _tile_mmultf32ps(srcdst, a, b) \ + __builtin_ia32_tmmultf32ps((srcdst), (a), (b)) + +/// \code +/// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \ +/// constexpr int b); +/// \endcode +/// +/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. +/// +/// \param srcdst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +/// +/// \code{.operation} +/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) { +/// dword[12:0] := 0 +/// dword[31:13] := x[31:13] +/// return dword +/// } +/// +/// DEFINE silence_snan_fp32(x[31:0]) { +/// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0) +/// x.fraction[22] := 1 +/// return x +/// } +/// +/// elements_dest:= srcdst.colsb/4 +/// +/// FOR m := 0 TO (srcdst.rows-1) +/// tmp[511:0] := 0 +/// FOR k := 0 TO (a.rows-1) +/// FOR n := 0 TO (elements_dest-1) +/// a1e := silence_snan_fp32(a.row[k].fp32[m]) +/// a2e := silence_snan_fp32(b.row[k].fp32[n]) +/// s1e := zero_lower_mantissa_bits_fp32(a1e) +/// s2e := zero_lower_mantissa_bits_fp32(a2e) +/// tmp.fp32[n] += s1e * s2e +/// ENDFOR +/// ENDFOR +/// +/// FOR n := 0 TO (elements_dest-1) +/// tmp.fp32[n] += srcdst.row[m].fp32[n] +/// ENDFOR +/// write_row_and_zero(srcdst, m, tmp, srcdst.colsb) +/// +/// ENDFOR +/// +/// zero_upper_rows(srcdst, srcdst.rows) +/// zero_tileconfig_start() +/// \endcode +#define _tile_tmmultf32ps(srcdst, a, b) \ + __builtin_ia32_ttmmultf32ps((srcdst), (a), (b)) + +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32 +_tile_mmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tmmultf32ps_internal(m, n, k, dst, src1, src2); +} + +/// Do Matrix Multiplication of src0 and src1, and then do Matrix Plus with dst. +/// All the calculation is base on float32 but with the lower 13-bit set to 0. +/// +/// \headerfile <immintrin.h> +/// +/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS_TF32 +static void __tile_mmultf32ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_mmultf32ps_internal(src0.row, src1.col, src0.col, dst->tile, + src0.tile, src1.tile); +} + +// dst = m x n (srcdest), src1 = k x m, src2 = k x n +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE +_tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2); +} + +/// Compute transpose and do Matrix Multiplication of src0 and src1, and then do +/// Matrix Plus with dst. All the calculation is base on float32 but with the +/// lower 13-bit set to 0. +/// +/// \headerfile <immintrin.h> +/// +/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS_TF32_TRANSPOSE +static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col, + dst->tile, src0.tile, src1.tile); +} + +#endif // __x86_64__ +#endif // __AMX_TF32INTRIN_H diff --git a/clang/lib/Headers/immintrin.h b/clang/lib/Headers/immintrin.h index bc240e28d59142..5740da8136ca99 100644 --- a/clang/lib/Headers/immintrin.h +++ b/clang/lib/Headers/immintrin.h @@ -660,6 +660,10 @@ _storebe_i64(void * __P, long long __D) { #include <amxavx512intrin.h> #endif +#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_TF32__) +#include <amxtf32intrin.h> +#endif + #if !defined(__SCE__) || __has_feature(modules) || \ defined(__AVX512VP2INTERSECT__) #include <avx512vp2intersectintrin.h> diff --git a/clang/lib/Sema/SemaX86.cpp b/clang/lib/Sema/SemaX86.cpp index 1155a5edc73c34..d7c8ed351f410a 100644 --- a/clang/lib/Sema/SemaX86.cpp +++ b/clang/lib/Sema/SemaX86.cpp @@ -654,6 +654,8 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) { case X86::BI__builtin_ia32_tdpbhf8ps: case X86::BI__builtin_ia32_tdphbf8ps: case X86::BI__builtin_ia32_tdphf8ps: + case X86::BI__builtin_ia32_tmmultf32ps: + case X86::BI__builtin_ia32_ttmmultf32ps: return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2}); case X86::BI__builtin_ia32_ttransposed: return CheckBuiltinTileArgumentsRange(TheCall, {0, 1}); diff --git a/clang/test/CodeGen/X86/amx_tf32.c b/clang/test/CodeGen/X86/amx_tf32.c new file mode 100644 index 00000000000000..661a9dfbc673b2 --- /dev/null +++ b/clang/test/CodeGen/X86/amx_tf32.c @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tile -target-feature +amx-tf32 \ +// RUN: -target-feature +amx-transpose -emit-llvm -o - -Wall -Werror -pedantic -Wno-gnu-statement-expression | FileCheck %s + +#include <immintrin.h> +#include <stddef.h> + +void test_tile_mmultf32ps(void) { + // CHECK-LABEL: @test_tile_mmultf32ps( + // CHECK: call void @llvm.x86.tmmultf32ps(i8 1, i8 2, i8 3) + _tile_mmultf32ps(1, 2, 3); +} + +void test_tile_tmmultf32ps(void) { + // CHECK-LABEL: @test_tile_tmmultf32ps( + // CHECK: call void @llvm.x86.ttmmultf32ps(i8 1, i8 2, i8 3) + _tile_tmmultf32ps(1, 2, 3); +} diff --git a/clang/test/CodeGen/X86/amx_tf32_api.c b/clang/test/CodeGen/X86/amx_tf32_api.c new file mode 100644 index 00000000000000..2ac8489e3e0baf --- /dev/null +++ b/clang/test/CodeGen/X86/amx_tf32_api.c @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown \ +// RUN: -target-feature +amx-tf32 -target-feature +amx-transpose \ +// RUN: -target-feature +amx-bf16 -target-feature +avx512f \ +// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s + +#include <immintrin.h> + +char buf[1024]; +#define STRIDE 32 + +char buf2[1024]; + +void test_tile_mmultf32ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_mmultf32ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.tmmultf32ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_mmultf32ps(&c, a, b); +} + +void test_tile_tmmultf32ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_tmmultf32ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.ttmmultf32ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_tmmultf32ps(&c, a, b); +} diff --git a/clang/test/CodeGen/X86/amx_tf32_errors.c b/clang/test/CodeGen/X86/amx_tf32_errors.c new file mode 100644 index 00000000000000..45021306921150 --- /dev/null +++ b/clang/test/CodeGen/X86/amx_tf32_errors.c @@ -0,0 +1,23 @@ +// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown \ +// RUN: -target-feature +amx-tf32 -target-feature +amx-transpose -verify + +#include <immintrin.h> +#include <stddef.h> + +void test_tile_mmultf32ps() { + _tile_mmultf32ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}} + _tile_mmultf32ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}} + _tile_mmultf32ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}} + _tile_mmultf32ps(1, 1, 3); // expected-error {{tile arguments must refer to different tiles}} + _tile_mmultf32ps(1, 2, 1); // expected-error {{tile arguments must refer to different tiles}} + _tile_mmultf32ps(1, 3, 3); // expected-error {{tile arguments must refer to different tiles}} +} + +void test_tile_tmmultf32ps() { + _tile_tmmultf32ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}} + _tile_tmmultf32ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}} + _tile_tmmultf32ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}} + _tile_tmmultf32ps(1, 1, 3); // expected-error {{tile arguments must refer to different tiles}} + _tile_tmmultf32ps(1, 2, 1); // expected-error {{tile arguments must refer to different tiles}} + _tile_tmmultf32ps(1, 2, 2); // expected-error {{tile arguments must refer to different tiles}} +} diff --git a/clang/test/CodeGen/X86/amx_tf32_inline_asm.c b/clang/test/CodeGen/X86/amx_tf32_inline_asm.c new file mode 100644 index 00000000000000..76d164737d88b6 --- /dev/null +++ b/clang/test/CodeGen/X86/amx_tf32_inline_asm.c @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tf32 -target-feature +amx-transpose -emit-llvm -o - -Wall -Werror -pedantic | FileCheck %s + +void f_tilemul(short a) +{ + //CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tmmultf32ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"() + __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t" + "tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t" + "tmmultf32ps %%tmm6, %%tmm0, %%tmm7 \n\t" + "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t" + ::: "memory", "tmm0", "tmm6", "tmm7"); + + //CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09ttmmultf32ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"() + __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t" + "tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t" + "ttmmultf32ps %%tmm6, %%tmm0, %%tmm7 \n\t" + "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t" + ::: "memory", "tmm0", "tmm6", "tmm7"); +} diff --git a/clang/test/Driver/x86-target-features.c b/clang/test/Driver/x86-target-features.c index 822c997f71744f..339f593dc760a8 100644 --- a/clang/test/Driver/x86-target-features.c +++ b/clang/test/Driver/x86-target-features.c @@ -318,6 +318,13 @@ // AMX-AVX512: "-target-feature" "+amx-avx512" // NO-AMX-AVX512: "-target-feature" "-amx-avx512" +// RUN: %clang -target x86_64-unknown-linux-gnu -mamx-tf32 %s \ +// RUN: -### -o %t.o 2>&1 | FileCheck -check-prefix=AMX-TF32 %s +// RUN: %clang -target x86_64-unknown-linux-gnu -mno-amx-tf32 %s \ +// RUN: -### -o %t.o 2>&1 | FileCheck -check-prefix=NO-AMX-TF32 %s +// AMX-TF32: "-target-feature" "+amx-tf32" +// NO-AMX-TF32: "-target-feature" "-amx-tf32" + // RUN: %clang --target=i386 -march=i386 -mhreset %s -### 2>&1 | FileCheck -check-prefix=HRESET %s // RUN: %clang --target=i386 -march=i386 -mno-hreset %s -### 2>&1 | FileCheck -check-prefix=NO-HRESET %s // HRESET: "-target-feature" "+hreset" diff --git a/clang/test/Preprocessor/x86_target_features.c b/clang/test/Preprocessor/x86_target_features.c index 8e4ddb1526626e..fa3d0038f05a93 100644 --- a/clang/test/Preprocessor/x86_target_features.c +++ b/clang/test/Preprocessor/x86_target_features.c @@ -570,6 +570,15 @@ // NO-AMX-AVX512-NOT: #define __AMX_AVX512__ 1 +// RUN: %clang -target x86_64-unknown-linux-gnu -march=x86-64 -mamx-tf32 -x c \ +// RUN: -E -dM -o - %s | FileCheck -check-prefix=AMX-TF32 %s +// AMX-TF32: #define __AMX_TF32__ 1 +// RUN: %clang -target x86_64-unknown-linux-gnu -march=x86-64 -mno-amx-tf32 -x c \ +// RUN: -E -dM -o - %s | FileCheck -check-prefix=NO-AMX-TF32 %s +// RUN: %clang -target x86_64-unknown-linux-gnu -march=x86-64 -mamx-tf32 -mno-amx-tile \ +// RUN: -x c -E -dM -o - %s | FileCheck -check-prefix=NO-AMX-TF32 %s +// NO-AMX-TF32-NOT: #define __AMX_TF32__ 1 + // RUN: %clang -target i386-unknown-unknown -march=atom -mavxvnni -x c -E -dM -o - %s | FileCheck -match-full-lines --check-prefix=AVXVNNI %s // AVXVNNI: #define __AVX2__ 1 diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td index 3003f9887e239c..ce519daacdb210 100644 --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -6101,6 +6101,25 @@ let TargetPrefix = "x86" in { Intrinsic<[llvm_v16i32_ty], [llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_i32_ty], []>; + + def int_x86_tmmultf32ps : ClangBuiltin<"__builtin_ia32_tmmultf32ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>; + def int_x86_ttmmultf32ps : ClangBuiltin<"__builtin_ia32_ttmmultf32ps">, + Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], + [ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>; + def int_x86_tmmultf32ps_internal : + ClangBuiltin<"__builtin_ia32_tmmultf32ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_ttmmultf32ps_internal : + ClangBuiltin<"__builtin_ia32_ttmmultf32ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; } //===----------------------------------------------------------------------===// diff --git a/llvm/include/llvm/TargetParser/X86TargetParser.def b/llvm/include/llvm/TargetParser/X86TargetParser.def index 815556e374bef5..3b643563775688 100644 --- a/llvm/include/llvm/TargetParser/X86TargetParser.def +++ b/llvm/include/llvm/TargetParser/X86TargetParser.def @@ -267,6 +267,7 @@ X86_FEATURE (ZU, "zu") X86_FEATURE (AMX_FP8, "amx-fp8") X86_FEATURE (AMX_TRANSPOSE, "amx-transpose") X86_FEATURE (AMX_AVX512, "amx-avx512") +X86_FEATURE (AMX_TF32, "amx-tf32") // These features aren't really CPU features, but the frontend can set them. X86_FEATURE (RETPOLINE_EXTERNAL_THUNK, "retpoline-external-thunk") X86_FEATURE (RETPOLINE_INDIRECT_BRANCHES, "retpoline-indirect-branches") diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td index 59780ba5b99fcf..35bbffdb20942d 100644 --- a/llvm/lib/Target/X86/X86.td +++ b/llvm/lib/Target/X86/X86.td @@ -280,6 +280,9 @@ def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512", "HasAMXAVX512", "true", "Support AMX-AVX512 instructions", [FeatureAMXTILE]>; +def FeatureAMXTF32 : SubtargetFeature<"amx-tf32", "HasAMXTF32", "true", + "Support AMX-TF32 instructions", + [FeatureAMXTILE]>; def FeatureCMPCCXADD : SubtargetFeature<"cmpccxadd", "HasCMPCCXADD", "true", "Support CMPCCXADD instructions">; def FeatureRAOINT : SubtargetFeature<"raoint", "HasRAOINT", "true", diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index a6096e5032e89c..4f045d78f75fb2 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -755,7 +755,9 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTDPBF16PSV: - case X86::PTDPFP16PSV: { + case X86::PTDPFP16PSV: + case X86::PTMMULTF32PSV: + case X86::PTTMMULTF32PSV: { MI.untieRegOperand(4); for (unsigned i = 3; i > 0; --i) MI.removeOperand(i); @@ -769,6 +771,13 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break; case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; + case X86::PTMMULTF32PSV: + Opc = X86::TMMULTF32PS; + break; + case X86::PTTMMULTF32PSV: + Opc = X86::TTMMULTF32PS; + break; + default: llvm_unreachable("Unexpected Opcode"); } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 253b768f34a07c..6140dffbe8ce65 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -37686,6 +37686,28 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); // The pseudo is gone now. return BB; } + case X86::PTMMULTF32PS: + case X86::PTTMMULTF32PS: { + const DebugLoc &DL = MI.getDebugLoc(); + unsigned Opc; + switch (MI.getOpcode()) { + default: + llvm_unreachable("Unexpected instruction!"); + case X86::PTMMULTF32PS: + Opc = X86::TMMULTF32PS; + break; + case X86::PTTMMULTF32PS: + Opc = X86::TTMMULTF32PS; + break; + } + MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); + MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef); + MI.eraseFromParent(); + return BB; + } } } diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td index b954c977f8c6c9..1b579c488c2f00 100644 --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -516,3 +516,55 @@ let Predicates = [HasAMXAVX512, HasAVX10_2_512, In64BitMode] in { TILE:$src3, GR32:$src4))]>; } } + +let Predicates = [HasAMXTF32, In64BitMode] in { + let SchedRW = [WriteSystem] in { + let Constraints = "$src1 = $dst" in { + def TMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "tmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", + []>, VEX, VVVV, T8, PD; + } + let Constraints = "$src4 = $dst" in { + def PTMMULTF32PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE:$dst, + (int_x86_tmmultf32ps_internal GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6))]>; + } + let usesCustomInserter = 1 in { + def PTMMULTF32PS : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_tmmultf32ps timm:$src1, + timm:$src2, timm:$src3)]>; + } + } // SchedRW = [WriteSystem] +} // HasAMXTF32 + +let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in { + let SchedRW = [WriteSystem] in { + let Constraints = "$src1 = $dst" in { + def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst), + (ins TILE:$src1, TILE:$src2, TILE:$src3), + "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", + []>, VEX, VVVV, T8, PS; + } + let Constraints = "$src4 = $dst" in { + def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE:$dst, + (int_x86_ttmmultf32ps_internal GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6))]>; + } + let usesCustomInserter = 1 in { + def PTTMMULTF32PS : PseudoI<(outs), (ins u8imm:$src1, + u8imm:$src2, u8imm:$src3), + [(int_x86_ttmmultf32ps timm:$src1, + timm:$src2, timm:$src3)]>; + } + } // SchedRW = [WriteSystem] +} // HasAMXTF32, HasAMXTRANSPOSE diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td index 2eb4e4fb941b29..a9ec5f660ff1d8 100644 --- a/llvm/lib/Target/X86/X86InstrPredicates.td +++ b/llvm/lib/Target/X86/X86InstrPredicates.td @@ -186,6 +186,7 @@ def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">; def HasAMXFP8 : Predicate<"Subtarget->hasAMXFP8()">; def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">; def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">; +def HasAMXTF32 : Predicate<"Subtarget->hasAMXTF32()">; def HasUINTR : Predicate<"Subtarget->hasUINTR()">; def HasUSERMSR : Predicate<"Subtarget->hasUSERMSR()">; def HasCRC32 : Predicate<"Subtarget->hasCRC32()">; diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 08c065c39ee1e3..0e74cfa75e9606 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -241,7 +241,8 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, case Intrinsic::x86_tdpbusd_internal: case Intrinsic::x86_tdpbuud_internal: case Intrinsic::x86_tdpbf16ps_internal: - case Intrinsic::x86_tdpfp16ps_internal: { + case Intrinsic::x86_tdpfp16ps_internal: + case Intrinsic::x86_tmmultf32ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); @@ -275,6 +276,23 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, Col = II->getArgOperand(1); break; } + case Intrinsic::x86_ttmmultf32ps_internal: { + switch (OpNo) { + case 3: + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + case 4: + Row = getRowFromCol(II, II->getArgOperand(2), 4); + Col = getColFromRow(II, II->getArgOperand(0), 4); + break; + case 5: + Row = getRowFromCol(II, II->getArgOperand(2), 4); + Col = II->getArgOperand(1); + break; + } + break; + } } return std::make_pair(Row, Col); diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 1b2192e3891fc5..09418c9bb74d34 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -1076,7 +1076,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM, case X86::PTDPFP16PSV: case X86::PTCMMIMFP16PSV: case X86::PTCMMRLFP16PSV: - case X86::PTTRANSPOSEDV: { + case X86::PTTRANSPOSEDV: + case X86::PTMMULTF32PSV: + case X86::PTTMMULTF32PSV: { MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2); ShapeT Shape(&MO1, &MO2, MRI); diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp index a973aaaa4806e6..140e565e1686f2 100644 --- a/llvm/lib/TargetParser/Host.cpp +++ b/llvm/lib/TargetParser/Host.cpp @@ -1880,6 +1880,7 @@ const StringMap<bool> sys::getHostCPUFeatures() { !getX86CpuIDAndInfoEx(0x1e, 0x1, &EAX, &EBX, &ECX, &EDX); Features["amx-fp8"] = HasLeaf1E && ((EAX >> 4) & 1) && HasAMXSave; Features["amx-transpose"] = HasLeaf1E && ((EAX >> 5) & 1) && HasAMXSave; + Features["amx-tf32"] = HasLeaf1E && ((EAX >> 6) & 1) && HasAMXSave; Features["amx-avx512"] = HasLeaf1E && ((EAX >> 7) & 1) && HasAMXSave; bool HasLeaf24 = diff --git a/llvm/lib/TargetParser/X86TargetParser.cpp b/llvm/lib/TargetParser/X86TargetParser.cpp index eb55e6fc9134c8..6b53424833bd47 100644 --- a/llvm/lib/TargetParser/X86TargetParser.cpp +++ b/llvm/lib/TargetParser/X86TargetParser.cpp @@ -602,6 +602,7 @@ constexpr FeatureBitset ImpliedFeaturesAMX_FP8 = FeatureAMX_TILE; constexpr FeatureBitset ImpliedFeaturesAMX_TRANSPOSE = FeatureAMX_TILE; constexpr FeatureBitset ImpliedFeaturesAMX_AVX512 = FeatureAMX_TILE | FeatureAVX10_2_512; +constexpr FeatureBitset ImpliedFeaturesAMX_TF32 = FeatureAMX_TILE; constexpr FeatureBitset ImpliedFeaturesHRESET = {}; constexpr FeatureBitset ImpliedFeaturesPREFETCHI = {}; diff --git a/llvm/test/CodeGen/X86/amx-tf32-internal.ll b/llvm/test/CodeGen/X86/amx-tf32-internal.ll new file mode 100644 index 00000000000000..8094f990828bad --- /dev/null +++ b/llvm/test/CodeGen/X86/amx-tf32-internal.ll @@ -0,0 +1,47 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+avx512f, \ +; RUN: -mattr=+amx-tf32,+amx-transpose -verify-machineinstrs | FileCheck %s + +define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { +; CHECK-LABEL: test_amx: +; CHECK: # %bb.0: +; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0 +; CHECK-NEXT: vmovups %zmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $1, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, %ax +; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm0 +; CHECK-NEXT: tilezero %tmm1 +; CHECK-NEXT: tilezero %tmm2 +; CHECK-NEXT: tmmultf32ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: ttmmultf32ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: tilestored %tmm2, (%rdi,%rdx) +; CHECK-NEXT: tilerelease +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + + %a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* %base, i64 %stride) + %b = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8) + %c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8) + + %c1 = call x86_amx @llvm.x86.tmmultf32ps.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b) + %c2 = call x86_amx @llvm.x86.ttmmultf32ps.internal(i16 8, i16 8, i16 8, x86_amx %c1, x86_amx %a, x86_amx %b) + + call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %c2) + ret void +} + +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare x86_amx @llvm.x86.tileloaddt164.internal(i16, i16, i8*, i64) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + + +declare x86_amx @llvm.x86.tmmultf32ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.ttmmultf32ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) diff --git a/llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll b/llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll new file mode 100644 index 00000000000000..af1a7ae1029756 --- /dev/null +++ b/llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll @@ -0,0 +1,23 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -O0 -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-tf32,+amx-transpose -verify-machineinstrs | FileCheck %s + +define void @test_tmmultf32ps() { +; CHECK-LABEL: test_tmmultf32ps: +; CHECK: # %bb.0: +; CHECK-NEXT: tmmultf32ps %tmm3, %tmm2, %tmm1 +; CHECK-NEXT: retq + call void @llvm.x86.tmmultf32ps(i8 1, i8 2, i8 3) + ret void +} +declare void @llvm.x86.tmmultf32ps(i8 %A, i8 %B, i8 %C) + +define void @test_ttmmultf32ps() { +; CHECK-LABEL: test_ttmmultf32ps: +; CHECK: # %bb.0: +; CHECK-NEXT: ttmmultf32ps %tmm3, %tmm2, %tmm1 +; CHECK-NEXT: retq + call void @llvm.x86.ttmmultf32ps(i8 1, i8 2, i8 3) + ret void +} +declare void @llvm.x86.ttmmultf32ps(i8 %A, i8 %B, i8 %C) + diff --git a/llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt b/llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt new file mode 100644 index 00000000000000..f372c42982b1b6 --- /dev/null +++ b/llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt @@ -0,0 +1,19 @@ +# RUN: llvm-mc --disassemble %s -triple=x86_64 | FileCheck -check-prefix=ATT %s +# RUN: llvm-mc --disassemble %s -triple=x86_64 -x86-asm-syntax=intel --output-asm-variant=1 | FileCheck -check-prefix=INTEL %s + +# ATT: tmmultf32ps %tmm4, %tmm5, %tmm6 +# INTEL: tmmultf32ps tmm6, tmm5, tmm4 +0xc4,0xe2,0x59,0x48,0xf5 + +# ATT: tmmultf32ps %tmm1, %tmm2, %tmm3 +# INTEL: tmmultf32ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x71,0x48,0xda + +# ATT: ttmmultf32ps %tmm4, %tmm5, %tmm6 +# INTEL: ttmmultf32ps tmm6, tmm5, tmm4 +0xc4,0xe2,0x58,0x48,0xf5 + +# ATT: ttmmultf32ps %tmm1, %tmm2, %tmm3 +# INTEL: ttmmultf32ps tmm3, tmm2, tmm1 +0xc4,0xe2,0x70,0x48,0xda + diff --git a/llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s new file mode 100644 index 00000000000000..b413597cd9da71 --- /dev/null +++ b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s @@ -0,0 +1,17 @@ +// RUN: llvm-mc -triple x86_64-unknown-unknown --show-encoding < %s | FileCheck %s + +// CHECK: tmmultf32ps %tmm4, %tmm5, %tmm6 +// CHECK: encoding: [0xc4,0xe2,0x59,0x48,0xf5] + tmmultf32ps %tmm4, %tmm5, %tmm6 + +// CHECK: tmmultf32ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x71,0x48,0xda] + tmmultf32ps %tmm1, %tmm2, %tmm3 + +// CHECK: ttmmultf32ps %tmm4, %tmm5, %tmm6 +// CHECK: encoding: [0xc4,0xe2,0x58,0x48,0xf5] + ttmmultf32ps %tmm4, %tmm5, %tmm6 + +// CHECK: ttmmultf32ps %tmm1, %tmm2, %tmm3 +// CHECK: encoding: [0xc4,0xe2,0x70,0x48,0xda] + ttmmultf32ps %tmm1, %tmm2, %tmm3 diff --git a/llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s new file mode 100644 index 00000000000000..98f55275716eb0 --- /dev/null +++ b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s @@ -0,0 +1,17 @@ +// RUN: llvm-mc -triple x86_64-unknown-unknown -x86-asm-syntax=intel -output-asm-variant=1 --show-encoding %s | FileCheck %s + +// CHECK: tmmultf32ps tmm6, tmm5, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x59,0x48,0xf5] + tmmultf32ps tmm6, tmm5, tmm4 + +// CHECK: tmmultf32ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x71,0x48,0xda] + tmmultf32ps tmm3, tmm2, tmm1 + +// CHECK: ttmmultf32ps tmm6, tmm5, tmm4 +// CHECK: encoding: [0xc4,0xe2,0x58,0x48,0xf5] + ttmmultf32ps tmm6, tmm5, tmm4 + +// CHECK: ttmmultf32ps tmm3, tmm2, tmm1 +// CHECK: encoding: [0xc4,0xe2,0x70,0x48,0xda] + ttmmultf32ps tmm3, tmm2, tmm1 >From 0c10267f80acd5cb1d7571537fb9105881aa665d Mon Sep 17 00:00:00 2001 From: Feng Zou <feng....@intel.com> Date: Sun, 10 Nov 2024 11:26:51 +0800 Subject: [PATCH 2/3] Addressed comments. --- clang/lib/Headers/amxtf32intrin.h | 91 +----------------- clang/lib/Headers/amxtf32transposeintrin.h | 105 +++++++++++++++++++++ llvm/lib/Target/X86/X86ISelLowering.cpp | 30 ++---- llvm/test/CodeGen/X86/amx-tf32-internal.ll | 1 - 4 files changed, 116 insertions(+), 111 deletions(-) create mode 100644 clang/lib/Headers/amxtf32transposeintrin.h diff --git a/clang/lib/Headers/amxtf32intrin.h b/clang/lib/Headers/amxtf32intrin.h index f11b7c7499e2d5..73fff9486a77ca 100644 --- a/clang/lib/Headers/amxtf32intrin.h +++ b/clang/lib/Headers/amxtf32intrin.h @@ -15,13 +15,14 @@ #define __AMX_TF32INTRIN_H #ifdef __x86_64__ +#if !defined(__SCE__) || __has_feature(modules) || \ + (defined(__AMX_TF32__) && defined(__AMX_TRANSPOSE__)) +#include <amxtf32transposeintrin.h> +#endif + #define __DEFAULT_FN_ATTRS_TF32 \ __attribute__((__always_inline__, __nodebug__, __target__("amx-tf32"))) -#define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE \ - __attribute__((__always_inline__, __nodebug__, \ - __target__("amx-tf32,amx-transpose"))) - /// Do Matrix Multiplication of \a a and \a b, and then do Matrix Plus /// with \a srcdst. /// All the calculation is base on float32 but with the lower 13-bit set to 0. @@ -82,60 +83,6 @@ #define _tile_mmultf32ps(srcdst, a, b) \ __builtin_ia32_tmmultf32ps((srcdst), (a), (b)) -/// \code -/// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \ -/// constexpr int b); -/// \endcode -/// -/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. -/// -/// \param srcdst -/// The destination tile. Max size is 1024 Bytes. -/// \param a -/// The 1st source tile. Max size is 1024 Bytes. -/// \param b -/// The 2nd source tile. Max size is 1024 Bytes. -/// -/// \code{.operation} -/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) { -/// dword[12:0] := 0 -/// dword[31:13] := x[31:13] -/// return dword -/// } -/// -/// DEFINE silence_snan_fp32(x[31:0]) { -/// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0) -/// x.fraction[22] := 1 -/// return x -/// } -/// -/// elements_dest:= srcdst.colsb/4 -/// -/// FOR m := 0 TO (srcdst.rows-1) -/// tmp[511:0] := 0 -/// FOR k := 0 TO (a.rows-1) -/// FOR n := 0 TO (elements_dest-1) -/// a1e := silence_snan_fp32(a.row[k].fp32[m]) -/// a2e := silence_snan_fp32(b.row[k].fp32[n]) -/// s1e := zero_lower_mantissa_bits_fp32(a1e) -/// s2e := zero_lower_mantissa_bits_fp32(a2e) -/// tmp.fp32[n] += s1e * s2e -/// ENDFOR -/// ENDFOR -/// -/// FOR n := 0 TO (elements_dest-1) -/// tmp.fp32[n] += srcdst.row[m].fp32[n] -/// ENDFOR -/// write_row_and_zero(srcdst, m, tmp, srcdst.colsb) -/// -/// ENDFOR -/// -/// zero_upper_rows(srcdst, srcdst.rows) -/// zero_tileconfig_start() -/// \endcode -#define _tile_tmmultf32ps(srcdst, a, b) \ - __builtin_ia32_ttmmultf32ps((srcdst), (a), (b)) - static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32 _tile_mmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, _tile1024i src1, _tile1024i src2) { @@ -162,33 +109,5 @@ static void __tile_mmultf32ps(__tile1024i *dst, __tile1024i src0, src0.tile, src1.tile); } -// dst = m x n (srcdest), src1 = k x m, src2 = k x n -static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE -_tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k, - _tile1024i dst, _tile1024i src1, _tile1024i src2) { - return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2); -} - -/// Compute transpose and do Matrix Multiplication of src0 and src1, and then do -/// Matrix Plus with dst. All the calculation is base on float32 but with the -/// lower 13-bit set to 0. -/// -/// \headerfile <immintrin.h> -/// -/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. -/// -/// \param dst -/// The destination tile. Max size is 1024 Bytes. -/// \param src0 -/// The 1st source tile. Max size is 1024 Bytes. -/// \param src1 -/// The 2nd source tile. Max size is 1024 Bytes. -__DEFAULT_FN_ATTRS_TF32_TRANSPOSE -static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0, - __tile1024i src1) { - dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col, - dst->tile, src0.tile, src1.tile); -} - #endif // __x86_64__ #endif // __AMX_TF32INTRIN_H diff --git a/clang/lib/Headers/amxtf32transposeintrin.h b/clang/lib/Headers/amxtf32transposeintrin.h new file mode 100644 index 00000000000000..d8e404780f5b7c --- /dev/null +++ b/clang/lib/Headers/amxtf32transposeintrin.h @@ -0,0 +1,105 @@ +/*===-- amxtf32transposeintrin.h - AMX_TF32 transpose intrinsics -*- C++ -*--=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===------------------------------------------------------------------------=== + */ +#ifndef __IMMINTRIN_H +#error \ + "Never use <amxtf32tranposeintrin.h> directly; include <immintrin.h> instead." +#endif // __IMMINTRIN_H + +#ifndef __AMX_TF32TRANSPOSEINTRIN_H +#define __AMX_TF32TRANSPOSEINTRIN_H +#ifdef __x86_64__ + +#define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE \ + __attribute__((__always_inline__, __nodebug__, \ + __target__("amx-tf32,amx-transpose"))) + +/// \code +/// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \ +/// constexpr int b); +/// \endcode +/// +/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. +/// +/// \param srcdst +/// The destination tile. Max size is 1024 Bytes. +/// \param a +/// The 1st source tile. Max size is 1024 Bytes. +/// \param b +/// The 2nd source tile. Max size is 1024 Bytes. +/// +/// \code{.operation} +/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) { +/// dword[12:0] := 0 +/// dword[31:13] := x[31:13] +/// return dword +/// } +/// +/// DEFINE silence_snan_fp32(x[31:0]) { +/// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0) +/// x.fraction[22] := 1 +/// return x +/// } +/// +/// elements_dest:= srcdst.colsb/4 +/// +/// FOR m := 0 TO (srcdst.rows-1) +/// tmp[511:0] := 0 +/// FOR k := 0 TO (a.rows-1) +/// FOR n := 0 TO (elements_dest-1) +/// a1e := silence_snan_fp32(a.row[k].fp32[m]) +/// a2e := silence_snan_fp32(b.row[k].fp32[n]) +/// s1e := zero_lower_mantissa_bits_fp32(a1e) +/// s2e := zero_lower_mantissa_bits_fp32(a2e) +/// tmp.fp32[n] += s1e * s2e +/// ENDFOR +/// ENDFOR +/// +/// FOR n := 0 TO (elements_dest-1) +/// tmp.fp32[n] += srcdst.row[m].fp32[n] +/// ENDFOR +/// write_row_and_zero(srcdst, m, tmp, srcdst.colsb) +/// +/// ENDFOR +/// +/// zero_upper_rows(srcdst, srcdst.rows) +/// zero_tileconfig_start() +/// \endcode +#define _tile_tmmultf32ps(srcdst, a, b) \ + __builtin_ia32_ttmmultf32ps((srcdst), (a), (b)) + +// dst = m x n (srcdest), src1 = k x m, src2 = k x n +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE +_tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2); +} + +/// Compute transpose and do Matrix Multiplication of src0 and src1, and then do +/// Matrix Plus with dst. All the calculation is base on float32 but with the +/// lower 13-bit set to 0. +/// +/// \headerfile <immintrin.h> +/// +/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS_TF32_TRANSPOSE +static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col, + dst->tile, src0.tile, src1.tile); +} + +#endif // __x86_64__ +#endif // __AMX_TF32TRANSPOSEINTRIN_H diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 6140dffbe8ce65..8597ccd2a32f08 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -37470,7 +37470,9 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTDPBF8PS: case X86::PTDPBHF8PS: case X86::PTDPHBF8PS: - case X86::PTDPHF8PS: { + case X86::PTDPHF8PS: + case X86::PTMMULTF32PS: + case X86::PTTMMULTF32PS: { unsigned Opc; switch (MI.getOpcode()) { // clang-format off @@ -37485,7 +37487,9 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break; case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break; case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; - // clang-format on + case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; + case X86::PTTMMULTF32PS: Opc = X86::TTMMULTF32PS; break; + // clang-format on } MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); @@ -37686,28 +37690,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); // The pseudo is gone now. return BB; } - case X86::PTMMULTF32PS: - case X86::PTTMMULTF32PS: { - const DebugLoc &DL = MI.getDebugLoc(); - unsigned Opc; - switch (MI.getOpcode()) { - default: - llvm_unreachable("Unexpected instruction!"); - case X86::PTMMULTF32PS: - Opc = X86::TMMULTF32PS; - break; - case X86::PTTMMULTF32PS: - Opc = X86::TTMMULTF32PS; - break; - } - MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); - MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef); - MI.eraseFromParent(); - return BB; - } } } diff --git a/llvm/test/CodeGen/X86/amx-tf32-internal.ll b/llvm/test/CodeGen/X86/amx-tf32-internal.ll index 8094f990828bad..6d0f3c57c08d89 100644 --- a/llvm/test/CodeGen/X86/amx-tf32-internal.ll +++ b/llvm/test/CodeGen/X86/amx-tf32-internal.ll @@ -39,7 +39,6 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { declare x86_amx @llvm.x86.tilezero.internal(i16, i16) declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) -declare x86_amx @llvm.x86.tileloaddt164.internal(i16, i16, i8*, i64) declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) >From 556c19502b877c2ffa918047517553521bf287af Mon Sep 17 00:00:00 2001 From: Feng Zou <feng....@intel.com> Date: Sun, 10 Nov 2024 11:30:11 +0800 Subject: [PATCH 3/3] Format a line of code. --- llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 8597ccd2a32f08..b01a60e31ae048 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -37489,7 +37489,7 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; case X86::PTTMMULTF32PS: Opc = X86::TTMMULTF32PS; break; - // clang-format on + // clang-format on } MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits