fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D72778
Files:
clang/include/clang/Basic/Builtins.def
clang/include/clang/Sema/Sema.h
clang/lib/CodeGen/CGBuiltin.cpp
clang/lib/Sema/SemaChecking.cpp
clang/test/CodeGen/builtin-matrix.c
Index: clang/test/CodeGen/builtin-matrix.c
===================================================================
--- clang/test/CodeGen/builtin-matrix.c
+++ clang/test/CodeGen/builtin-matrix.c
@@ -251,4 +251,24 @@
}
// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]]
+void transpose1(dx5x5_t *a, dx5x5_t *b) {
+ *a = __builtin_matrix_transpose(*b);
+
+ // CHECK-LABEL: @transpose1(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>*
+ // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8
+ // CHECK-NEXT: %3 = call <25 x double> @llvm.matrix.transpose.v25f64(<25 x double> %2, i32 5, i32 5)
+ // CHECK-NEXT: %4 = load [25 x double]*, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: %5 = bitcast [25 x double]* %4 to <25 x double>*
+ // CHECK-NEXT: store <25 x double> %3, <25 x double>* %5, align 8
+ // CHECK-NEXT: ret void
+}
+// CHECK: declare <25 x double> @llvm.matrix.transpose.v25f64(<25 x double>, i32 immarg, i32 immarg) [[READNONE]]
+
// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1618,6 +1618,7 @@
case Builtin::BI__builtin_matrix_add:
case Builtin::BI__builtin_matrix_subtract:
case Builtin::BI__builtin_matrix_multiply:
+ case Builtin::BI__builtin_matrix_transpose:
if (!getLangOpts().EnableMatrix) {
Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
return ExprError();
@@ -1633,6 +1634,8 @@
return SemaBuiltinMatrixEltwiseOverload(TheCall, TheCallResult);
case Builtin::BI__builtin_matrix_multiply:
return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult);
+ case Builtin::BI__builtin_matrix_transpose:
+ return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
default:
llvm_unreachable("All matrix builtins should be handled here!");
}
@@ -15470,3 +15473,60 @@
return CallResult;
}
+
+ExprResult Sema::SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
+ ExprResult CallResult) {
+ if (checkArgCount(*this, TheCall, 1))
+ return ExprError();
+
+ Expr *Arg = TheCall->getArg(0);
+
+ // Some very basic type chekcing, the parameter must be a matrix
+ if (!Arg->getType()->isMatrixType()) {
+ Diag(Arg->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+ return ExprError();
+ }
+
+ MatrixType const *MType =
+ cast<MatrixType const>(Arg->getType().getCanonicalType());
+
+ unsigned R = MType->getNumRows();
+ unsigned C = MType->getNumColumns();
+ // Full Type Checking
+
+ // Set up the function prototype
+
+ if (!Arg->isRValue()) {
+ ExprResult Res = ImplicitCastExpr::Create(
+ Context, Arg->getType(), CK_LValueToRValue, Arg, nullptr, VK_RValue);
+ assert(!Res.isInvalid() && "Matrix Cast failed");
+ TheCall->setArg(0, Res.get());
+ }
+
+ Expr *Callee = TheCall->getCallee();
+ DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+ FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+
+ // Function Return Type
+ QualType ReturnElementType = MType->getElementType();
+ QualType ResultType = Context.getMatrixType(ReturnElementType, C, R);
+
+ // Create a new DeclRefExpr to refer to the new decl.
+ DeclRefExpr *NewDRE = DeclRefExpr::Create(
+ Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+ /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+ DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+
+ // Set the callee in the CallExpr.
+ // FIXME: This loses syntactic information.
+ QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+ ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy,
+ CK_BuiltinFnToFnPtr);
+ TheCall->setCallee(PromotedCall.get());
+
+ // Change the result type of the call to match the original value type. This
+ // is arbitrary, but the codegen for these builtins ins design to handle it
+ // gracefully.
+ TheCall->setType(ResultType);
+ return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2366,6 +2366,15 @@
return RValue::get(Result);
}
+ case Builtin::BI__builtin_matrix_transpose: {
+ const MatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType());
+ Value *MatValue = EmitScalarExpr(E->getArg(0));
+ MatrixBuilder<CGBuilderTy> MB(Builder);
+ Value *Result = MB.CreateMatrixTranspose(
+ MatValue, MatrixTy->getNumRows(), MatrixTy->getNumColumns());
+ return RValue::get(Result);
+ }
+
case Builtin::BI__builtin_matrix_add: {
MatrixBuilder<CGBuilderTy> MB(Builder);
Value *Matrix1 = EmitScalarExpr(E->getArg(0));
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11620,6 +11620,9 @@
ExprResult CallResult);
ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall,
ExprResult CallResult);
+ ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
+ ExprResult CallResult);
+
public:
enum FormatStringType {
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -578,6 +578,7 @@
BUILTIN(__builtin_matrix_subtract, "v.", "nt")
BUILTIN(__builtin_matrix_add, "v.", "nt")
BUILTIN(__builtin_matrix_multiply, "v.", "nt")
+BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
// "Overloaded" Atomic operator builtins. These are overloaded to support data
// types of i8, i16, i32, i64, and i128. The front-end sees calls to the
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits