fhahn updated this revision to Diff 254746.
fhahn marked an inline comment as done.
fhahn added a comment.
Use 20 bits in MatrixTypeBitfields for both number of rows and number of
columns. This leaves 24 bits for NumTypeBits, while providing a large ranges
for number of rows/columns.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D72281/new/
https://reviews.llvm.org/D72281
Files:
clang/include/clang/AST/ASTContext.h
clang/include/clang/AST/RecursiveASTVisitor.h
clang/include/clang/AST/Type.h
clang/include/clang/AST/TypeLoc.h
clang/include/clang/AST/TypeProperties.td
clang/include/clang/Basic/Attr.td
clang/include/clang/Basic/DiagnosticSemaKinds.td
clang/include/clang/Basic/LangOptions.def
clang/include/clang/Basic/TypeNodes.td
clang/include/clang/Driver/Options.td
clang/include/clang/Sema/Sema.h
clang/include/clang/Serialization/TypeBitCodes.def
clang/lib/AST/ASTContext.cpp
clang/lib/AST/ASTStructuralEquivalence.cpp
clang/lib/AST/ExprConstant.cpp
clang/lib/AST/ItaniumMangle.cpp
clang/lib/AST/MicrosoftMangle.cpp
clang/lib/AST/Type.cpp
clang/lib/AST/TypePrinter.cpp
clang/lib/Basic/Targets/OSTargets.cpp
clang/lib/CodeGen/CGDebugInfo.cpp
clang/lib/CodeGen/CGDebugInfo.h
clang/lib/CodeGen/CGExpr.cpp
clang/lib/CodeGen/CodeGenFunction.cpp
clang/lib/CodeGen/CodeGenTypes.cpp
clang/lib/CodeGen/ItaniumCXXABI.cpp
clang/lib/Driver/ToolChains/Clang.cpp
clang/lib/Frontend/CompilerInvocation.cpp
clang/lib/Sema/SemaExpr.cpp
clang/lib/Sema/SemaLookup.cpp
clang/lib/Sema/SemaTemplate.cpp
clang/lib/Sema/SemaTemplateDeduction.cpp
clang/lib/Sema/SemaType.cpp
clang/lib/Sema/TreeTransform.h
clang/lib/Serialization/ASTReader.cpp
clang/lib/Serialization/ASTWriter.cpp
clang/test/CodeGen/matrix-type.c
clang/test/CodeGenCXX/matrix-type.cpp
clang/test/SemaCXX/matrix-type.cpp
clang/tools/libclang/CIndex.cpp
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -1786,6 +1786,8 @@
DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type)
DEFAULT_TYPELOC_IMPL(Vector, Type)
DEFAULT_TYPELOC_IMPL(ExtVector, VectorType)
+DEFAULT_TYPELOC_IMPL(Matrix, Type)
+DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, Type)
DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType)
DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType)
DEFAULT_TYPELOC_IMPL(Record, TagType)
Index: clang/test/SemaCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/matrix-type.cpp
@@ -0,0 +1,54 @@
+// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+
+using matrix_double_t = double __attribute__((matrix_type(6, 6)));
+using matrix_float_t = float __attribute__((matrix_type(6, 6)));
+using matrix_int_t = int __attribute__((matrix_type(6, 6)));
+
+void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
+ using matrix1_t = int __attribute__((matrix_type(Rows, 1))); // expected-error{{matrix_type attribute requires an integer constant}}
+ using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}}
+ using matrix3_t = int __attribute__((matrix_type(C, C))); // expected-error{{matrix_type attribute requires an integer constant}}
+ using matrix4_t = int __attribute__((matrix_type(-1, 1))); // expected-error{{matrix row size too large. Size is 4294967295 when the maximum allowed is 1048575}}
+ using matrix5_t = int __attribute__((matrix_type(1, -1))); // expected-error{{matrix column size too large. Size is 4294967295 when the maximum allowed is 1048575}}
+ using matrix6_t = int __attribute__((matrix_type(0, 1))); // expected-error{{zero matrix size}}
+ using matrix7_t = int __attribute__((matrix_type(1, 0))); // expected-error{{zero matrix size}}
+ using matrix7_t = int __attribute__((matrix_type(char, 0))); // expected-error{{expected '(' for function-style cast or type construction}}
+ using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large. Size is 1048576 when the maximum allowed is 1048575}}
+}
+
+struct S1 {};
+
+void matrix_unsupported_element_type() {
+ using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}}
+ using matrix2_t = S1 __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'S1'}}
+}
+
+template <typename T> // expected-note{{declared here}}
+void matrix_template_1() {
+ using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}}
+}
+
+template <class C> // expected-note{{declared here}}
+void matrix_template_2() {
+ using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}}
+}
+
+template <unsigned Rows, unsigned Cols>
+void matrix_template_3() {
+ using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero matrix size}}
+}
+
+void instantiate_template_3() {
+ matrix_template_3<1, 10>();
+ matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}}
+}
+
+template <int Rows, unsigned Cols>
+void matrix_template_4() {
+ using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{matrix row size too large. Size is 4294967293 when the maximum allowed is 1048575}}
+}
+
+void instantiate_template_4() {
+ matrix_template_4<2, 10>();
+ matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}}
+}
Index: clang/test/CodeGenCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCXX/matrix-type.cpp
@@ -0,0 +1,176 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+ // CHECK-LABEL: define void @_Z10load_storePDm5_5_dS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>*
+ // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8
+ // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>*
+ // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+ // CHECK-LABEL: define void @_Z17parameter_passingDm3_3_fPS_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float], align 4
+ // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4
+ // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4
+ // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4
+ // CHECK-NEXT: ret void
+ *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+ // CHECK-LABEL: define <9 x float> @_Z13return_matrixPDm3_3_f(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>*
+ // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4
+ // CHECK-NEXT: ret <9 x float> %2
+ return *a;
+}
+
+struct Matrix {
+ char Tmp1;
+ fx3x4_t Data;
+ float Tmp2;
+};
+
+void matrix_struct_pointers(Matrix *a, Matrix *b) {
+ // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b->Data = a->Data;
+}
+
+void matrix_struct_reference(Matrix &a, Matrix &b) {
+ // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b.Data = a.Data;
+}
+
+class MatrixClass {
+public:
+ int Tmp1;
+ fx3x4_t Data;
+ long Tmp2;
+};
+
+void matrix_class_reference(MatrixClass &a, MatrixClass &b) {
+ // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %class.MatrixClass*, align 8
+ // CHECK-NEXT: %b.addr = alloca %class.MatrixClass*, align 8
+ // CHECK-NEXT: store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8
+ // CHECK-NEXT: store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b.Data = a.Data;
+}
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+class MatrixClassTemplate {
+public:
+ using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols)));
+ int Tmp1;
+ MatrixTy Data;
+ long Tmp2;
+};
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) {
+ b.Data = a.Data;
+}
+
+MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) {
+ // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %Data.addr = alloca float*, align 8
+ // CHECK-NEXT: %Arg = alloca %class.MatrixClassTemplate, align 8
+ // CHECK-NEXT: store float* %Data, float** %Data.addr, align 8
+ // CHECK-NEXT: %0 = load float*, float** %Data.addr, align 8
+ // CHECK-NEXT: %1 = bitcast float* %0 to [150 x float]*
+ // CHECK-NEXT: %2 = bitcast [150 x float]* %1 to <150 x float>*
+ // CHECK-NEXT: %3 = load <150 x float>, <150 x float>* %2, align 4
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+ // CHECK-NEXT: store <150 x float> %3, <150 x float>* %4, align 4
+ // CHECK-NEXT: call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result)
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %class.MatrixClassTemplate*, align 8
+ // CHECK-NEXT: %b.addr = alloca %class.MatrixClassTemplate*, align 8
+ // CHECK-NEXT: store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8
+ // CHECK-NEXT: store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [150 x float]* %Data to <150 x float>*
+ // CHECK-NEXT: %2 = load <150 x float>, <150 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+ // CHECK-NEXT: store <150 x float> %2, <150 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+
+ MatrixClassTemplate<float, 10, 15> Result, Arg;
+ Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data);
+ matrix_template_reference(Arg, Result);
+ return Result;
+}
Index: clang/test/CodeGen/matrix-type.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/matrix-type.c
@@ -0,0 +1,79 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+ // CHECK-LABEL: define void @load_store(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>*
+ // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8
+ // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>*
+ // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+ // CHECK-LABEL: define void @parameter_passing(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float], align 4
+ // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4
+ // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4
+ // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4
+ // CHECK-NEXT: ret void
+ *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+ // CHECK-LABEL: define <9 x float> @return_matrix
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>*
+ // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4
+ // CHECK-NEXT: ret <9 x float> %2
+ return *a;
+}
+
+typedef struct {
+ char Tmp1;
+ fx3x4_t Data;
+ float Tmp2;
+} Matrix;
+
+void matrix_struct(Matrix *a, Matrix *b) {
+ // CHECK-LABEL: define void @matrix_struct(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b->Data = a->Data;
+}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,15 @@
Record.AddSourceLocation(TL.getNameLoc());
}
+void TypeLocWriter::VisitMatrixTypeLoc(MatrixTypeLoc TL) {
+ Record.AddSourceLocation(TL.getNameLoc());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+ DependentSizedMatrixTypeLoc TL) {
+ Record.AddSourceLocation(TL.getNameLoc());
+}
+
void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
Record.AddSourceLocation(TL.getLocalRangeBegin());
Record.AddSourceLocation(TL.getLParenLoc());
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -6525,6 +6525,15 @@
TL.setNameLoc(readSourceLocation());
}
+void TypeLocReader::VisitMatrixTypeLoc(MatrixTypeLoc TL) {
+ TL.setNameLoc(readSourceLocation());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+ DependentSizedMatrixTypeLoc TL) {
+ TL.setNameLoc(readSourceLocation());
+}
+
void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
TL.setLocalRangeBegin(readSourceLocation());
TL.setLParenLoc(readSourceLocation());
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -894,6 +894,16 @@
Expr *SizeExpr,
SourceLocation AttributeLoc);
+ /// Build a new matrix type given the element type and dimensions.
+ QualType RebuildMatrixType(QualType ElementType, unsigned NumRows,
+ unsigned NumColumns);
+
+ /// Build a new matrix type given the type and dependently-defined
+ /// dimensions.
+ QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttributeLoc);
+
/// Build a new DependentAddressSpaceType or return the pointee
/// type variable with the correct address space (retrieved from
/// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5160,6 +5170,65 @@
return Result;
}
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformMatrixType(TypeLocBuilder &TLB,
+ MatrixTypeLoc TL) {
+ const MatrixType *T = TL.getTypePtr();
+ QualType ElementType = getDerived().TransformType(T->getElementType());
+ if (ElementType.isNull())
+ return QualType();
+
+ QualType Result = TL.getType();
+ if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+ Result = getDerived().RebuildMatrixType(ElementType, T->getNumRows(),
+ T->getNumColumns());
+ if (Result.isNull())
+ return QualType();
+ }
+
+ MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+ NewTL.setNameLoc(TL.getNameLoc());
+
+ return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+ TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+ const DependentSizedMatrixType *T = TL.getTypePtr();
+
+ QualType ElementType = getDerived().TransformType(T->getElementType());
+ if (ElementType.isNull()) {
+ return QualType();
+ }
+
+ EnterExpressionEvaluationContext Unevaluated(
+ SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+ ExprResult Rows = getDerived().TransformExpr(T->getRowExpr());
+ ExprResult Cols = getDerived().TransformExpr(T->getColumnExpr());
+
+ QualType Result = TL.getType();
+ // TODO: Finish this
+ if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+ Rows.get() != T->getRowExpr() || Cols.get() != T->getColumnExpr()) {
+ Result = getDerived().RebuildDependentSizedMatrixType(
+ ElementType, Rows.get(), Cols.get(), T->getAttributeLoc());
+
+ if (Result.isNull())
+ return QualType();
+ }
+
+ if (isa<DependentSizedMatrixType>(Result)) {
+ DependentSizedMatrixTypeLoc NewTL =
+ TLB.push<DependentSizedMatrixTypeLoc>(Result);
+ NewTL.setNameLoc(TL.getNameLoc());
+ } else {
+ MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+ NewTL.setNameLoc(TL.getNameLoc());
+ }
+ return Result;
+}
+
template <typename Derived>
QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
@@ -13662,6 +13731,21 @@
return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
}
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildMatrixType(QualType ElementType,
+ unsigned NumRows,
+ unsigned NumColumns) {
+ return SemaRef.Context.getMatrixType(ElementType, NumRows, NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+ QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+ SourceLocation AttributeLoc) {
+ return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+ AttributeLoc);
+}
+
template<typename Derived>
QualType TreeTransform<Derived>::RebuildFunctionProtoType(
QualType T,
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -2440,7 +2440,8 @@
unsigned TypeSize = static_cast<unsigned>(Context.getTypeSize(CurType));
if (VectorSize == 0) {
- Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange();
+ Diag(AttrLoc, diag::err_attribute_zero_size)
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
@@ -2454,7 +2455,7 @@
if (VectorType::isVectorSizeTooLarge(VectorSize / TypeSize)) {
Diag(AttrLoc, diag::err_attribute_size_too_large)
// Display sizes in error messages in bytes.
- << SizeExpr->getSourceRange()
+ << SizeExpr->getSourceRange() << "vector"
<< static_cast<unsigned>(VecSize.getZExtValue())
<< (VectorType::getMaxNumElements() * (TypeSize / 8));
return QualType();
@@ -2498,13 +2499,13 @@
if (vectorSize == 0) {
Diag(AttrLoc, diag::err_attribute_zero_size)
- << ArraySize->getSourceRange();
+ << ArraySize->getSourceRange() << "vector";
return QualType();
}
if (VectorType::isVectorSizeTooLarge(vectorSize)) {
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << ArraySize->getSourceRange() << vectorSize
+ << ArraySize->getSourceRange() << "vector" << vectorSize
<< VectorType::getMaxNumElements();
return QualType();
}
@@ -2515,6 +2516,98 @@
return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
}
+/// \brief Build a Matrix Type
+///
+/// Run the required checks for the matrix type
+QualType Sema::BuildMatrixType(QualType T, Expr *NumRows, Expr *NumCols,
+ SourceLocation AttrLoc) {
+ assert(Context.getLangOpts().EnableMatrix &&
+ "Should never build a matrix type when it is disabled");
+
+ if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
+ NumRows->isValueDependent() || NumCols->isValueDependent()) {
+ return Context.getDependentSizedMatrixType(T, NumRows, NumCols, AttrLoc);
+ }
+
+ unsigned MatrixRows = 0;
+ unsigned MatrixColumns = 0;
+
+ { // Handle parameter error checking
+ // Invalid matrix type (must be float or integer)
+ if (!(T->isIntegerType() || T->isRealFloatingType() ||
+ T->isDependentType())) {
+ Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << T;
+ return QualType();
+ }
+
+ // Should this be kept at 32bit even though we're deprecating it?
+ llvm::APSInt ValueRows(32), ValueColumns(32);
+
+ bool const RowsIsInteger =
+ NumRows->isIntegerConstantExpr(ValueRows, Context);
+ bool const ColumnsIsInteger =
+ NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+ auto const RowRange = NumRows->getSourceRange();
+ auto const ColRange = NumCols->getSourceRange();
+
+ // Both are invalid types
+ if (!RowsIsInteger && !ColumnsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+ << ColRange;
+ return QualType();
+ }
+
+ // One or the other are invalid
+ if (!RowsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+ return QualType();
+ }
+
+ // Getting the wrong source range
+ if (!ColumnsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+ return QualType();
+ }
+
+ MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+ MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+
+ // Check Matrix size
+ if (MatrixRows == 0 && MatrixColumns == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size)
+ << "matrix" << RowRange << ColRange;
+ return QualType();
+ }
+ if (MatrixRows == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+ return QualType();
+ }
+ if (MatrixColumns == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+ return QualType();
+ }
+
+ if (MatrixType::isDimensionTooLarge(MatrixRows)) {
+ Diag(AttrLoc, diag::err_attribute_size_too_large)
+ << RowRange << "matrix row" << MatrixRows
+ << MatrixType::getMaxElementsPerDimension();
+ return QualType();
+ }
+
+ if (MatrixType::isDimensionTooLarge(MatrixColumns)) {
+ Diag(AttrLoc, diag::err_attribute_size_too_large)
+ << ColRange << "matrix column" << MatrixColumns
+ << MatrixType::getMaxElementsPerDimension();
+ return QualType();
+ }
+ }
+ return Context.getMatrixType(T, MatrixRows, MatrixColumns);
+}
+
bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
if (T->isArrayType() || T->isFunctionType()) {
Diag(Loc, diag::err_func_returning_array_function)
@@ -7642,6 +7735,71 @@
}
}
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+ Sema &S) {
+ if (!S.getLangOpts().EnableMatrix) {
+ S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+ return;
+ }
+
+ if (Attr.getNumArgs() != 2) {
+ S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+ << Attr << 2;
+ return;
+ }
+
+ Expr *rowsExpr = nullptr;
+ Expr *colsExpr = nullptr;
+
+ // TODO: Refactor parameter extraction into separate function
+ // Get the number of rows
+ if (Attr.isArgIdent(0)) {
+ CXXScopeSpec SS;
+ SourceLocation TemplateKeywordLoc;
+ UnqualifiedId id;
+ id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+ ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+ TemplateKeywordLoc, id, false, false);
+
+ if (Rows.isInvalid()) {
+ // TODO: maybe a good error message would be nice here
+ return;
+ }
+ rowsExpr = Rows.get();
+ } else {
+ assert(Attr.isArgExpr(0) &&
+ "Argument to should either be an identity or expression");
+ rowsExpr = Attr.getArgAsExpr(0);
+ }
+
+ // Get the number of columns
+ if (Attr.isArgIdent(1)) {
+ CXXScopeSpec SS;
+ SourceLocation TemplateKeywordLoc;
+ UnqualifiedId id;
+ id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+ ExprResult Columns = S.ActOnIdExpression(
+ S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+ if (Columns.isInvalid()) {
+ // TODO: a good error message would be nice here
+ return;
+ }
+ rowsExpr = Columns.get();
+ } else {
+ assert(Attr.isArgExpr(1) &&
+ "Argument to should either be an identity or expression");
+ colsExpr = Attr.getArgAsExpr(1);
+ }
+
+ // Create Matrix Type
+ QualType T = S.BuildMatrixType(CurType, rowsExpr, colsExpr, Attr.getLoc());
+ if (!T.isNull()) {
+ CurType = T;
+ }
+}
+
static void HandleLifetimeBoundAttr(TypeProcessingState &State,
QualType &CurType,
ParsedAttr &Attr) {
@@ -7793,6 +7951,11 @@
break;
}
+ case ParsedAttr::AT_MatrixType:
+ HandleMatrixTypeAttr(type, attr, state.getSema());
+ attr.setUsedAsTypeAttr();
+ break;
+
MS_TYPE_ATTRS_CASELIST:
if (!handleMSPointerTypeQualifierAttr(state, attr, type))
attr.setUsedAsTypeAttr();
Index: clang/lib/Sema/SemaTemplateDeduction.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateDeduction.cpp
+++ clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2054,6 +2054,89 @@
return Sema::TDK_NonDeducedMismatch;
}
+ // (clang extension)
+ //
+ // T __attribute__((matrix_type(<integral constant>, <integral
+ // constant>)))
+ // TODO: Allow deduction from matrix type to vector type
+ // TODO: Decide on deduction from vector type to matrix type
+ case Type::Matrix: {
+ const MatrixType *MatrixParam = cast<MatrixType>(Param);
+ // Matrix-DepSizedMatrix deduction
+ if (const DependentSizedMatrixType *MatrixArg =
+ dyn_cast<DependentSizedMatrixType>(Arg)) {
+ // can't check number of elements since the argument is dependent
+ return DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, TDF);
+ }
+ // Matrix-Matrix deduction
+ if (const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg)) {
+ // Check that the dimensions are the same
+ if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+ MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+ return Sema::TDK_NonDeducedMismatch;
+ }
+ // Perform deduction on element types
+ return DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, TDF);
+ }
+ return Sema::TDK_NonDeducedMismatch;
+ }
+
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *MatrixParam =
+ cast<DependentSizedMatrixType>(Param);
+ // DepSizedMatrix - DepSizedMatrix deduction
+ // DepSizedMatrix - Matrix deduction
+ if (const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg)) {
+ // Do deduction on the element types
+ if (Sema::TemplateDeductionResult Result =
+ DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, TDF)) {
+ return Result;
+ }
+
+ // Deduce matrix size if possible
+ NonTypeTemplateParmDecl *RowExprTemplateParam =
+ getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr());
+ NonTypeTemplateParmDecl *ColumnExprTemplateParam =
+ getDeducedParameterFromExpr(Info, MatrixParam->getColumnExpr());
+
+ // TODO: Allow one to fail and the other to succeed in the deduction
+ // Can't deduce either rows or columns, just say everything is fine
+ if (!RowExprTemplateParam || !ColumnExprTemplateParam) {
+ return Sema::TDK_Success;
+ }
+
+ // Unsigned might make more sense
+ llvm::APSInt ArgRows(S.Context.getTypeSize(S.Context.IntTy));
+ ArgRows = MatrixArg->getNumRows();
+
+ // Deduce Rows
+ {
+ Sema::TemplateDeductionResult Res = DeduceNonTypeTemplateArgument(
+ S, TemplateParams, RowExprTemplateParam, ArgRows, S.Context.IntTy,
+ true, Info, Deduced);
+ if (Res != Sema::TDK_Success) {
+ return Res;
+ }
+ }
+
+ // Deduce Columns
+ llvm::APSInt ArgColumns(S.Context.getTypeSize(S.Context.IntTy));
+ ArgColumns = MatrixArg->getNumColumns();
+
+ // Deduce columns
+ return DeduceNonTypeTemplateArgument(
+ S, TemplateParams, ColumnExprTemplateParam, ArgColumns,
+ S.Context.IntTy, true, Info, Deduced);
+ }
+ return Sema::TDK_NonDeducedMismatch;
+ }
+
// (clang extension)
//
// T __attribute__(((address_space(N))))
@@ -5695,6 +5778,24 @@
break;
}
+ case Type::Matrix: {
+ const MatrixType *MatType = cast<MatrixType>(T);
+ MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+ Depth, Used);
+ break;
+ }
+
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+ MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+ Depth, Used);
+ MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+ Used);
+ MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+ Depth, Used);
+ break;
+ }
+
case Type::FunctionProto: {
const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -5867,6 +5867,11 @@
return Visit(T->getElementType());
}
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+ const DependentSizedMatrixType *T) {
+ return Visit(T->getElementType());
+}
+
bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
const DependentAddressSpaceType *T) {
return Visit(T->getPointeeType());
@@ -5885,6 +5890,10 @@
return Visit(T->getElementType());
}
+bool UnnamedLocalNoLinkageFinder::VisitMatrixType(const MatrixType *T) {
+ return Visit(T->getElementType());
+}
+
bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
const FunctionProtoType* T) {
for (const auto &A : T->param_types()) {
Index: clang/lib/Sema/SemaLookup.cpp
===================================================================
--- clang/lib/Sema/SemaLookup.cpp
+++ clang/lib/Sema/SemaLookup.cpp
@@ -2966,6 +2966,7 @@
// These are fundamental types.
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::Complex:
break;
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4248,6 +4248,7 @@
case Type::Complex:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::Record:
case Type::Enum:
case Type::Elaborated:
Index: clang/lib/Frontend/CompilerInvocation.cpp
===================================================================
--- clang/lib/Frontend/CompilerInvocation.cpp
+++ clang/lib/Frontend/CompilerInvocation.cpp
@@ -3315,6 +3315,8 @@
Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
+ Opts.EnableMatrix = Args.hasArg(OPT_fenable_matrix);
+
Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags);
if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) {
@@ -3570,7 +3572,7 @@
InputArgList Args = Opts.ParseArgs(CommandLineArgs, MissingArgIndex,
MissingArgCount, IncludedFlagsBitmask);
LangOptions &LangOpts = *Res.getLangOpts();
-
+ //
// Check for missing argument error.
if (MissingArgCount) {
Diags.Report(diag::err_drv_missing_argument)
Index: clang/lib/Driver/ToolChains/Clang.cpp
===================================================================
--- clang/lib/Driver/ToolChains/Clang.cpp
+++ clang/lib/Driver/ToolChains/Clang.cpp
@@ -4553,6 +4553,13 @@
if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
CmdArgs.push_back("-fdefault-calling-conv=stdcall");
+ if (Args.hasArg(options::OPT_fenable_matrix)) {
+ // enable-matrix is needed by both the LangOpts and by LLVM.
+ CmdArgs.push_back("-fenable-matrix");
+ CmdArgs.push_back("-mllvm");
+ CmdArgs.push_back("-enable-matrix");
+ }
+
CodeGenOptions::FramePointerKind FPKeepKind =
getFramePointerKind(Args, RawTriple);
const char *FPKeepKindStr = nullptr;
Index: clang/lib/CodeGen/ItaniumCXXABI.cpp
===================================================================
--- clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3222,6 +3222,7 @@
// GCC treats vector and complex types as fundamental types.
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::Complex:
case Type::Atomic:
// FIXME: GCC treats block pointers as fundamental types?!
@@ -3457,6 +3458,7 @@
case Type::Builtin:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::Complex:
case Type::BlockPointer:
// Itanium C++ ABI 2.9.5p4:
Index: clang/lib/CodeGen/CodeGenTypes.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenTypes.cpp
+++ clang/lib/CodeGen/CodeGenTypes.cpp
@@ -84,6 +84,13 @@
/// a type. For example, the scalar representation for _Bool is i1, but the
/// memory representation is usually i8 or i32, depending on the target.
llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
+ if (T->isMatrixType()) {
+ const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+ const MatrixType *MT = cast<MatrixType>(Ty);
+ return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+ MT->getNumRows() * MT->getNumColumns());
+ }
+
llvm::Type *R = ConvertType(T);
// If this is a non-bool type, don't map it.
@@ -648,6 +655,12 @@
VT->getNumElements());
break;
}
+ case Type::Matrix: {
+ const MatrixType *MT = cast<MatrixType>(Ty);
+ ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+ MT->getNumRows() * MT->getNumColumns());
+ break;
+ }
case Type::FunctionNoProto:
case Type::FunctionProto:
ResultType = ConvertFunctionTypeInternal(T);
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -268,6 +268,7 @@
case Type::MemberPointer:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Enum:
@@ -2018,6 +2019,7 @@
case Type::Complex:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::Record:
case Type::Enum:
case Type::Elaborated:
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -145,8 +145,19 @@
Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
const Twine &Name, Address *Alloca) {
- return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
- /*ArraySize=*/nullptr, Alloca);
+ Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+ /*ArraySize=*/nullptr, Alloca);
+
+ if (Ty->isMatrixType()) {
+ auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ Result = Address(
+ Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+ Result.getAlignment());
+ }
+ return Result;
}
Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1759,6 +1770,20 @@
}
}
+ if (Ty->isMatrixType()) {
+ auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+ cast<llvm::PointerType>(Addr.getPointer()->getType())
+ ->getElementType());
+ if (ArrayTy) {
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ Addr = Address(
+ Builder.CreateBitCast(Addr.getPointer(), VectorTy->getPointerTo()),
+ Addr.getAlignment());
+ }
+ }
+
Value = EmitToMemory(Value, Ty);
LValue AtomicLValue =
@@ -1812,6 +1837,20 @@
if (LV.isSimple()) {
assert(!LV.getType()->isFunctionType());
+ if (LV.getType()->isMatrixType()) {
+ auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+ cast<llvm::PointerType>(LV.getPointer(*this)->getType())
+ ->getElementType());
+ if (ArrayTy) {
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ LV.setAddress(Address(Builder.CreateBitCast(LV.getPointer(*this),
+ VectorTy->getPointerTo()),
+ LV.getAlignment()));
+ }
+ }
+
// Everything needs a load.
return RValue::get(EmitLoadOfScalar(LV, Loc));
}
Index: clang/lib/CodeGen/CGDebugInfo.h
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.h
+++ clang/lib/CodeGen/CGDebugInfo.h
@@ -190,6 +190,7 @@
llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+ llvm::DIType *CreateType(const MatrixType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);
Index: clang/lib/CodeGen/CGDebugInfo.cpp
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.cpp
+++ clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2713,6 +2713,23 @@
return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
}
+llvm::DIType *CGDebugInfo::CreateType(const MatrixType *Ty,
+ llvm::DIFile *Unit) {
+ llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+ uint64_t Size = CGM.getContext().getTypeSize(Ty);
+ uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+ // Number of Columns, followed by rows
+ llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+ Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+ Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+ llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+
+ // FIXME: Create another debug type for matrices
+ // For the time being, it treats it like a 2D array
+ return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
uint64_t Size;
uint32_t Align;
@@ -3106,6 +3123,8 @@
case Type::ExtVector:
case Type::Vector:
return CreateType(cast<VectorType>(Ty), Unit);
+ case Type::Matrix:
+ return CreateType(cast<MatrixType>(Ty), Unit);
case Type::ObjCObjectPointer:
return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
case Type::ObjCObject:
Index: clang/lib/Basic/Targets/OSTargets.cpp
===================================================================
--- clang/lib/Basic/Targets/OSTargets.cpp
+++ clang/lib/Basic/Targets/OSTargets.cpp
@@ -133,6 +133,9 @@
Builder.defineMacro("__MACH__");
PlatformMinVersion = VersionTuple(Maj, Min, Rev);
+
+ if (Opts.EnableMatrix)
+ Builder.defineMacro("__MATRIX_EXTENSION__", "1");
}
static void addMinGWDefines(const llvm::Triple &Triple, const LangOptions &Opts,
Index: clang/lib/AST/TypePrinter.cpp
===================================================================
--- clang/lib/AST/TypePrinter.cpp
+++ clang/lib/AST/TypePrinter.cpp
@@ -254,6 +254,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
+ case Type::DependentSizedMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Paren:
@@ -718,6 +720,37 @@
OS << ")))";
}
+void TypePrinter::printMatrixBefore(const MatrixType *T, raw_ostream &OS) {
+ // TODO: Fix the spacing between the element type and the __attribute__
+ printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ OS << T->getNumRows() << ", " << T->getNumColumns();
+ OS << ")))";
+}
+
+void TypePrinter::printMatrixAfter(const MatrixType *T, raw_ostream &OS) {
+ printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+ const DependentSizedMatrixType *T, raw_ostream &OS) {
+ printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ if (T->getRowExpr()) {
+ T->getRowExpr()->printPretty(OS, nullptr, Policy);
+ }
+ OS << ", ";
+ if (T->getColumnExpr()) {
+ T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+ }
+ OS << ")))";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+ const DependentSizedMatrixType *T, raw_ostream &OS) {
+ printAfter(T->getElementType(), OS);
+}
+
void
FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
const PrintingPolicy &Policy)
Index: clang/lib/AST/Type.cpp
===================================================================
--- clang/lib/AST/Type.cpp
+++ clang/lib/AST/Type.cpp
@@ -282,6 +282,45 @@
AddrSpaceExpr->Profile(ID, Context, true);
}
+MatrixType::MatrixType(QualType matrixType, unsigned nRows, unsigned nColumns,
+ QualType canonType)
+ : MatrixType(Matrix, matrixType, nRows, nColumns, canonType) {}
+
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, unsigned nRows,
+ unsigned nColumns, QualType canonType)
+ : Type(tc, canonType, matrixType->getDependence()),
+ ElementType(matrixType) {
+ MatrixTypeBits.NumRows = nRows;
+ MatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+ const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+ Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+ : Type(DependentSizedMatrix, CanonicalType,
+ TypeDependence::Dependent | TypeDependence::Instantiation |
+ (ElementType->isVariablyModifiedType()
+ ? TypeDependence::VariablyModified
+ : TypeDependence::None) |
+ (ElementType->containsUnexpandedParameterPack() ||
+ (RowExpr &&
+ RowExpr->containsUnexpandedParameterPack()) ||
+ (ColumnExpr &&
+ ColumnExpr->containsUnexpandedParameterPack())
+ ? TypeDependence::UnexpandedPack
+ : TypeDependence::None)),
+ Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr),
+ ElementType(ElementType), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+ const ASTContext &CTX,
+ QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr) {
+ ID.AddPointer(ElementType.getAsOpaquePtr());
+ RowExpr->Profile(ID, CTX, true);
+ ColumnExpr->Profile(ID, CTX, true);
+}
+
VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType,
VectorKind vecKind)
: VectorType(Vector, vecType, nElements, canonType, vecKind) {}
@@ -938,6 +977,16 @@
return Ctx.getExtVectorType(elementType, T->getNumElements());
}
+ QualType VisitMatrixType(const MatrixType *T) {
+ QualType elementType = recurse(T->getElementType());
+ if (elementType.isNull())
+ return {};
+ if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+ return QualType(T, 0);
+
+ return Ctx.getMatrixType(elementType, T->getNumRows(), T->getNumColumns());
+ }
+
QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
QualType returnType = recurse(T->getReturnType());
if (returnType.isNull())
@@ -1757,6 +1806,14 @@
return Visit(T->getElementType());
}
+ Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+ return Visit(T->getElementType());
+ }
+
+ Type *VisitMatrixType(const MatrixType *T) {
+ return Visit(T->getElementType());
+ }
+
Type *VisitFunctionProtoType(const FunctionProtoType *T) {
if (Syntactic && T->hasTrailingReturn())
return const_cast<FunctionProtoType*>(T);
@@ -3688,6 +3745,8 @@
case Type::Vector:
case Type::ExtVector:
return Cache::get(cast<VectorType>(T)->getElementType());
+ case Type::Matrix:
+ return Cache::get(cast<MatrixType>(T)->getElementType());
case Type::FunctionNoProto:
return Cache::get(cast<FunctionType>(T)->getReturnType());
case Type::FunctionProto: {
@@ -3773,6 +3832,8 @@
case Type::Vector:
case Type::ExtVector:
return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+ case Type::Matrix:
+ return computeTypeLinkageInfo(cast<MatrixType>(T)->getElementType());
case Type::FunctionNoProto:
return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
case Type::FunctionProto: {
@@ -3936,6 +3997,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
+ case Type::DependentSizedMatrix:
case Type::DependentAddressSpace:
case Type::FunctionProto:
case Type::FunctionNoProto:
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -2755,6 +2755,23 @@
<< Range;
}
+void MicrosoftCXXNameMangler::mangleType(const MatrixType *T, Qualifiers quals,
+ SourceRange Range) {
+ DiagnosticsEngine &Diags = Context.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+ "Cannot mangle this matrix type yet");
+ Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+ Qualifiers quals, SourceRange Range) {
+ DiagnosticsEngine &Diags = Context.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(
+ DiagnosticsEngine::Error,
+ "Cannot mangle this dependent-sized matrix type yet");
+ Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
Qualifiers, SourceRange Range) {
DiagnosticsEngine &Diags = Context.getDiags();
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -2065,6 +2065,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
+ case Type::DependentSizedMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Paren:
@@ -3327,6 +3329,20 @@
mangleType(T->getElementType());
}
+void CXXNameMangler::mangleType(const MatrixType *T) {
+ Out << "Dm" << T->getNumRows() << "_" << T->getNumColumns() << '_';
+ mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+ Out << "Dm";
+ mangleExpression(T->getRowExpr());
+ Out << '_';
+ mangleExpression(T->getColumnExpr());
+ Out << '_';
+ mangleType(T->getElementType());
+}
+
void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
SplitQualType split = T->getPointeeType().split();
mangleQualifiers(split.Quals, T);
Index: clang/lib/AST/ExprConstant.cpp
===================================================================
--- clang/lib/AST/ExprConstant.cpp
+++ clang/lib/AST/ExprConstant.cpp
@@ -10307,6 +10307,7 @@
case Type::BlockPointer:
case Type::Vector:
case Type::ExtVector:
+ case Type::Matrix:
case Type::ObjCObject:
case Type::ObjCInterface:
case Type::ObjCObjectPointer:
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,39 @@
break;
}
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+ const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+ // Rows
+ if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+ Mat2->getRowExpr())) {
+ return false;
+ }
+ // Columns
+ if (!IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+ Mat2->getColumnExpr())) {
+ return false;
+ }
+ // Element Type
+ if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+ Mat2->getElementType())) {
+ return false;
+ }
+ return true;
+ }
+
+ case Type::Matrix: {
+ const MatrixType *Mat1 = cast<MatrixType>(T1);
+ const MatrixType *Mat2 = cast<MatrixType>(T2);
+ if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+ Mat2->getElementType()))
+ return false;
+ if (Mat1->getNumRows() != Mat2->getNumRows() ||
+ Mat1->getNumColumns() != Mat2->getNumColumns())
+ return false;
+ break;
+ }
+
case Type::FunctionProto: {
const auto *Proto1 = cast<FunctionProtoType>(T1);
const auto *Proto2 = cast<FunctionProtoType>(T2);
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -1929,6 +1929,18 @@
break;
}
+ case Type::Matrix: {
+ const auto *MT = cast<MatrixType>(T);
+ TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+ // The matrix type is intended to be ABI compatible with arrays with respect
+ // to alignment and size. We use LLVM's array type for storage.
+ Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+ // If the alignment is not a power of 2, round up to the next power of 2.
+ // This happens for non-power-of-2 length vectors.
+ Align = ElementInfo.Width;
+ break;
+ }
+
case Type::Builtin:
switch (cast<BuiltinType>(T)->getKind()) {
default: llvm_unreachable("Unknown builtin type!");
@@ -3350,6 +3362,8 @@
case Type::DependentVector:
case Type::ExtVector:
case Type::DependentSizedExtVector:
+ case Type::Matrix:
+ case Type::DependentSizedMatrix:
case Type::DependentAddressSpace:
case Type::ObjCObject:
case Type::ObjCInterface:
@@ -3761,6 +3775,76 @@
return QualType(New, 0);
}
+/// getMatrixType - Return the unique reference to a matrix type of the
+/// specified element type and size. ElementTy must be a built-in integer or
+/// floating point type.
+QualType ASTContext::getMatrixType(QualType ElementTy, unsigned NumRows,
+ unsigned NumColumns) const {
+ llvm::FoldingSetNodeID ID;
+ MatrixType::Profile(ID, ElementTy, NumRows, NumColumns, Type::Matrix);
+
+ void *InsertPos = nullptr;
+ if (MatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) {
+ return QualType(MTP, 0);
+ }
+
+ QualType Canonical;
+ if (!ElementTy.isCanonical()) {
+ Canonical = getMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+ MatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ assert(!NewIP && "Matrix type shouldn't already exist in the map");
+ (void)NewIP;
+ }
+
+ auto *New = new (*this, TypeAlignment)
+ MatrixType(ElementTy, NumRows, NumColumns, Canonical);
+ MatrixTypes.InsertNode(New, InsertPos);
+ Types.push_back(New);
+ return QualType(New, 0);
+}
+
+// getDependentSizedMatrixType - Return a unique reference to the
+// dependent matrix MatrixElementType must be a builtin type
+QualType ASTContext::getDependentSizedMatrixType(QualType MatrixElementType,
+ Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttrLoc) const {
+ llvm::FoldingSetNodeID ID;
+ DependentSizedMatrixType::Profile(
+ ID, *this, getCanonicalType(MatrixElementType), RowExpr, ColumnExpr);
+
+ void *InsertPos = nullptr;
+ DependentSizedMatrixType *Canon =
+ DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ DependentSizedMatrixType *New;
+ if (Canon) {
+ // Already have a canonical version of the matrix type
+ // Use it as the canonical type for newly-built types
+ New = new (*this, TypeAlignment)
+ DependentSizedMatrixType(*this, MatrixElementType, QualType(Canon, 0),
+ RowExpr, ColumnExpr, AttrLoc);
+ } else {
+ QualType CanonicalMatrixElementType = getCanonicalType(MatrixElementType);
+ if (CanonicalMatrixElementType == MatrixElementType) {
+ New = new (*this, TypeAlignment) DependentSizedMatrixType(
+ *this, MatrixElementType, QualType(), RowExpr, ColumnExpr, AttrLoc);
+ DependentSizedMatrixType *CanonCheck =
+ DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+ (void)CanonCheck;
+ DependentSizedMatrixTypes.InsertNode(New, InsertPos);
+ } else {
+ QualType Canon = getDependentSizedMatrixType(
+ CanonicalMatrixElementType, RowExpr, ColumnExpr, SourceLocation());
+ New = new (*this, TypeAlignment) DependentSizedMatrixType(
+ *this, MatrixElementType, Canon, RowExpr, ColumnExpr, AttrLoc);
+ }
+ }
+ Types.push_back(New);
+ return QualType(New, 0);
+}
+
QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
Expr *AddrSpaceExpr,
SourceLocation AttrLoc) const {
@@ -7275,6 +7359,11 @@
*NotEncodedT = T;
return;
+ case Type::Matrix:
+ if (NotEncodedT)
+ *NotEncodedT = T;
+ return;
+
// We could see an undeduced auto type here during error recovery.
// Just ignore it.
case Type::Auto:
@@ -8153,6 +8242,15 @@
LHS->getNumElements() == RHS->getNumElements();
}
+/// areCompatMatrixTypes - Return true if the two specified vector types are
+/// compatible.
+static bool areCompatMatrixTypes(const MatrixType *LHS, const MatrixType *RHS) {
+ assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+ return LHS->getElementType() == RHS->getElementType() &&
+ LHS->getNumRows() == RHS->getNumRows() &&
+ LHS->getNumColumns() == RHS->getNumColumns();
+}
+
bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
QualType SecondVec) {
assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9349,6 +9447,11 @@
RHSCan->castAs<VectorType>()))
return LHS;
return {};
+ case Type::Matrix:
+ if (areCompatMatrixTypes(LHSCan->castAs<MatrixType>(),
+ RHSCan->castAs<MatrixType>()))
+ return LHS;
+ return {};
case Type::ObjCObject: {
// Check if the types are assignment compatible.
// FIXME: This should be type compatibility, e.g. whether
Index: clang/include/clang/Serialization/TypeBitCodes.def
===================================================================
--- clang/include/clang/Serialization/TypeBitCodes.def
+++ clang/include/clang/Serialization/TypeBitCodes.def
@@ -58,5 +58,7 @@
TYPE_BIT_CODE(DependentAddressSpace, DEPENDENT_ADDRESS_SPACE, 47)
TYPE_BIT_CODE(DependentVector, DEPENDENT_SIZED_VECTOR, 48)
TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49)
+TYPE_BIT_CODE(Matrix, MATRIX, 50)
+TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 51)
#undef TYPE_BIT_CODE
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1626,6 +1626,9 @@
QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc);
QualType BuildExtVectorType(QualType T, Expr *ArraySize,
SourceLocation AttrLoc);
+ QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
+ SourceLocation AttrLoc);
+
QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
SourceLocation AttrLoc);
Index: clang/include/clang/Driver/Options.td
===================================================================
--- clang/include/clang/Driver/Options.td
+++ clang/include/clang/Driver/Options.td
@@ -1992,6 +1992,10 @@
def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>,
Flags<[CC1Option]>;
+def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
+ Flags<[CC1Option]>,
+ HelpText<"Enable matrix data type and related builtin functions">;
+
def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">,
Group<f_Group>, Flags<[CC1Option]>,
HelpText<"Treat editor placeholders as valid source code">;
Index: clang/include/clang/Basic/TypeNodes.td
===================================================================
--- clang/include/clang/Basic/TypeNodes.td
+++ clang/include/clang/Basic/TypeNodes.td
@@ -65,10 +65,12 @@
def VariableArrayType : TypeNode<ArrayType>;
def DependentSizedArrayType : TypeNode<ArrayType>, AlwaysDependent;
def DependentSizedExtVectorType : TypeNode<Type>, AlwaysDependent;
+def DependentSizedMatrixType : TypeNode<Type>, AlwaysDependent;
def DependentAddressSpaceType : TypeNode<Type>, AlwaysDependent;
def VectorType : TypeNode<Type>;
def DependentVectorType : TypeNode<Type>, AlwaysDependent;
def ExtVectorType : TypeNode<VectorType>;
+def MatrixType : TypeNode<Type>;
def FunctionType : TypeNode<Type, 1>;
def FunctionProtoType : TypeNode<FunctionType>;
def FunctionNoProtoType : TypeNode<FunctionType>;
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -351,6 +351,8 @@
LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors")
+LANGOPT(EnableMatrix, 1, 0, "Enable or disable the builtin matrix type")
+
COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2764,6 +2764,7 @@
def err_attribute_too_few_arguments : Error<
"%0 attribute takes at least %1 argument%s1">;
def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
+def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
def err_attribute_bad_neon_vector_size : Error<
"Neon vector size must be 64 or 128 bits">;
def err_attribute_requires_positive_integer : Error<
@@ -2867,8 +2868,8 @@
"init methods must return an object pointer type, not %0">;
def err_attribute_invalid_size : Error<
"vector size not an integral multiple of component size">;
-def err_attribute_zero_size : Error<"zero vector size">;
-def err_attribute_size_too_large : Error<"vector size too large. Size is %0 when the maximum allowed is %1">;
+def err_attribute_zero_size : Error<"zero %0 size">;
+def err_attribute_size_too_large : Error<"%0 size too large. Size is %1 when the maximum allowed is %2">;
def err_typecheck_vector_not_convertable_implict_truncation : Error<
"cannot convert between %select{scalar|vector}0 type %1 and vector type"
" %2 as implicit conversion would cause truncation">;
@@ -10685,6 +10686,9 @@
"%select{non-pointer|function pointer|void pointer}0 argument to "
"'__builtin_launder' is not allowed">;
+def err_builtin_matrix_disabled: Error<
+ "Builtin matrix support is disabled. Pass -fenable-matrix to enable it.">;
+
def err_preserve_field_info_not_field : Error<
"__builtin_preserve_field_info argument %0 not a field access">;
def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -2468,6 +2468,15 @@
let Documentation = [Undocumented];
}
+def MatrixType : TypeAttr {
+ let Spellings = [Clang<"matrix_type">];
+ let Subjects = SubjectList<[TypedefName], ErrorDiag>;
+ let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
+ let Documentation = [Undocumented];
+ let ASTNode = 0;
+ let PragmaAttributeSupport = 0;
+}
+
def Visibility : InheritableAttr {
let Clone = 0;
let Spellings = [GCC<"visibility">];
Index: clang/include/clang/AST/TypeProperties.td
===================================================================
--- clang/include/clang/AST/TypeProperties.td
+++ clang/include/clang/AST/TypeProperties.td
@@ -224,6 +224,41 @@
}]>;
}
+let Class = MatrixType in {
+ def : Property<"elementType", QualType> {
+ let Read = [{ node->getElementType() }];
+ }
+ def : Property<"numRows", UInt32> {
+ let Read = [{ node->getNumRows() }];
+ }
+ def : Property<"numColumns", UInt32> {
+ let Read = [{ node->getNumColumns() }];
+ }
+
+ def : Creator<[{
+ return ctx.getMatrixType(elementType, numRows, numColumns);
+ }]>;
+}
+
+let Class = DependentSizedMatrixType in {
+ def : Property<"elementType", QualType> {
+ let Read = [{ node->getElementType() }];
+ }
+ def : Property<"rows", ExprRef> {
+ let Read = [{ node->getRowExpr() }];
+ }
+ def : Property<"columns", ExprRef> {
+ let Read = [{ node->getColumnExpr() }];
+ }
+ def : Property<"attributeLoc", SourceLocation> {
+ let Read = [{ node->getAttributeLoc() }];
+ }
+
+ def : Creator<[{
+ return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
+ }]>;
+}
+
let Class = FunctionType in {
def : Property<"returnType", QualType> {
let Read = [{ node->getReturnType() }];
Index: clang/include/clang/AST/TypeLoc.h
===================================================================
--- clang/include/clang/AST/TypeLoc.h
+++ clang/include/clang/AST/TypeLoc.h
@@ -1774,6 +1774,18 @@
DependentSizedExtVectorType> {
};
+// Same as VectorType: FIXME: attribute locations.
+class MatrixTypeLoc
+ : public InheritingConcreteTypeLoc<TypeSpecTypeLoc, MatrixTypeLoc,
+ MatrixType> {};
+
+// Same as VectorType: FIXME: attribute locations. Also look into making this
+// a subtype of the MatrixTypeLoc
+class DependentSizedMatrixTypeLoc
+ : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
+ DependentSizedMatrixTypeLoc,
+ DependentSizedMatrixType> {};
+
// FIXME: location of the '_Complex' keyword.
class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
ComplexTypeLoc,
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -1657,6 +1657,19 @@
enum { MaxNumElements = (1 << (29 - NumTypeBits)) - 1 };
};
+ class MatrixTypeBitfields {
+ friend class MatrixType;
+
+ unsigned : NumTypeBits;
+
+ // Number of rows and columns. Using 20 bits allows supporting very large
+ // matrixes, while keeping 24 bits to accommodate NumTypeBits.
+ unsigned NumRows : 20;
+ unsigned NumColumns : 20;
+
+ enum { MaxElementsPerDimension = (1 << 20) - 1 };
+ };
+
class AttributedTypeBitfields {
friend class AttributedType;
@@ -1766,6 +1779,7 @@
TypeWithKeywordBitfields TypeWithKeywordBits;
ElaboratedTypeBitfields ElaboratedTypeBits;
VectorTypeBitfields VectorTypeBits;
+ MatrixTypeBitfields MatrixTypeBits;
SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
DependentTemplateSpecializationTypeBitfields
@@ -2024,6 +2038,7 @@
bool isComplexIntegerType() const; // GCC _Complex integer type.
bool isVectorType() const; // GCC vector type.
bool isExtVectorType() const; // Extended vector type.
+ bool isMatrixType() const;
bool isDependentAddressSpaceType() const; // value-dependent address space qualifier
bool isObjCObjectPointerType() const; // pointer to ObjC object
bool isObjCRetainableType() const; // ObjC object or block pointer
@@ -3400,6 +3415,117 @@
}
};
+/// MatrixType - This type is created using
+/// __attribute__((matrix_type(rows, columns))), where "rows" is the
+/// number of rows and "columns" is the number of columns.
+class MatrixType : public Type, public llvm::FoldingSetNode {
+protected:
+ friend class ASTContext;
+
+ QualType ElementType;
+
+ // MatrixElementType: The type of the elements in the matrix
+ // NRows: Number of rows
+ // NColumns: Number of columns
+ // CanonElementType: Canonical element type (if the matrix type is not
+ // canonical)
+ MatrixType(QualType MatrixElementType, unsigned NRows, unsigned NColumns,
+ QualType CanonElementType);
+
+ // typeClass: The typeclass (defined in TypeNodes.def)
+ // MatrixElementType: The type of elements in the matrix
+ // NRows: The number of rows
+ // NColumns: The number of columns
+ // CanonElementType: Canonical type (if the matrixType is not canonical)
+ MatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
+ unsigned NColumns, QualType CanonElementType);
+
+public:
+ // The type of the elements being stored in the matrix
+ QualType getElementType() const { return ElementType; }
+
+ // The number of rows in the matrix
+ unsigned getNumRows() const { return MatrixTypeBits.NumRows; }
+
+ // The number of columns in the matrix
+ unsigned getNumColumns() const { return MatrixTypeBits.NumColumns; }
+
+ unsigned getNumElementsFlattened() const {
+ return MatrixTypeBits.NumRows * MatrixTypeBits.NumColumns;
+ }
+
+ // Check if the dimensions of the matrix fit in data storage type
+ static bool isDimensionTooLarge(unsigned NumElements) {
+ return NumElements > MatrixTypeBitfields::MaxElementsPerDimension;
+ }
+
+ static unsigned getMaxElementsPerDimension() {
+ return MatrixTypeBitfields::MaxElementsPerDimension;
+ }
+
+ bool isSugared() const { return false; }
+ QualType desugar() const { return QualType(this, 0); }
+
+ void Profile(llvm::FoldingSetNodeID &ID) {
+ Profile(ID, getElementType(), getNumRows(), getNumColumns(),
+ getTypeClass());
+ }
+
+ static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
+ unsigned NumRows, unsigned NumColumns,
+ TypeClass TypeClass) {
+ ID.AddPointer(ElementType.getAsOpaquePtr());
+ ID.AddInteger(NumRows);
+ ID.AddInteger(NumColumns);
+ ID.AddInteger(TypeClass);
+ }
+
+ static bool classof(const Type *T) {
+ return T->getTypeClass() == Matrix ||
+ T->getTypeClass() == DependentSizedMatrix;
+ }
+};
+
+/// DependentSizedMatrixType - Represents a matrix type where the type
+/// and size is dependnt on a template.
+///
+class DependentSizedMatrixType : public Type, public llvm::FoldingSetNode {
+ friend class ASTContext;
+
+ const ASTContext &Context;
+ Expr *RowExpr;
+ Expr *ColumnExpr;
+
+ /// The element type of the matrix
+ QualType ElementType;
+
+ SourceLocation loc;
+
+ DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
+ QualType CanonicalType, Expr *RowExpr,
+ Expr *ColumnExpr, SourceLocation loc);
+
+public:
+ QualType getElementType() const { return ElementType; }
+ Expr *getRowExpr() const { return RowExpr; }
+ Expr *getColumnExpr() const { return ColumnExpr; }
+ SourceLocation getAttributeLoc() const { return loc; }
+
+ bool isSugared() const { return false; }
+ QualType desugar() const { return QualType(this, 0); }
+
+ static bool classof(const Type *T) {
+ return T->getTypeClass() == DependentSizedMatrix;
+ }
+
+ void Profile(llvm::FoldingSetNodeID &ID) {
+ Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
+ }
+
+ static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
+ QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
+};
+
/// FunctionType - C99 6.7.5.3 - Function Declarators. This is the common base
/// class of FunctionNoProtoType and FunctionProtoType.
class FunctionType : public Type {
@@ -6557,6 +6683,10 @@
return isa<ExtVectorType>(CanonicalType);
}
+inline bool Type::isMatrixType() const {
+ return isa<MatrixType>(CanonicalType);
+}
+
inline bool Type::isDependentAddressSpaceType() const {
return isa<DependentAddressSpaceType>(CanonicalType);
}
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1006,6 +1006,16 @@
DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
+DEF_TRAVERSE_TYPE(MatrixType, { TRY_TO(TraverseType(T->getElementType())); })
+
+DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
+ if (T->getRowExpr())
+ TRY_TO(TraverseStmt(T->getRowExpr()));
+ if (T->getColumnExpr())
+ TRY_TO(TraverseStmt(T->getColumnExpr()));
+ TRY_TO(TraverseType(T->getElementType()));
+})
+
DEF_TRAVERSE_TYPE(FunctionNoProtoType,
{ TRY_TO(TraverseType(T->getReturnType())); })
@@ -1254,6 +1264,21 @@
TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
})
+// Same as VectorType: FIXME: MatrixTypeLoc is unfinished
+DEF_TRAVERSE_TYPELOC(MatrixType, {
+ TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
+DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
+ if (TL.getTypePtr()->getRowExpr()) {
+ TRY_TO(TraverseStmt(TL.getTypePtr()->getRowExpr()));
+ }
+ if (TL.getTypePtr()->getColumnExpr()) {
+ TRY_TO(TraverseStmt(TL.getTypePtr()->getColumnExpr()));
+ }
+ TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
{ TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
Index: clang/include/clang/AST/ASTContext.h
===================================================================
--- clang/include/clang/AST/ASTContext.h
+++ clang/include/clang/AST/ASTContext.h
@@ -193,6 +193,8 @@
DependentAddressSpaceTypes;
mutable llvm::FoldingSet<VectorType> VectorTypes;
mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
+ mutable llvm::FoldingSet<MatrixType> MatrixTypes;
+ mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
FunctionProtoTypes;
@@ -1309,6 +1311,21 @@
Expr *SizeExpr,
SourceLocation AttrLoc) const;
+ /// Return the unique reference to the matrix type of the specified element
+ /// type and size
+ ///
+ /// \pre \p MatrixType must be a built-in type.
+ QualType getMatrixType(QualType MatrixType, unsigned NumRows,
+ unsigned NumColumns) const;
+
+ /// Return the unique reference to the matrix type of the specified element
+ /// type and size
+ ///
+ /// \pre \p MatrixElementType must be a built-in type.
+ QualType getDependentSizedMatrixType(QualType MatrixElementType,
+ Expr *RowExpr, Expr *ColumnExpr,
+ SourceLocation AttrLoc) const;
+
QualType getDependentAddressSpaceType(QualType PointeeType,
Expr *AddrSpaceExpr,
SourceLocation AttrLoc) const;
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits