fhahn created this revision.
Herald added a reviewer: martong.
Herald added subscribers: tschuett, arphaman.
Herald added a project: clang.

This patch adds a matrix type to Clang as described in
"Matrix Support in Clang" on cfe-dev [1]. The patch is not intended for
review yet, just to provide an idea how the implementation would look
like.

One aspect in particular I would appreciate feedback on is how to best
ensure matrix type values are aligned the same as pointers to the
element type, while using LLVM's vector type to lower operations.

The main problem is struct layouting, where LLVM's vector type has a
larger alignment than desired.

To work around that fact, the patch uses array types as storage types for
matrix values, but vector types in other contexts. After loading/before
storing, we bitcast between array type and vector type. Alternatively
we could opt for generating packed LLVM structs.

The builtins will be added in separate, follow-on patches.

[1] http://lists.llvm.org/pipermail/cfe-dev/2019-December/064141.html


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72281

Files:
  clang/include/clang/AST/ASTContext.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/Type.h
  clang/include/clang/AST/TypeLoc.h
  clang/include/clang/AST/TypeProperties.td
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Basic/LangOptions.def
  clang/include/clang/Basic/TypeNodes.td
  clang/include/clang/Driver/Options.td
  clang/include/clang/Sema/Sema.h
  clang/include/clang/Serialization/TypeBitCodes.def
  clang/lib/AST/ASTContext.cpp
  clang/lib/AST/ASTStructuralEquivalence.cpp
  clang/lib/AST/ExprConstant.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/AST/MicrosoftMangle.cpp
  clang/lib/AST/Type.cpp
  clang/lib/AST/TypePrinter.cpp
  clang/lib/CodeGen/CGDebugInfo.cpp
  clang/lib/CodeGen/CGDebugInfo.h
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CodeGenFunction.cpp
  clang/lib/CodeGen/CodeGenTypes.cpp
  clang/lib/CodeGen/ItaniumCXXABI.cpp
  clang/lib/Driver/ToolChains/Clang.cpp
  clang/lib/Frontend/CompilerInvocation.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaLookup.cpp
  clang/lib/Sema/SemaTemplate.cpp
  clang/lib/Sema/SemaTemplateDeduction.cpp
  clang/lib/Sema/SemaType.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/CodeGen/matrix-type.c
  clang/test/CodeGenCXX/matrix-type.cpp
  clang/test/SemaCXX/matrix-type.cpp
  clang/tools/libclang/CIndex.cpp

Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -1792,6 +1792,8 @@
 DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type)
 DEFAULT_TYPELOC_IMPL(Vector, Type)
 DEFAULT_TYPELOC_IMPL(ExtVector, VectorType)
+DEFAULT_TYPELOC_IMPL(Matrix, Type)
+DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, Type)
 DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(Record, TagType)
Index: clang/test/SemaCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/matrix-type.cpp
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+
+using matrix_double_t = double __attribute__((matrix_type(6, 6)));
+using matrix_float_t = float __attribute__((matrix_type(6, 6)));
+using matrix_int_t = int __attribute__((matrix_type(6, 6)));
+
+void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
+  using matrix1_t = int __attribute__((matrix_type(Rows, 1)));    // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix3_t = int __attribute__((matrix_type(C, C)));       // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix4_t = int __attribute__((matrix_type(-1, 1)));      // expected-error{{vector size too large}}
+  using matrix5_t = int __attribute__((matrix_type(1, -1)));      // expected-error{{vector size too large}}
+  using matrix6_t = int __attribute__((matrix_type(0, 1)));       // expected-error{{zero vector size}}
+  using matrix7_t = int __attribute__((matrix_type(1, 0)));       // expected-error{{zero vector size}}
+  using matrix7_t = int __attribute__((matrix_type(char, 0)));    // expected-error{{expected '(' for function-style cast or type construction}}
+}
+
+struct S1 {};
+
+void matrix_unsupported_element_type() {
+  using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}}
+  using matrix2_t = S1 __attribute__((matrix_type(1, 1)));    // expected-error{{invalid matrix element type 'S1'}}
+}
+
+template <typename T> // expected-note{{declared here}}
+void matrix_template_1() {
+  using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}}
+}
+
+template <class C> // expected-note{{declared here}}
+void matrix_template_2() {
+  using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}}
+}
+
+template <unsigned Rows, unsigned Cols>
+void matrix_template_3() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero vector size}}
+}
+
+void instantiate_template_3() {
+  matrix_template_3<1, 10>();
+  matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}}
+}
+
+template <int Rows, unsigned Cols>
+void matrix_template_4() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{vector size too large}}
+}
+
+void instantiate_template_4() {
+  matrix_template_4<2, 10>();
+  matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}}
+}
Index: clang/test/CodeGenCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCXX/matrix-type.cpp
@@ -0,0 +1,176 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @_Z10load_storePDm5_5_dS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @_Z17parameter_passingDm3_3_fPS_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @_Z13return_matrixPDm3_3_f(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+struct Matrix {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+};
+
+void matrix_struct_pointers(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
+
+void matrix_struct_reference(Matrix &a, Matrix &b) {
+  // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+class MatrixClass {
+public:
+  int Tmp1;
+  fx3x4_t Data;
+  long Tmp2;
+};
+
+void matrix_class_reference(MatrixClass &a, MatrixClass &b) {
+  // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+class MatrixClassTemplate {
+public:
+  using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols)));
+  int Tmp1;
+  MatrixTy Data;
+  long Tmp2;
+};
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) {
+  b.Data = a.Data;
+}
+
+MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) {
+  // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret %agg.result, float* %Data)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Data.addr = alloca float*, align 8
+  // CHECK-NEXT:    %Arg = alloca %class.MatrixClassTemplate, align 8
+  // CHECK-NEXT:    store float* %Data, float** %Data.addr, align 8
+  // CHECK-NEXT:    %0 = load float*, float** %Data.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast float* %0 to [150 x float]*
+  // CHECK-NEXT:    %2 = bitcast [150 x float]* %1 to <150 x float>*
+  // CHECK-NEXT:    %3 = load <150 x float>, <150 x float>* %2, align 4
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %3, <150 x float>* %4, align 4
+  // CHECK-NEXT:    call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [150 x float]* %Data to <150 x float>*
+  // CHECK-NEXT:    %2 = load <150 x float>, <150 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %2, <150 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+
+  MatrixClassTemplate<float, 10, 15> Result, Arg;
+  Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data);
+  matrix_template_reference(Arg, Result);
+  return Result;
+}
Index: clang/test/CodeGen/matrix-type.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/matrix-type.c
@@ -0,0 +1,79 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @load_store(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @parameter_passing(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @return_matrix
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+typedef struct {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+} Matrix;
+
+void matrix_struct(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @matrix_struct(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,15 @@
   Record.AddSourceLocation(TL.getNameLoc());
 }
 
+void TypeLocWriter::VisitMatrixTypeLoc(MatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getNameLoc());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getNameLoc());
+}
+
 void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   Record.AddSourceLocation(TL.getLocalRangeBegin());
   Record.AddSourceLocation(TL.getLParenLoc());
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -6522,6 +6522,15 @@
   TL.setNameLoc(readSourceLocation());
 }
 
+void TypeLocReader::VisitMatrixTypeLoc(MatrixTypeLoc TL) {
+  TL.setNameLoc(readSourceLocation());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  TL.setNameLoc(readSourceLocation());
+}
+
 void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   TL.setLocalRangeBegin(readSourceLocation());
   TL.setLParenLoc(readSourceLocation());
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -868,6 +868,16 @@
                                               Expr *SizeExpr,
                                               SourceLocation AttributeLoc);
 
+  /// Build a new matrix type given the element type and dimensions.
+  QualType RebuildMatrixType(QualType ElementType, unsigned NumRows,
+                             unsigned NumColumns);
+
+  /// Build a new matrix type given the type and dependently-defined
+  /// dimensions.
+  QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+                                           Expr *ColumnExpr,
+                                           SourceLocation AttributeLoc);
+
   /// Build a new DependentAddressSpaceType or return the pointee
   /// type variable with the correct address space (retrieved from
   /// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5015,6 +5025,65 @@
   return Result;
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformMatrixType(TypeLocBuilder &TLB,
+                                                     MatrixTypeLoc TL) {
+  const MatrixType *T = TL.getTypePtr();
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull())
+    return QualType();
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+    Result = getDerived().RebuildMatrixType(ElementType, T->getNumRows(),
+                                            T->getNumColumns());
+    if (Result.isNull())
+      return QualType();
+  }
+
+  MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+  NewTL.setNameLoc(TL.getNameLoc());
+
+  return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+    TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+  const DependentSizedMatrixType *T = TL.getTypePtr();
+
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull()) {
+    return QualType();
+  }
+
+  EnterExpressionEvaluationContext Unevaluated(
+      SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+  ExprResult Rows = getDerived().TransformExpr(T->getRowExpr());
+  ExprResult Cols = getDerived().TransformExpr(T->getColumnExpr());
+
+  QualType Result = TL.getType();
+  // TODO: Finish this
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+      Rows.get() != T->getRowExpr() || Cols.get() != T->getColumnExpr()) {
+    Result = getDerived().RebuildDependentSizedMatrixType(
+        ElementType, Rows.get(), Cols.get(), T->getAttributeLoc());
+
+    if (Result.isNull())
+      return QualType();
+  }
+
+  if (isa<DependentSizedMatrixType>(Result)) {
+    DependentSizedMatrixTypeLoc NewTL =
+        TLB.push<DependentSizedMatrixTypeLoc>(Result);
+    NewTL.setNameLoc(TL.getNameLoc());
+  } else {
+    MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+    NewTL.setNameLoc(TL.getNameLoc());
+  }
+  return Result;
+}
+
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
     TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
@@ -13084,6 +13153,21 @@
   return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildMatrixType(QualType ElementType,
+                                                   unsigned NumRows,
+                                                   unsigned NumColumns) {
+  return SemaRef.Context.getMatrixType(ElementType, NumRows, NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+    QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+    SourceLocation AttributeLoc) {
+  return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+                                 AttributeLoc);
+}
+
 template<typename Derived>
 QualType TreeTransform<Derived>::RebuildFunctionProtoType(
     QualType T,
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -2467,6 +2467,101 @@
   return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
 }
 
+/// \brief Build a Matrix Type
+///
+/// Run the required checks for the matrix type
+QualType Sema::BuildMatrixType(QualType T, Expr *NumRows, Expr *NumCols,
+                               SourceLocation AttrLoc) {
+  assert(Context.getLangOpts().EnableMatrix &&
+         "Should never build a matrix type when it is disabled");
+
+  if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
+      NumRows->isValueDependent() || NumCols->isValueDependent()) {
+    return Context.getDependentSizedMatrixType(T, NumRows, NumCols, AttrLoc);
+  }
+
+  unsigned MatrixRows = 0;
+  unsigned MatrixColumns = 0;
+
+  { // Handle parameter error checking
+    // Invalid matrix type (must be float or integer)
+    if (!(T->isIntegerType() || T->isRealFloatingType() ||
+          T->isDependentType())) {
+      Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << T;
+      return QualType();
+    }
+
+    // Should this be kept at 32bit even though we're deprecating it?
+    llvm::APSInt ValueRows(32), ValueColumns(32);
+
+    bool const RowsIsInteger =
+        NumRows->isIntegerConstantExpr(ValueRows, Context);
+    bool const ColumnsIsInteger =
+        NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+    auto const RowRange = NumRows->getSourceRange();
+    auto const ColRange = NumCols->getSourceRange();
+
+    // Both are invalid types
+    if (!RowsIsInteger && !ColumnsIsInteger) {
+      Diag(AttrLoc, diag::err_attribute_argument_type)
+          << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+          << ColRange;
+      return QualType();
+    }
+
+    // One or the other are invalid
+    if (!RowsIsInteger) {
+      Diag(AttrLoc, diag::err_attribute_argument_type)
+          << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+      return QualType();
+    }
+
+    // Getting the wrong source range
+    if (!ColumnsIsInteger) {
+      Diag(AttrLoc, diag::err_attribute_argument_type)
+          << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+      return QualType();
+    }
+
+    MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+    MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+
+    // Check Matrix size
+    if (MatrixRows == 0 && MatrixColumns == 0) {
+      Diag(AttrLoc, diag::err_attribute_zero_size)
+          << "matrix" << RowRange << ColRange;
+      return QualType();
+    }
+    if (MatrixRows == 0) {
+      Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+      return QualType();
+    }
+    if (MatrixColumns == 0) {
+      Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+      return QualType();
+    }
+
+    if (VectorType::isVectorSizeTooLarge(MatrixRows) &&
+        VectorType::isVectorSizeTooLarge(MatrixColumns)) {
+      Diag(AttrLoc, diag::err_attribute_size_too_large)
+          << "matrix" << RowRange << ColRange;
+      return QualType();
+    }
+
+    if (VectorType::isVectorSizeTooLarge(MatrixRows)) {
+      Diag(AttrLoc, diag::err_attribute_size_too_large) << "matrix" << RowRange;
+      return QualType();
+    }
+
+    if (VectorType::isVectorSizeTooLarge(MatrixColumns)) {
+      Diag(AttrLoc, diag::err_attribute_size_too_large) << "matrix" << ColRange;
+      return QualType();
+    }
+  }
+  return Context.getMatrixType(T, MatrixRows, MatrixColumns);
+}
+
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
   if (T->isArrayType() || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
@@ -7415,6 +7510,71 @@
   }
 }
 
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+                                 Sema &S) {
+  if (!S.getLangOpts().EnableMatrix) {
+    S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+    return;
+  }
+
+  if (Attr.getNumArgs() != 2) {
+    S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+        << Attr << 2;
+    return;
+  }
+
+  Expr *rowsExpr = nullptr;
+  Expr *colsExpr = nullptr;
+
+  // TODO: Refactor parameter extraction into separate function
+  // Get the number of rows
+  if (Attr.isArgIdent(0)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+    ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+                                          TemplateKeywordLoc, id, false, false);
+
+    if (Rows.isInvalid()) {
+      // TODO: maybe a good error message would be nice here
+      return;
+    }
+    rowsExpr = Rows.get();
+  } else {
+    assert(Attr.isArgExpr(0) &&
+           "Argument to should either be an identity or expression");
+    rowsExpr = Attr.getArgAsExpr(0);
+  }
+
+  // Get the number of columns
+  if (Attr.isArgIdent(1)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+    ExprResult Columns = S.ActOnIdExpression(
+        S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+    if (Columns.isInvalid()) {
+      // TODO: a good error message would be nice here
+      return;
+    }
+    rowsExpr = Columns.get();
+  } else {
+    assert(Attr.isArgExpr(1) &&
+           "Argument to should either be an identity or expression");
+    colsExpr = Attr.getArgAsExpr(1);
+  }
+
+  // Create Matrix Type
+  QualType T = S.BuildMatrixType(CurType, rowsExpr, colsExpr, Attr.getLoc());
+  if (!T.isNull()) {
+    CurType = T;
+  }
+}
+
 static void HandleLifetimeBoundAttr(TypeProcessingState &State,
                                     QualType &CurType,
                                     ParsedAttr &Attr) {
@@ -7561,6 +7721,11 @@
       break;
     }
 
+    case ParsedAttr::AT_MatrixType:
+      HandleMatrixTypeAttr(type, attr, state.getSema());
+      attr.setUsedAsTypeAttr();
+      break;
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();
Index: clang/lib/Sema/SemaTemplateDeduction.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateDeduction.cpp
+++ clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2043,6 +2043,89 @@
       return Sema::TDK_NonDeducedMismatch;
     }
 
+    //     (clang extension)
+    //
+    //     T __attribute__((matrix_type(<integral constant>, <integral
+    //     constant>)))
+    //     TODO: Allow deduction from matrix type to vector type
+    //     TODO: Decide on deduction from vector type to matrix type
+    case Type::Matrix: {
+      const MatrixType *MatrixParam = cast<MatrixType>(Param);
+      // Matrix-DepSizedMatrix deduction
+      if (const DependentSizedMatrixType *MatrixArg =
+              dyn_cast<DependentSizedMatrixType>(Arg)) {
+        // can't check number of elements since the argument is dependent
+        return DeduceTemplateArgumentsByTypeMatch(
+            S, TemplateParams, MatrixParam->getElementType(),
+            MatrixArg->getElementType(), Info, Deduced, TDF);
+      }
+      // Matrix-Matrix deduction
+      if (const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg)) {
+        // Check that the dimensions are the same
+        if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+            MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+          return Sema::TDK_NonDeducedMismatch;
+        }
+        // Perform deduction on element types
+        return DeduceTemplateArgumentsByTypeMatch(
+            S, TemplateParams, MatrixParam->getElementType(),
+            MatrixArg->getElementType(), Info, Deduced, TDF);
+      }
+      return Sema::TDK_NonDeducedMismatch;
+    }
+
+    case Type::DependentSizedMatrix: {
+      const DependentSizedMatrixType *MatrixParam =
+          cast<DependentSizedMatrixType>(Param);
+      // DepSizedMatrix - DepSizedMatrix deduction
+      // DepSizedMatrix - Matrix deduction
+      if (const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg)) {
+        // Do deduction on the element types
+        if (Sema::TemplateDeductionResult Result =
+                DeduceTemplateArgumentsByTypeMatch(
+                    S, TemplateParams, MatrixParam->getElementType(),
+                    MatrixArg->getElementType(), Info, Deduced, TDF)) {
+          return Result;
+        }
+
+        // Deduce matrix size if possible
+        NonTypeTemplateParmDecl *RowExprTemplateParam =
+            getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr());
+        NonTypeTemplateParmDecl *ColumnExprTemplateParam =
+            getDeducedParameterFromExpr(Info, MatrixParam->getColumnExpr());
+
+        // TODO: Allow one to fail and the other to succeed in the deduction
+        // Can't deduce either rows or columns, just say everything is fine
+        if (!RowExprTemplateParam || !ColumnExprTemplateParam) {
+          return Sema::TDK_Success;
+        }
+
+        // Unsigned might make more sense
+        llvm::APSInt ArgRows(S.Context.getTypeSize(S.Context.IntTy));
+        ArgRows = MatrixArg->getNumRows();
+
+        // Deduce Rows
+        {
+          Sema::TemplateDeductionResult Res = DeduceNonTypeTemplateArgument(
+              S, TemplateParams, RowExprTemplateParam, ArgRows, S.Context.IntTy,
+              true, Info, Deduced);
+          if (Res != Sema::TDK_Success) {
+            return Res;
+          }
+        }
+
+        // Deduce Columns
+        llvm::APSInt ArgColumns(S.Context.getTypeSize(S.Context.IntTy));
+        ArgColumns = MatrixArg->getNumColumns();
+
+        // Deduce columns
+        return DeduceNonTypeTemplateArgument(
+            S, TemplateParams, ColumnExprTemplateParam, ArgColumns,
+            S.Context.IntTy, true, Info, Deduced);
+      }
+      return Sema::TDK_NonDeducedMismatch;
+    }
+
     //     (clang extension)
     //
     //     T __attribute__(((address_space(N))))
@@ -5607,6 +5690,24 @@
     break;
   }
 
+  case Type::Matrix: {
+    const MatrixType *MatType = cast<MatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+                               Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
   case Type::FunctionProto: {
     const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
     MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -5568,6 +5568,11 @@
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+    const DependentSizedMatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
     const DependentAddressSpaceType *T) {
   return Visit(T->getPointeeType());
@@ -5586,6 +5591,10 @@
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitMatrixType(const MatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
                                                   const FunctionProtoType* T) {
   for (const auto &A : T->param_types()) {
Index: clang/lib/Sema/SemaLookup.cpp
===================================================================
--- clang/lib/Sema/SemaLookup.cpp
+++ clang/lib/Sema/SemaLookup.cpp
@@ -2959,6 +2959,7 @@
     // These are fundamental types.
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::Complex:
       break;
 
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4144,6 +4144,7 @@
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:
Index: clang/lib/Frontend/CompilerInvocation.cpp
===================================================================
--- clang/lib/Frontend/CompilerInvocation.cpp
+++ clang/lib/Frontend/CompilerInvocation.cpp
@@ -3283,7 +3283,8 @@
   }
 
   Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
-  Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
+  Opts.BuildingPCHWithObjectFile = Opts.EnableMatrix =
+      Args.hasArg(OPT_fenable_matrix);
 }
 
 static bool isStrictlyPreprocessorAction(frontend::ActionKind Action) {
@@ -3498,7 +3499,7 @@
   InputArgList Args = Opts.ParseArgs(CommandLineArgs, MissingArgIndex,
                                      MissingArgCount, IncludedFlagsBitmask);
   LangOptions &LangOpts = *Res.getLangOpts();
-
+  //
   // Check for missing argument error.
   if (MissingArgCount) {
     Diags.Report(diag::err_drv_missing_argument)
Index: clang/lib/Driver/ToolChains/Clang.cpp
===================================================================
--- clang/lib/Driver/ToolChains/Clang.cpp
+++ clang/lib/Driver/ToolChains/Clang.cpp
@@ -4364,6 +4364,13 @@
   if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
     CmdArgs.push_back("-fdefault-calling-conv=stdcall");
 
+  if (Args.hasArg(options::OPT_fenable_matrix)) {
+    // enable-matrix is needed by both the LangOpts and by LLVM.
+    CmdArgs.push_back("-fenable-matrix");
+    CmdArgs.push_back("-mllvm");
+    CmdArgs.push_back("-enable-matrix");
+  }
+
   CodeGenOptions::FramePointerKind FPKeepKind =
                   getFramePointerKind(Args, RawTriple);
   const char *FPKeepKindStr = nullptr;
Index: clang/lib/CodeGen/ItaniumCXXABI.cpp
===================================================================
--- clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3200,6 +3200,7 @@
   // GCC treats vector and complex types as fundamental types.
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
   case Type::Complex:
   case Type::Atomic:
   // FIXME: GCC treats block pointers as fundamental types?!
@@ -3435,6 +3436,7 @@
   case Type::Builtin:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
   case Type::Complex:
   case Type::BlockPointer:
     // Itanium C++ ABI 2.9.5p4:
Index: clang/lib/CodeGen/CodeGenTypes.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenTypes.cpp
+++ clang/lib/CodeGen/CodeGenTypes.cpp
@@ -84,6 +84,13 @@
 /// a type.  For example, the scalar representation for _Bool is i1, but the
 /// memory representation is usually i8 or i32, depending on the target.
 llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
+  if (T->isMatrixType()) {
+    const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+    const MatrixType *MT = cast<MatrixType>(Ty);
+    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+                                MT->getNumRows() * MT->getNumColumns());
+  }
+
   llvm::Type *R = ConvertType(T);
 
   // If this is a non-bool type, don't map it.
@@ -609,6 +616,12 @@
                                        VT->getNumElements());
     break;
   }
+  case Type::Matrix: {
+    const MatrixType *MT = cast<MatrixType>(Ty);
+    ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+                                       MT->getNumRows() * MT->getNumColumns());
+    break;
+  }
   case Type::FunctionNoProto:
   case Type::FunctionProto:
     ResultType = ConvertFunctionTypeInternal(T);
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -259,6 +259,7 @@
     case Type::MemberPointer:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Enum:
@@ -1969,6 +1970,7 @@
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -144,8 +144,19 @@
 
 Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
                                        const Twine &Name, Address *Alloca) {
-  return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
-                          /*ArraySize=*/nullptr, Alloca);
+  Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+                                    /*ArraySize=*/nullptr, Alloca);
+
+  if (Ty->isMatrixType()) {
+    auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+    auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                           ArrayTy->getNumElements());
+
+    Result = Address(
+        Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+        Result.getAlignment());
+  }
+  return Result;
 }
 
 Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1740,6 +1751,20 @@
     }
   }
 
+  if (Ty->isMatrixType()) {
+    auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+        cast<llvm::PointerType>(Addr.getPointer()->getType())
+            ->getElementType());
+    if (ArrayTy) {
+      auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                             ArrayTy->getNumElements());
+
+      Addr = Address(
+          Builder.CreateBitCast(Addr.getPointer(), VectorTy->getPointerTo()),
+          Addr.getAlignment());
+    }
+  }
+
   Value = EmitToMemory(Value, Ty);
 
   LValue AtomicLValue =
@@ -1793,6 +1818,20 @@
   if (LV.isSimple()) {
     assert(!LV.getType()->isFunctionType());
 
+    if (LV.getType()->isMatrixType()) {
+      auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+          cast<llvm::PointerType>(LV.getPointer(*this)->getType())
+              ->getElementType());
+      if (ArrayTy) {
+        auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                               ArrayTy->getNumElements());
+
+        LV.setAddress(Address(Builder.CreateBitCast(LV.getPointer(*this),
+                                                    VectorTy->getPointerTo()),
+                              LV.getAlignment()));
+      }
+    }
+
     // Everything needs a load.
     return RValue::get(EmitLoadOfScalar(LV, Loc));
   }
Index: clang/lib/CodeGen/CGDebugInfo.h
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.h
+++ clang/lib/CodeGen/CGDebugInfo.h
@@ -188,6 +188,7 @@
   llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
 
   llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+  llvm::DIType *CreateType(const MatrixType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);
Index: clang/lib/CodeGen/CGDebugInfo.cpp
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.cpp
+++ clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2642,6 +2642,23 @@
   return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
 }
 
+llvm::DIType *CGDebugInfo::CreateType(const MatrixType *Ty,
+                                      llvm::DIFile *Unit) {
+  llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+  uint64_t Size = CGM.getContext().getTypeSize(Ty);
+  uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+  // Number of Columns, followed by rows
+  llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+  llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+
+  // FIXME: Create another debug type for matrices
+  // For the time being, it treats it like a 2D array
+  return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
 llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
   uint64_t Size;
   uint32_t Align;
@@ -3035,6 +3052,8 @@
   case Type::ExtVector:
   case Type::Vector:
     return CreateType(cast<VectorType>(Ty), Unit);
+  case Type::Matrix:
+    return CreateType(cast<MatrixType>(Ty), Unit);
   case Type::ObjCObjectPointer:
     return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
   case Type::ObjCObject:
Index: clang/lib/AST/TypePrinter.cpp
===================================================================
--- clang/lib/AST/TypePrinter.cpp
+++ clang/lib/AST/TypePrinter.cpp
@@ -254,6 +254,8 @@
     case Type::DependentSizedExtVector:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
+    case Type::DependentSizedMatrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Paren:
@@ -718,6 +720,37 @@
   OS << ")))";
 }
 
+void TypePrinter::printMatrixBefore(const MatrixType *T, raw_ostream &OS) {
+  // TODO: Fix the spacing between the element type and the __attribute__
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  OS << T->getNumRows() << ", " << T->getNumColumns();
+  OS << "))) ";
+}
+
+void TypePrinter::printMatrixAfter(const MatrixType *T, raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  if (T->getRowExpr()) {
+    T->getRowExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ", ";
+  if (T->getColumnExpr()) {
+    T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << "))) ";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
 void
 FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
                                                const PrintingPolicy &Policy)
Index: clang/lib/AST/Type.cpp
===================================================================
--- clang/lib/AST/Type.cpp
+++ clang/lib/AST/Type.cpp
@@ -260,6 +260,42 @@
   SizeExpr->Profile(ID, Context, true);
 }
 
+MatrixType::MatrixType(QualType matrixType, unsigned nRows, unsigned nColumns,
+                       QualType canonType)
+    : MatrixType(Matrix, matrixType, nRows, nColumns, canonType) {}
+
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, unsigned nRows,
+                       unsigned nColumns, QualType canonType)
+    : Type(tc, canonType, matrixType->isDependentType(),
+           matrixType->isInstantiationDependentType(),
+           matrixType->isVariablyModifiedType(),
+           matrixType->containsUnexpandedParameterPack()),
+      ElementType(matrixType) {
+  MatrixTypeBits.NumRows = nRows;
+  MatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+    const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+    Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+    : Type(DependentSizedMatrix, CanonicalType, /*Dependent=*/true,
+           /*InstantiationDependent=*/true,
+           ElementType->isVariablyModifiedType(),
+           (ElementType->containsUnexpandedParameterPack() ||
+            (RowExpr && RowExpr->containsUnexpandedParameterPack()) ||
+            (ColumnExpr && ColumnExpr->containsUnexpandedParameterPack()))),
+      Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr),
+      ElementType(ElementType), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+                                       const ASTContext &CTX,
+                                       QualType ElementType, Expr *RowExpr,
+                                       Expr *ColumnExpr) {
+  ID.AddPointer(ElementType.getAsOpaquePtr());
+  RowExpr->Profile(ID, CTX, true);
+  ColumnExpr->Profile(ID, CTX, true);
+}
+
 DependentAddressSpaceType::DependentAddressSpaceType(
     const ASTContext &Context, QualType PointeeType, QualType can,
     Expr *AddrSpaceExpr, SourceLocation loc)
@@ -953,6 +989,16 @@
     return Ctx.getExtVectorType(elementType, T->getNumElements());
   }
 
+  QualType VisitMatrixType(const MatrixType *T) {
+    QualType elementType = recurse(T->getElementType());
+    if (elementType.isNull())
+      return {};
+    if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+      return QualType(T, 0);
+
+    return Ctx.getMatrixType(elementType, T->getNumRows(), T->getNumColumns());
+  }
+
   QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
     QualType returnType = recurse(T->getReturnType());
     if (returnType.isNull())
@@ -1770,6 +1816,14 @@
       return Visit(T->getElementType());
     }
 
+    Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
+    Type *VisitMatrixType(const MatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
     Type *VisitFunctionProtoType(const FunctionProtoType *T) {
       if (Syntactic && T->hasTrailingReturn())
         return const_cast<FunctionProtoType*>(T);
@@ -3689,6 +3743,8 @@
   case Type::Vector:
   case Type::ExtVector:
     return Cache::get(cast<VectorType>(T)->getElementType());
+  case Type::Matrix:
+    return Cache::get(cast<MatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return Cache::get(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3774,6 +3830,8 @@
   case Type::Vector:
   case Type::ExtVector:
     return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+  case Type::Matrix:
+    return computeTypeLinkageInfo(cast<MatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3935,6 +3993,8 @@
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -2759,6 +2759,23 @@
     << Range;
 }
 
+void MicrosoftCXXNameMangler::mangleType(const MatrixType *T, Qualifiers quals,
+                                         SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+                                          "Cannot mangle this matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+                                         Qualifiers quals, SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(
+      DiagnosticsEngine::Error,
+      "Cannot mangle this dependent-sized matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
 void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
                                          Qualifiers, SourceRange Range) {
   DiagnosticsEngine &Diags = Context.getDiags();
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -1987,6 +1987,8 @@
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
+  case Type::DependentSizedMatrix:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
   case Type::Paren:
@@ -3249,6 +3251,20 @@
   mangleType(T->getElementType());
 }
 
+void CXXNameMangler::mangleType(const MatrixType *T) {
+  Out << "Dm" << T->getNumRows() << "_" << T->getNumColumns() << '_';
+  mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+  Out << "Dm";
+  mangleExpression(T->getRowExpr());
+  Out << '_';
+  mangleExpression(T->getColumnExpr());
+  Out << '_';
+  mangleType(T->getElementType());
+}
+
 void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
   SplitQualType split = T->getPointeeType().split();
   mangleQualifiers(split.Quals, T);
Index: clang/lib/AST/ExprConstant.cpp
===================================================================
--- clang/lib/AST/ExprConstant.cpp
+++ clang/lib/AST/ExprConstant.cpp
@@ -10133,6 +10133,7 @@
   case Type::BlockPointer:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
   case Type::ObjCObject:
   case Type::ObjCInterface:
   case Type::ObjCObjectPointer:
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -623,6 +623,38 @@
     break;
   }
 
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+    const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+    // Rows
+    if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+                                  Mat2->getRowExpr())) {
+      return false;
+    }
+    // Columns
+    if (!IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+                                  Mat2->getColumnExpr())) {
+      return false;
+    }
+    // Element Type
+    if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType())) {
+      return false;
+    }
+    return true;
+  }
+
+  case Type::Matrix: {
+    const MatrixType *Mat1 = cast<MatrixType>(T1);
+    const MatrixType *Mat2 = cast<MatrixType>(T2);
+    if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType()))
+      return false;
+    if (Mat1->getNumRows() != Mat2->getNumRows())
+      return false;
+    break;
+  }
+
   case Type::FunctionProto: {
     const auto *Proto1 = cast<FunctionProtoType>(T1);
     const auto *Proto2 = cast<FunctionProtoType>(T2);
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -1887,6 +1887,18 @@
     break;
   }
 
