LuoYuanke created this revision.
Herald added a subscriber: pengfei.
LuoYuanke requested review of this revision.
Herald added projects: clang, LLVM.
Herald added subscribers: llvm-commits, cfe-commits.

Introduce new intrinsic to cast vector and amx. This can prevent
middle-end optimization on bitcast. However sometimes we need the
optimizaton for bitcast. For inner_product of amx_cast.c, we have to
deal with llvm.x86.vector.amx.cast.v256i32.x86amx by ourselves.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D99152

Files:
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/test/CodeGen/X86/amx_cast.c
  llvm/include/llvm/IR/IntrinsicsX86.td

Index: llvm/include/llvm/IR/IntrinsicsX86.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsX86.td
+++ llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5085,6 +5085,8 @@
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty], []>;
+  def int_x86_vector_amx_cast :
+              Intrinsic<[llvm_any_ty], [llvm_any_ty], [IntrNoMem]>;
 }
 
 //===----------------------------------------------------------------------===//
Index: clang/test/CodeGen/X86/amx_cast.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/X86/amx_cast.c
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 %s -O2 -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +avx512f  -target-feature +amx-int8  \
+// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK
+
+#include <immintrin.h>
+
+char buf[1024];
+#define STRIDE 32
+
+char buf2[1024];
+
+void test1() {
+//CHECK: %0 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #1
+//CHECK: %1 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %0) #1
+//CHECK: %2 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %1) #1
+//CHECK: tail call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %2) #1
+  __tile1024i a = {8, 8};
+  __tile1024i b = {8, 8};
+
+  __tile_loadd(&a, buf, STRIDE);
+  __tile_stored(buf, STRIDE, a);
+}
+
+void test2() {
+//CHECK:  %0 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> zeroinitializer) #1
+//CHECK:  tail call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %0) #1
+  __tile1024i a = {8, 8};
+
+  __tile_stored(buf, STRIDE, a);
+}
+
+#define TILE_SZ 16
+void inner_product(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) {
+//CHECK: for.body6:                                        ; preds = %for.body6.lr.ph, %for.cond.cleanup9
+//CHECK:   %indvars.iv200 = phi i64 [ 0, %for.body6.lr.ph ], [ %indvars.iv.next201, %for.cond.cleanup9 ]
+//CHECK:   %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) #1
+//CHECK:   %2 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %1) #1
+//CHECK:   %3 = shl nsw i64 %indvars.iv200, 4
+//CHECK:   br i1 %cmp8163, label %for.body10.lr.ph, label %for.cond.cleanup9
+//CHECK: for.body10.lr.ph:                                 ; preds = %for.body6
+//CHECK:   %add.ptr19 = getelementptr inbounds i32, i32* %B_mem, i64 %3
+//CHECK:   br label %for.body10
+//CHECK: for.cond.cleanup9:                                ; preds = %for.body10, %for.body6
+//CHECK:   %c.sroa.8127.2.lcssa = phi <256 x i32> [ %2, %for.body6 ], [ %18, %for.body10 ]
+//CHECK:   %add.ptr31 = getelementptr inbounds i32, i32* %add.ptr28, i64 %3
+//CHECK:   %4 = bitcast i32* %add.ptr31 to i8*
+//CHECK:   %5 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %c.sroa.8127.2.lcssa) #1
+//CHECK:   tail call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* %4, i64 %mul24, x86_amx %5) #1
+//CHECK:   %indvars.iv.next201 = add nuw nsw i64 %indvars.iv200, 1
+//CHECK:   %exitcond205.not = icmp eq i64 %indvars.iv.next201, %wide.trip.count204
+//CHECK:   br i1 %exitcond205.not, label %for.cond.cleanup5, label %for.body6, !llvm.loop !4
+//CHECK: for.body10:                                       ; preds = %for.body10.lr.ph, %for.body10
+//CHECK:   %indvars.iv = phi i64 [ 0, %for.body10.lr.ph ], [ %indvars.iv.next, %for.body10 ]
+//CHECK:   %c.sroa.8127.2164 = phi <256 x i32> [ %2, %for.body10.lr.ph ], [ %18, %for.body10 ]
+//CHECK:   %6 = shl nsw i64 %indvars.iv, 4
+//CHECK:   %add.ptr14 = getelementptr inbounds i32, i32* %add.ptr, i64 %6
+//CHECK:   %7 = bitcast i32* %add.ptr14 to i8*
+//CHECK:   %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %7, i64 %mul15) #1
+//CHECK:   %9 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %8) #1
+//CHECK:   %10 = mul nsw i64 %6, %conv23
+//CHECK:   %add.ptr22 = getelementptr inbounds i32, i32* %add.ptr19, i64 %10
+//CHECK:   %11 = bitcast i32* %add.ptr22 to i8*
+//CHECK:   %12 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %11, i64 %mul24) #1
+//CHECK:   %13 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %12) #1
+//CHECK:   %14 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %c.sroa.8127.2164) #1
+//CHECK:   %15 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %9) #1
+//CHECK:   %16 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %13) #1
+//CHECK:   %17 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %14, x86_amx %15, x86_amx %16) #1
+//CHECK:   %18 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %17) #1
+//CHECK:   %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+//CHECK:   %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+//CHECK:   br i1 %exitcond.not, label %for.cond.cleanup9, label %for.body10, !llvm.loop !5
+
+  const int m = M / TILE_SZ;
+  const int n = N / TILE_SZ;
+  const int k = K / TILE_SZ;
+
+  for (int i = 0; i < m; i++)
+    for (int j = 0; j < n; j++) {
+      __tile1024i c = {TILE_SZ, TILE_SZ*sizeof(int)};
+      __tile_zero(&c);
+      for (int l = 0; l < k; l++) {
+        __tile1024i a = {TILE_SZ, TILE_SZ*sizeof(int)};
+        __tile1024i b = {TILE_SZ, TILE_SZ*sizeof(int)};
+        __tile_loadd(&a, A_mem+(i*TILE_SZ)*K+l*TILE_SZ, K*sizeof(int));
+        __tile_loadd(&b, B_mem+(l*TILE_SZ)*N+j*TILE_SZ, N*sizeof(int));
+        __tile_dpbssd(&c, a, b);
+      }
+      __tile_stored(C_mem+(i*TILE_SZ)*M+j*TILE_SZ, N*sizeof(int), c);
+    }
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -5139,7 +5139,12 @@
 
         assert(PTy->canLosslesslyBitCastTo(FTy->getParamType(i)) &&
                "Must be able to losslessly bit cast to param");
-        ArgValue = Builder.CreateBitCast(ArgValue, PTy);
+        if (PTy->isX86_AMXTy() || ArgValue->getType()->isX86_AMXTy())
+          ArgValue =
+              Builder.CreateIntrinsic(Intrinsic::x86_vector_amx_cast,
+                                      {PTy, ArgValue->getType()}, {ArgValue});
+        else
+          ArgValue = Builder.CreateBitCast(ArgValue, PTy);
       }
 
       Args.push_back(ArgValue);
@@ -5163,7 +5168,11 @@
 
       assert(V->getType()->canLosslesslyBitCastTo(RetTy) &&
              "Must be able to losslessly bit cast result type");
-      V = Builder.CreateBitCast(V, RetTy);
+      if (RetTy->isX86_AMXTy() || V->getType()->isX86_AMXTy())
+        V = Builder.CreateIntrinsic(Intrinsic::x86_vector_amx_cast,
+                                    {RetTy, V->getType()}, {V});
+      else
+        V = Builder.CreateBitCast(V, RetTy);
     }
 
     return RValue::get(V);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to