fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72781

Files:
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGen/builtin-matrix.c

Index: clang/test/CodeGen/builtin-matrix.c
===================================================================
--- clang/test/CodeGen/builtin-matrix.c
+++ clang/test/CodeGen/builtin-matrix.c
@@ -271,4 +271,23 @@
 }
 // CHECK: declare <25 x double> @llvm.matrix.transpose.v25f64(<25 x double>, i32 immarg, i32 immarg) [[READNONE]]
 
+void column_load1(dx5x5_t *a, double *b) {
+  *a = __builtin_matrix_column_load(b, 5, 5, 10);
+
+  // CHECK-LABEL: @column_load1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca double*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store double* %b, double** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load double*, double** %b.addr, align 8
+  // CHECK-NEXT:    %matrix = call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* %0, i32 10, i32 5, i32 5)
+  // CHECK-NEXT:    %1 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %1 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %matrix, <25 x double>* %2, align 8
+  // CHECK-NEXT:    ret void
+}
+// CHECK: declare <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double*, i32, i32 immarg, i32 immarg) [[READONLY:#[0-9]]]
+
 // CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }
+// CHECK: attributes [[READONLY]] = { nounwind readonly willreturn }
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1619,6 +1619,7 @@
   case Builtin::BI__builtin_matrix_subtract:
   case Builtin::BI__builtin_matrix_multiply:
   case Builtin::BI__builtin_matrix_transpose:
+  case Builtin::BI__builtin_matrix_column_load:
     if (!getLangOpts().EnableMatrix) {
       Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
       return ExprError();
@@ -1636,6 +1637,8 @@
       return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult);
     case Builtin::BI__builtin_matrix_transpose:
       return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
+    case Builtin::BI__builtin_matrix_column_load:
+      return SemaBuiltinMatrixColumnLoadOverload(TheCall, TheCallResult);
     default:
       llvm_unreachable("All matrix builtins should be handled here!");
     }
@@ -15530,3 +15533,121 @@
   TheCall->setType(ResultType);
   return CallResult;
 }
+
+ExprResult Sema::SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall,
+                                                     ExprResult CallResult) {
+  // Must have exactly four operands
+  // 1: Pointer to data
+  // 2: Rows (constant)
+  // 3: Columns (constant)
+  // 5: Stride
+
+  // Operands have very similar semantics to glVertexAttribPointer from OpenGL.
+  // Instead of the attribute index, it is a pointer to the memory that is being
+  // loaded from Instead of size, we need the rows and columns. Note that these
+  // must be constant to construct the matrix type.
+
+  if (checkArgCount(*this, TheCall, 4))
+    return ExprError();
+
+  Expr *DataExpr = TheCall->getArg(0);
+  Expr *RowsExpr = TheCall->getArg(1);
+  Expr *ColsExpr = TheCall->getArg(2);
+  Expr *StrideExpr = TheCall->getArg(3);
+
+  unsigned Rows = 0;
+  unsigned Cols = 0;
+
+  if (!(DataExpr->getType()->isPointerType() ||
+        DataExpr->getType()->isArrayType())) {
+    Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg)
+        << 0 << 0;
+  }
+
+  bool ArgError = false;
+  // get the matrix dimensions
+  {
+    llvm::APSInt Value(32);
+    SourceLocation RowColErrorPos;
+
+    if (!RowsExpr->isIntegerConstantExpr(Value, Context, &RowColErrorPos)) {
+      Diag(RowsExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+          << 0 << 1;
+      ArgError = true;
+    } else
+      Rows = Value.getZExtValue();
+
+    if (!ColsExpr->isIntegerConstantExpr(Value, Context, &RowColErrorPos)) {
+      Diag(ColsExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+          << 1 << 1;
+      ArgError = true;
+    } else
+      Cols = Value.getZExtValue();
+  }
+  if (!StrideExpr->getType()->isIntegralType(Context)) {
+    Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+        << 3 << 1;
+    ArgError = true;
+  }
+  if (ArgError)
+    return ExprError();
+
+  QualType ElementType;
+
+  if (const PointerType *PTy = dyn_cast<PointerType>(DataExpr->getType())) {
+    ElementType = PTy->getPointeeType();
+  } else if (const ArrayType *ATy = dyn_cast<ArrayType>(DataExpr->getType())) {
+    ElementType = ATy->getElementType();
+  } else {
+    llvm_unreachable("Pointer Expression must be a pointer or an array");
+    return ExprError();
+  }
+  ElementType.removeLocalConst();
+
+  if (!ElementType->isIntegralType(Context) && !ElementType->isFloatingType()) {
+    Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg)
+        << 0 << 1;
+    return ExprError();
+  }
+
+  // TODO: Check this, it seems weird to have to cast a pointer to an l-value
+  // I guess it needs to be materialized as a pointer before we can work with it
+  if (!DataExpr->isRValue()) {
+    ExprResult CastExprResult = ImplicitCastExpr::Create(
+        Context, DataExpr->getType(), CK_LValueToRValue, DataExpr, nullptr,
+        VK_RValue);
+    assert(!CastExprResult.isInvalid() &&
+           "Pointer failed to be casted to an R-value");
+    DataExpr = CastExprResult.get();
+    TheCall->setArg(0, DataExpr);
+  }
+
+  QualType ReturnType = Context.getMatrixType(ElementType, Rows, Cols);
+
+  llvm::SmallVector<QualType, 5> ParameterTypes = {
+      DataExpr->getType().withConst(), RowsExpr->getType().withConst(),
+      ColsExpr->getType().withConst(), StrideExpr->getType().withConst()};
+  Expr *Callee = TheCall->getCallee();
+  DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+  FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+
+  // Create a new DeclRefExpr to refer to the new decl.
+  DeclRefExpr *NewDRE = DeclRefExpr::Create(
+      Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+      /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+      DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+
+  // Set the callee in the CallExpr.
+  // FIXME: This loses syntactic information.
+  QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+  ExprResult PromotedCall =
+      ImpCastExprToType(NewDRE, CalleePtrTy, CK_BuiltinFnToFnPtr);
+  TheCall->setCallee(PromotedCall.get());
+
+  // Change the result type of the call to match the original value type. This
+  // is arbitrary, but the codegen for these builtins ins design to handle it
+  // gracefully.
+  TheCall->setType(ReturnType);
+
+  return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2340,6 +2340,39 @@
     V = Builder.CreateFCmpUNO(V, V, "cmp");
     return RValue::get(Builder.CreateZExt(V, ConvertType(E->getType())));
   }
+  case Builtin::BI__builtin_matrix_column_load: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    // Emit everything that isn't dependent on the first parameter type
+    Value *Stride = EmitScalarExpr(E->getArg(3));
+    const MatrixType *ResultTy = getMatrixTy(E->getType());
+
+    // If it's an address we need to emit the pointer
+    // otherwise, emit the array
+    Value *Result = nullptr;
+    if (const PointerType *PTy =
+            dyn_cast<PointerType>(E->getArg(0)->getType())) {
+      Address Src = EmitPointerWithAlignment(E->getArg(0));
+      EmitNonNullArgCheck(RValue::get(Src.getPointer()),
+                          E->getArg(0)->getType(), E->getArg(0)->getExprLoc(),
+                          FD, 0);
+      Result = MB.CreateMatrixColumnwiseLoad(
+          Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(),
+          Stride, "matrix");
+    } else if (const ArrayType *ATy =
+                   dyn_cast<ArrayType>(E->getArg(0)->getType())) {
+      Address Src = EmitArrayToPointerDecay(E->getArg(0));
+      EmitNonNullArgCheck(RValue::get(Src.getPointer()),
+                          E->getArg(0)->getType(), E->getArg(0)->getExprLoc(),
+                          FD, 0);
+      Result = MB.CreateMatrixColumnwiseLoad(
+          Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(),
+          Stride, "matrix");
+    } else {
+      llvm_unreachable(
+          "CGBuiltin.cpp: First argument must either be a pointer or an array");
+    }
+    return RValue::get(Result);
+  }
   case Builtin::BI__builtin_matrix_insert: {
     MatrixBuilder<CGBuilderTy> MB(Builder);
     Value *MatValue = EmitScalarExpr(E->getArg(0));
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11623,6 +11623,8 @@
   ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
                                                 ExprResult CallResult);
 
+  ExprResult SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall,
+                                                 ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -579,6 +579,7 @@
 BUILTIN(__builtin_matrix_add, "v.", "nt")
 BUILTIN(__builtin_matrix_multiply, "v.", "nt")
 BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
+BUILTIN(__builtin_matrix_column_load, "v.", "nFt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D72781: [Ma... Florian Hahn via Phabricator via cfe-commits

Reply via email to