+  case Type::Matrix: {
+    const auto *MT = cast<MatrixType>(T);
+    TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+    // The matrix type is intended to be ABI compatible with arrays with respect
+    // to alignment and size. We use LLVM's array type for storage.
+    Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+    // If the alignment is not a power of 2, round up to the next power of 2.
+    // This happens for non-power-of-2 length vectors.
+    Align = ElementInfo.Width;
+    break;
+  }
+
   case Type::Builtin:
     switch (cast<BuiltinType>(T)->getKind()) {
     default: llvm_unreachable("Unknown builtin type!");
@@ -3303,6 +3315,8 @@
   case Type::DependentVector:
   case Type::ExtVector:
   case Type::DependentSizedExtVector:
+  case Type::Matrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::ObjCObject:
   case Type::ObjCInterface:
@@ -3692,6 +3706,76 @@
   return QualType(New, 0);
 }
 
+/// getMatrixType - Return the unique reference to a matrix type of the
+/// specified element type and size. ElementTy must be a built-in integer or
+/// floating point type.
+QualType ASTContext::getMatrixType(QualType ElementTy, unsigned NumRows,
+                                   unsigned NumColumns) const {
+  llvm::FoldingSetNodeID ID;
+  MatrixType::Profile(ID, ElementTy, NumRows, NumColumns, Type::Matrix);
+
+  void *InsertPos = nullptr;
+  if (MatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) {
+    return QualType(MTP, 0);
+  }
+
+  QualType Canonical;
+  if (!ElementTy.isCanonical()) {
+    Canonical = getMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+    MatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+    assert(!NewIP && "Matrix type shouldn't already exist in the map");
+    (void)NewIP;
+  }
+
+  auto *New = new (*this, TypeAlignment)
+      MatrixType(ElementTy, NumRows, NumColumns, Canonical);
+  MatrixTypes.InsertNode(New, InsertPos);
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
+// getDependentSizedMatrixType - Return a unique reference to the
+// dependent matrix MatrixElementType must be a builtin type
+QualType ASTContext::getDependentSizedMatrixType(QualType MatrixElementType,
+                                                 Expr *RowExpr,
+                                                 Expr *ColumnExpr,
+                                                 SourceLocation AttrLoc) const {
+  llvm::FoldingSetNodeID ID;
+  DependentSizedMatrixType::Profile(
+      ID, *this, getCanonicalType(MatrixElementType), RowExpr, ColumnExpr);
+
+  void *InsertPos = nullptr;
+  DependentSizedMatrixType *Canon =
+      DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+  DependentSizedMatrixType *New;
+  if (Canon) {
+    // Already have a canonical version of the matrix type
+    // Use it as the canonical type for newly-built types
+    New = new (*this, TypeAlignment)
+        DependentSizedMatrixType(*this, MatrixElementType, QualType(Canon, 0),
+                                 RowExpr, ColumnExpr, AttrLoc);
+  } else {
+    QualType CanonicalMatrixElementType = getCanonicalType(MatrixElementType);
+    if (CanonicalMatrixElementType == MatrixElementType) {
+      New = new (*this, TypeAlignment) DependentSizedMatrixType(
+          *this, MatrixElementType, QualType(), RowExpr, ColumnExpr, AttrLoc);
+      DependentSizedMatrixType *CanonCheck =
+          DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+      assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+      (void)CanonCheck;
+      DependentSizedMatrixTypes.InsertNode(New, InsertPos);
+    } else {
+      QualType Canon = getDependentSizedMatrixType(
+          CanonicalMatrixElementType, RowExpr, ColumnExpr, SourceLocation());
+      New = new (*this, TypeAlignment) DependentSizedMatrixType(
+          *this, MatrixElementType, Canon, RowExpr, ColumnExpr, AttrLoc);
+    }
+  }
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
 QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
                                                   Expr *AddrSpaceExpr,
                                                   SourceLocation AttrLoc) const {
@@ -7177,6 +7261,11 @@
       *NotEncodedT = T;
     return;
 
+  case Type::Matrix:
+    if (NotEncodedT)
+      *NotEncodedT = T;
+    return;
+
   // We could see an undeduced auto type here during error recovery.
   // Just ignore it.
   case Type::Auto:
@@ -8002,6 +8091,15 @@
          LHS->getNumElements() == RHS->getNumElements();
 }
 
+/// areCompatMatrixTypes - Return true if the two specified vector types are
+/// compatible.
+static bool areCompatMatrixTypes(const MatrixType *LHS, const MatrixType *RHS) {
+  assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+  return LHS->getElementType() == RHS->getElementType() &&
+         LHS->getNumRows() == RHS->getNumRows() &&
+         LHS->getNumColumns() == RHS->getNumColumns();
+}
+
 bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
                                           QualType SecondVec) {
   assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9198,6 +9296,11 @@
                              RHSCan->castAs<VectorType>()))
       return LHS;
     return {};
+  case Type::Matrix:
+    if (areCompatMatrixTypes(LHSCan->castAs<MatrixType>(),
+                             RHSCan->castAs<MatrixType>()))
+      return LHS;
+    return {};
   case Type::ObjCObject: {
     // Check if the types are assignment compatible.
     // FIXME: This should be type compatibility, e.g. whether
Index: clang/include/clang/Serialization/TypeBitCodes.def
===================================================================
--- clang/include/clang/Serialization/TypeBitCodes.def
+++ clang/include/clang/Serialization/TypeBitCodes.def
@@ -58,5 +58,7 @@
 TYPE_BIT_CODE(DependentAddressSpace, DEPENDENT_ADDRESS_SPACE, 47)
 TYPE_BIT_CODE(DependentVector, DEPENDENT_SIZED_VECTOR, 48)
 TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49)
+TYPE_BIT_CODE(Matrix, MATRIX, 50)
+TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 51)
 
 #undef TYPE_BIT_CODE
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1537,6 +1537,9 @@
   QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc);
   QualType BuildExtVectorType(QualType T, Expr *ArraySize,
                               SourceLocation AttrLoc);
+  QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
+                           SourceLocation AttrLoc);
+
   QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
                                  SourceLocation AttrLoc);
 
Index: clang/include/clang/Driver/Options.td
===================================================================
--- clang/include/clang/Driver/Options.td
+++ clang/include/clang/Driver/Options.td
@@ -1951,6 +1951,10 @@
 def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>,
   Flags<[CC1Option]>;
 
+def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
+    Flags<[CC1Option]>,
+    HelpText<"Enable matrix data type and related builtin functions">;
+
 def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">,
   Group<f_Group>, Flags<[CC1Option]>,
   HelpText<"Treat editor placeholders as valid source code">;
Index: clang/include/clang/Basic/TypeNodes.td
===================================================================
--- clang/include/clang/Basic/TypeNodes.td
+++ clang/include/clang/Basic/TypeNodes.td
@@ -65,10 +65,12 @@
 def VariableArrayType : TypeNode<ArrayType>;
 def DependentSizedArrayType : TypeNode<ArrayType>, AlwaysDependent;
 def DependentSizedExtVectorType : TypeNode<Type>, AlwaysDependent;
+def DependentSizedMatrixType : TypeNode<Type>, AlwaysDependent;
 def DependentAddressSpaceType : TypeNode<Type>, AlwaysDependent;
 def VectorType : TypeNode<Type>;
 def DependentVectorType : TypeNode<Type>, AlwaysDependent;
 def ExtVectorType : TypeNode<VectorType>;
+def MatrixType : TypeNode<Type>;
 def FunctionType : TypeNode<Type, 1>;
 def FunctionProtoType : TypeNode<FunctionType>;
 def FunctionNoProtoType : TypeNode<FunctionType>;
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -343,6 +343,8 @@
 
 LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors")
 
+LANGOPT(EnableMatrix, 1, 0, "Enable or disable the builtin matrix type")
+
 #undef LANGOPT
 #undef COMPATIBLE_LANGOPT
 #undef BENIGN_LANGOPT
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2641,6 +2641,7 @@
 def err_attribute_too_few_arguments : Error<
   "%0 attribute takes at least %1 argument%s1">;
 def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
+def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
 def err_attribute_bad_neon_vector_size : Error<
   "Neon vector size must be 64 or 128 bits">;
 def err_attribute_requires_positive_integer : Error<
@@ -10239,6 +10240,9 @@
   "%select{non-pointer|function pointer|void pointer}0 argument to "
   "'__builtin_launder' is not allowed">;
 
+def err_builtin_matrix_disabled: Error<
+  "Builtin matrix support is disabled. Pass -fenable-matrix to enable it.">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -2372,6 +2372,15 @@
   let Documentation = [Undocumented];
 }
 
+def MatrixType : TypeAttr {
+  let Spellings = [Clang<"matrix_type">];
+  let Subjects = SubjectList<[TypedefName], ErrorDiag>;
+  let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
+  let Documentation = [Undocumented];
+  let ASTNode = 0;
+  let PragmaAttributeSupport = 0;
+}
+
 def Visibility : InheritableAttr {
   let Clone = 0;
   let Spellings = [GCC<"visibility">];
Index: clang/include/clang/AST/TypeProperties.td
===================================================================
--- clang/include/clang/AST/TypeProperties.td
+++ clang/include/clang/AST/TypeProperties.td
@@ -224,6 +224,41 @@
   }]>;
 }
 
+let Class = MatrixType in {
+  def : Property<"elementType", QualType> {
+    let Read = [{ node->getElementType() }];
+  }
+  def : Property<"numRows", UInt32> {
+    let Read = [{ node->getNumRows() }];
+  }
+  def : Property<"numColumns", UInt32> {
+    let Read = [{ node->getNumColumns() }];
+  }
+
+  def : Creator<[{
+    return ctx.getMatrixType(elementType, numRows, numColumns);
+  }]>;
+}
+
+let Class = DependentSizedMatrixType in {
+  def : Property<"elementType", QualType> {
+    let Read = [{ node->getElementType() }];
+  }
+  def : Property<"rows", ExprRef> {
+    let Read = [{ node->getRowExpr() }];
+  }
+  def : Property<"columns", ExprRef> {
+    let Read = [{ node->getColumnExpr() }];
+  }
+  def : Property<"attributeLoc", SourceLocation> {
+    let Read = [{ node->getAttributeLoc() }];
+  }
+
+  def : Creator<[{
+    return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
+  }]>;
+}
+
 let Class = FunctionType in {
   def : Property<"returnType", QualType> {
     let Read = [{ node->getReturnType() }];
Index: clang/include/clang/AST/TypeLoc.h
===================================================================
--- clang/include/clang/AST/TypeLoc.h
+++ clang/include/clang/AST/TypeLoc.h
@@ -1767,6 +1767,18 @@
                                      DependentSizedExtVectorType> {
 };
 
+// Same as VectorType: FIXME: attribute locations.
+class MatrixTypeLoc
+    : public InheritingConcreteTypeLoc<TypeSpecTypeLoc, MatrixTypeLoc,
+                                       MatrixType> {};
+
+// Same as VectorType: FIXME: attribute locations.  Also look into making this
+// a subtype of the MatrixTypeLoc
+class DependentSizedMatrixTypeLoc
+    : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
+                                       DependentSizedMatrixTypeLoc,
+                                       DependentSizedMatrixType> {};
+
 // FIXME: location of the '_Complex' keyword.
 class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
                                                         ComplexTypeLoc,
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -1666,6 +1666,19 @@
     enum { MaxNumElements = (1 << (29 - NumTypeBits)) - 1 };
   };
 
+  class MatrixTypeBitfields {
+    friend class MatrixType;
+
+    unsigned : NumTypeBits;
+
+    // Number of rows and columns
+    unsigned NumRows : 29 - NumTypeBits;
+    unsigned NumColumns : 29 - NumTypeBits;
+
+    enum { MaxNumRows = (1 << (29 - NumTypeBits)) - 1 };
+    enum { MaxNumColumns = (1 << (29 - NumTypeBits)) - 1 };
+  };
+
   class AttributedTypeBitfields {
     friend class AttributedType;
 
@@ -1766,6 +1779,7 @@
     TypeWithKeywordBitfields TypeWithKeywordBits;
     ElaboratedTypeBitfields ElaboratedTypeBits;
     VectorTypeBitfields VectorTypeBits;
+    MatrixTypeBitfields MatrixTypeBits;
     SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
     TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
     DependentTemplateSpecializationTypeBitfields
@@ -2029,6 +2043,7 @@
   bool isComplexIntegerType() const;            // GCC _Complex integer type.
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
+  bool isMatrixType() const;
   bool isDependentAddressSpaceType() const;     // value-dependent address space qualifier
   bool isObjCObjectPointerType() const;         // pointer to ObjC object
   bool isObjCRetainableType() const;            // ObjC object or block pointer
@@ -3407,6 +3422,114 @@
   }
 };
 
+/// MatrixType - This type is created using
+/// __attribute__((matrix_type(rows, columns))), where "rows" is the
+/// number of rows and "columns" is the number of columns.
+class MatrixType : public Type, public llvm::FoldingSetNode {
+protected:
+  friend class ASTContext;
+
+  QualType ElementType;
+
+  // MatrixElementType:   The type of the elements in the matrix
+  // NRows:               Number of rows
+  // NColumns:            Number of columns
+  // CanonElementType:    Canonical element type (if the matrix type is not
+  // canonical)
+  MatrixType(QualType MatrixElementType, unsigned NRows, unsigned NColumns,
+             QualType CanonElementType);
+
+  // typeClass:           The typeclass (defined in TypeNodes.def)
+  // MatrixElementType:   The type of elements in the matrix
+  // NRows:               The number of rows
+  // NColumns:            The number of columns
+  // CanonElementType:    Canonical type (if the matrixType is not canonical)
+  MatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
+             unsigned NColumns, QualType CanonElementType);
+
+public:
+  // The type of the elements being stored in the matrix
+  QualType getElementType() const { return ElementType; }
+
+  // The number of rows in the matrix
+  unsigned getNumRows() const { return MatrixTypeBits.NumRows; }
+
+  // The number of columns in the matrix
+  unsigned getNumColumns() const { return MatrixTypeBits.NumColumns; }
+
+  unsigned getNumElementsFlattened() const {
+    return MatrixTypeBits.NumRows * MatrixTypeBits.NumColumns;
+  }
+
+  // Check if the dimensions of the matrix fit in data storage type
+  static bool tooBig(unsigned NumRows, unsigned NumColumns) {
+    return NumRows > MatrixTypeBitfields::MaxNumRows ||
+           NumColumns > MatrixTypeBitfields::MaxNumColumns;
+  }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, getElementType(), getNumRows(), getNumColumns(),
+            getTypeClass());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
+                      unsigned NumRows, unsigned NumColumns,
+                      TypeClass TypeClass) {
+    ID.AddPointer(ElementType.getAsOpaquePtr());
+    ID.AddInteger(NumRows);
+    ID.AddInteger(NumColumns);
+    ID.AddInteger(TypeClass);
+  }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == Matrix ||
+           T->getTypeClass() == DependentSizedMatrix;
+  }
+};
+
+/// DependentSizedMatrixType - Represents a matrix type where the type
+/// and size is dependnt on a template.
+///
+class DependentSizedMatrixType : public Type, public llvm::FoldingSetNode {
+  friend class ASTContext;
+
+  const ASTContext &Context;
+  Expr *RowExpr;
+  Expr *ColumnExpr;
+
+  /// The element type of the matrix
+  QualType ElementType;
+
+  SourceLocation loc;
+
+  DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
+                           QualType CanonicalType, Expr *RowExpr,
+                           Expr *ColumnExpr, SourceLocation loc);
+
+public:
+  QualType getElementType() const { return ElementType; }
+  Expr *getRowExpr() const { return RowExpr; }
+  Expr *getColumnExpr() const { return ColumnExpr; }
+  SourceLocation getAttributeLoc() const { return loc; }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == DependentSizedMatrix;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
+                      QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
+};
+
 /// FunctionType - C99 6.7.5.3 - Function Declarators.  This is the common base
 /// class of FunctionNoProtoType and FunctionProtoType.
 class FunctionType : public Type {
@@ -6573,6 +6696,10 @@
   return isa<ExtVectorType>(CanonicalType);
 }
 
+inline bool Type::isMatrixType() const {
+  return isa<MatrixType>(CanonicalType);
+}
+
 inline bool Type::isDependentAddressSpaceType() const {
   return isa<DependentAddressSpaceType>(CanonicalType);
 }
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1000,6 +1000,16 @@
 
 DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
 
+DEF_TRAVERSE_TYPE(MatrixType, { TRY_TO(TraverseType(T->getElementType())); })
+
+DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
+  if (T->getRowExpr())
+    TRY_TO(TraverseStmt(T->getRowExpr()));
+  if (T->getColumnExpr())
+    TRY_TO(TraverseStmt(T->getColumnExpr()));
+  TRY_TO(TraverseType(T->getElementType()));
+})
+
 DEF_TRAVERSE_TYPE(FunctionNoProtoType,
                   { TRY_TO(TraverseType(T->getReturnType())); })
 
@@ -1235,6 +1245,21 @@
   TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
 })
 
+// Same as VectorType: FIXME: MatrixTypeLoc is unfinished
+DEF_TRAVERSE_TYPELOC(MatrixType, {
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
+DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
+  if (TL.getTypePtr()->getRowExpr()) {
+    TRY_TO(TraverseStmt(TL.getTypePtr()->getRowExpr()));
+  }
+  if (TL.getTypePtr()->getColumnExpr()) {
+    TRY_TO(TraverseStmt(TL.getTypePtr()->getColumnExpr()));
+  }
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
 DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
                      { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
 
Index: clang/include/clang/AST/ASTContext.h
===================================================================
--- clang/include/clang/AST/ASTContext.h
+++ clang/include/clang/AST/ASTContext.h
@@ -187,6 +187,8 @@
       DependentAddressSpaceTypes;
   mutable llvm::FoldingSet<VectorType> VectorTypes;
   mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
+  mutable llvm::FoldingSet<MatrixType> MatrixTypes;
+  mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
   mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
   mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
     FunctionProtoTypes;
@@ -1380,6 +1382,21 @@
                                           Expr *SizeExpr,
                                           SourceLocation AttrLoc) const;
 
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  ///
+  /// \pre \p MatrixType must be a built-in type.
+  QualType getMatrixType(QualType MatrixType, unsigned NumRows,
+                         unsigned NumColumns) const;
+
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  ///
+  /// \pre \p MatrixElementType must be a built-in type.
+  QualType getDependentSizedMatrixType(QualType MatrixElementType,
+                                       Expr *RowExpr, Expr *ColumnExpr,
+                                       SourceLocation AttrLoc) const;
+
   QualType getDependentAddressSpaceType(QualType PointeeType,
                                         Expr *AddrSpaceExpr,
                                         SourceLocation AttrLoc) const;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D72281: [Ma... Florian Hahn via Phabricator via cfe-commits

Reply via email to