fhahn updated this revision to Diff 269027. fhahn marked 2 inline comments as done. fhahn added a comment.
In D76794#2078232 <https://reviews.llvm.org/D76794#2078232>, @rjmccall wrote: > Your IRGen test cases cover a lot of ground, but please add more Sema test > cases that go over the basics: element types matching, column/row counts > matching, multiplication by inappropriate scalars, etc. Otherwise LGTM! Oh, of course! I think some of those might have been accidentally dropped when moving patches around. I've added the additional Sema tests. Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D76794/new/ https://reviews.llvm.org/D76794 Files: clang/include/clang/Sema/Sema.h clang/lib/CodeGen/CGExprScalar.cpp clang/lib/Sema/SemaExpr.cpp clang/lib/Sema/SemaOverload.cpp clang/test/CodeGen/matrix-type-operators.c clang/test/CodeGenCXX/matrix-type-operators.cpp clang/test/Sema/matrix-type-operators.c clang/test/SemaCXX/matrix-type-operators.cpp llvm/include/llvm/IR/MatrixBuilder.h
Index: llvm/include/llvm/IR/MatrixBuilder.h =================================================================== --- llvm/include/llvm/IR/MatrixBuilder.h +++ llvm/include/llvm/IR/MatrixBuilder.h @@ -33,6 +33,21 @@ IRBuilderTy &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, + Value *RHS) { + assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && + "One of the operands must be a matrix (embedded in a vector)"); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast<VectorType>(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast<VectorType>(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return {LHS, RHS}; + } + public: MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} @@ -164,15 +179,13 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); + if (LHS->getType()->getScalarType()->isFloatingPointTy()) + return B.CreateFMul(LHS, RHS); + return B.CreateMul(LHS, RHS); } /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. Index: clang/test/SemaCXX/matrix-type-operators.cpp =================================================================== --- clang/test/SemaCXX/matrix-type-operators.cpp +++ clang/test/SemaCXX/matrix-type-operators.cpp @@ -65,6 +65,45 @@ // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}} } +template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2> +typename MyMatrix<EltTy2, R2, C2>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} + + MyMatrix<int, 5, 6> m; + B.value = m.value * A.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix<unsigned, 2, 2> Mat1; + MyMatrix<unsigned, 3, 3> Mat2; + MyMatrix<float, 2, 2> Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + MyMatrix<unsigned, 3, 2> Mat4; + Mat1.value = multiply<unsigned, 3, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat4, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 3, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}} + + Mat1.value = multiply<float, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat3, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<float, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}} + + Mat4.value = Mat4.value * Mat1; + // expected-error@-1 {{no viable conversion from 'MyMatrix<unsigned int, 2, 2>' to 'unsigned int'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 2, 2>')}} +} + struct UserT {}; struct StructWithC { Index: clang/test/Sema/matrix-type-operators.c =================================================================== --- clang/test/Sema/matrix-type-operators.c +++ clang/test/Sema/matrix-type-operators.c @@ -32,6 +32,44 @@ // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} } +typedef int ix10x5_t __attribute__((matrix_type(10, 5))); +typedef int ix10x10_t __attribute__((matrix_type(10, 10))); + +void matrix_matrix_multiply(sx10x10_t a, sx5x10_t b, ix10x5_t c, ix10x10_t d, float sf, char *p) { + // Check dimension mismatches. + a = a * b; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}} + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + // Check element type mismatches. + a = b * c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'ix10x5_t' (aka 'int __attribute__((matrix_type(10, 5)))'))}} + d = a * a; + // expected-error@-1 {{assigning to 'ix10x10_t' (aka 'int __attribute__((matrix_type(10, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + p = a * a; + // expected-error@-1 {{assigning to 'char *' from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float sf, char *p) { + // Shape of multiplication result does not match the type of b. + b = a * sf; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} + b = sf * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} + + a = a * p; + // expected-error@-1 {{casting 'char *' to incompatible type 'float'}} + // expected-error@-2 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'char *')}} + a = p * a; + // expected-error@-1 {{casting 'char *' to incompatible type 'float'}} + // expected-error@-2 {{invalid operands to binary expression ('char *' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}} + + sf = a * sf; + // expected-error@-1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} + sx5x10_t get_matrix(); void insert(sx5x10_t a, float f) { Index: clang/test/CodeGenCXX/matrix-type-operators.cpp =================================================================== --- clang/test/CodeGenCXX/matrix-type-operators.cpp +++ clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -157,6 +157,45 @@ m.value = w3 - m.value; } +template <typename EltTy0, unsigned R0, unsigned C0, unsigned C1> +typename MyMatrix<EltTy0, R0, C1>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, C0, C1> &B) { + return A.value * B.value; +} + +MyMatrix<float, 2, 2> test_multiply_template(MyMatrix<float, 2, 5> Mat1, + MyMatrix<float, 5, 2> Mat2) { + // CHECK-LABEL: define void @_Z22test_multiply_template8MyMatrixIfLj2ELj5EES_IfLj5ELj2EE( + // CHECK-NEXT: entry: + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.2* nonnull align 4 dereferenceable(40) %Mat2) + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %agg.result, i32 0, i32 0 + // CHECK-NEXT: [[VALUE_ADDR:%.*]] = bitcast [4 x float]* %value to <4 x float>* + // CHECK-NEXT: store <4 x float> [[RES]], <4 x float>* [[VALUE_ADDR]], align 4 + // CHECK-NEXT: ret void + // + // CHECK-LABEL: define linkonce_odr <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE( + // CHECK: [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK: [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @llvm.matrix.multiply.v4f32.v10f32.v10f32(<10 x float> [[MAT1]], <10 x float> [[MAT2]], i32 2, i32 5, i32 2) + // CHECK-NEXT: ret <4 x float> [[RES]] + + MyMatrix<float, 2, 2> Res; + Res.value = multiply(Mat1, Mat2); + return Res; +} + +void test_IntWrapper_Multiply(MyMatrix<double, 10, 9> &m, IntWrapper &w3) { + // CHECK-LABEL: define void @_Z24test_IntWrapper_MultiplyR8MyMatrixIdLj10ELj9EER10IntWrapper( + // CHECK: [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* {{.*}}) + // CHECK-NEXT: [[SCALAR_FP:%.*]] = sitofp i32 %call to double + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fmul <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + // CHECK: ret void + m.value = w3 * m.value; +} + template <typename EltTy, unsigned Rows, unsigned Columns> void insert(MyMatrix<EltTy, Rows, Columns> &Mat, EltTy e, unsigned i, unsigned j) { Mat.value[i][j] = e; @@ -164,11 +203,11 @@ void test_insert_template1(MyMatrix<unsigned, 2, 2> &Mat, unsigned e, unsigned i, unsigned j) { // CHECK-LABEL: @_Z21test_insert_template1R8MyMatrixIjLj2ELj2EEjjj( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load i32, i32* %e.addr, align 4 // CHECK-NEXT: [[I:%.*]] = load i32, i32* %i.addr, align 4 // CHECK-NEXT: [[J:%.*]] = load i32, i32* %j.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.1* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) + // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -190,9 +229,9 @@ void test_insert_template2(MyMatrix<float, 3, 8> &Mat, float e) { // CHECK-LABEL: @_Z21test_insert_template2R8MyMatrixIfLj3ELj8EEf( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.4*, %struct.MyMatrix.4** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load float, float* %e.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.2* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) + // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.4* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -220,7 +259,7 @@ int test_extract_template(MyMatrix<int, 2, 2> Mat1) { // CHECK-LABEL: @_Z21test_extract_template8MyMatrixIiLj2ELj2EE( // CHECK-NEXT: entry: - // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) + // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.5* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) // CHECK-NEXT: ret i32 [[CALL]] // // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE( @@ -301,7 +340,7 @@ constexpr identmatrix_t identmatrix; void test_constexpr1(matrix_type<float, 4, 4> &m) { - // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef([16 x float]* nonnull align 4 dereferenceable(64) %m) #3 { + // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef( // CHECK: [[MAT:%.*]] = load <16 x float>, <16 x float>* {{.*}}, align 4 // CHECK-NEXT: [[IM:%.*]] = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK-NEXT: [[ADD:%.*]] = fadd <16 x float> [[MAT]], [[IM]] @@ -327,7 +366,7 @@ } void test_constexpr2(matrix_type<int, 5, 5> &m) { - // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei([25 x i32]* nonnull align 4 dereferenceable(100) %m) #4 { + // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei( // CHECK: [[IM:%.*]] = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK: [[MAT:%.*]] = load <25 x i32>, <25 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[SUB:%.*]] = sub <25 x i32> [[IM]], [[MAT]] Index: clang/test/CodeGen/matrix-type-operators.c =================================================================== --- clang/test/CodeGen/matrix-type-operators.c +++ clang/test/CodeGen/matrix-type-operators.c @@ -173,6 +173,134 @@ b = vulli + b; } +// Tests for matrix multiplication. + +void multiply_matrix_matrix_double(dx5x5_t b, dx5x5_t c) { + // CHECK-LABEL: @multiply_matrix_matrix_double( + // CHECK: [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[RES:%.*]] = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> [[B]], <25 x double> [[C]], i32 5, i32 5, i32 5) + // CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [25 x double]* %a to <25 x double>* + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* [[A_ADDR]], align 8 + // CHECK-NEXT: ret void + // + + dx5x5_t a; + a = b * c; +} + +typedef int ix3x9_t __attribute__((matrix_type(3, 9))); +typedef int ix9x9_t __attribute__((matrix_type(9, 9))); +// CHECK-LABEL: @multiply_matrix_matrix_int( +// CHECK: [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[RES:%.*]] = call <81 x i32> @llvm.matrix.multiply.v81i32.v27i32.v27i32(<27 x i32> [[B]], <27 x i32> [[C]], i32 9, i32 3, i32 9) +// CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [81 x i32]* %a to <81 x i32>* +// CHECK-NEXT: store <81 x i32> [[RES]], <81 x i32>* [[A_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_matrix_matrix_int(ix9x3_t b, ix3x9_t c) { + ix9x9_t a; + a = b * c; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_float( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load float, float* %s.addr, align 4 +// CHECK-NEXT: [[S_EXT:%.*]] = fpext float [[S]] to double +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_float(dx5x5_t a, float s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_double( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_double(dx5x5_t a, double s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_float_matrix_scalar_double( +// CHECK: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = fptrunc double [[S]] to float +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <6 x float> undef, float [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> undef, <6 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_scalar_double(fx2x3_t b, double s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_short( +// CHECK: [[S:%.*]] = load i16, i16* %s.addr, align 2 +// CHECK-NEXT: [[S_EXT:%.*]] = sext i16 [[S]] to i32 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_short(ix9x3_t b, short s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_ull( +// CHECK: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[S:%.*]] = load i64, i64* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[MAT]], [[VECSPLAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) { + b = b * s; +} + +// CHECK-LABEL: @multiply_float_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [6 x float], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>* +// CHECK-NEXT: store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[MAT]], <float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00> +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_constant(fx2x3_t a) { + a = a * 2.5; +} + +// CHECK-LABEL: @multiply_int_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [27 x i32], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [27 x i32]* [[A_ADDR]] to <27 x i32>* +// CHECK-NEXT: store <27 x i32> [[A:%.*]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> <i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5>, [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_constant(ix9x3_t a) { + a = 5 * a; +} + // Tests for the matrix type operators. typedef double dx5x5_t __attribute__((matrix_type(5, 5))); Index: clang/lib/Sema/SemaOverload.cpp =================================================================== --- clang/lib/Sema/SemaOverload.cpp +++ clang/lib/Sema/SemaOverload.cpp @@ -9167,8 +9167,10 @@ case OO_Star: // '*' is either unary or binary if (Args.size() == 1) OpBuilder.addUnaryStarPointerOverloads(); - else + else { OpBuilder.addGenericBinaryArithmeticOverloads(); + OpBuilder.addMatrixBinaryArithmeticOverloads(); + } break; case OO_Slash: Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -10058,6 +10058,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -12120,6 +12123,37 @@ return InvalidOperands(Loc, LHS, RHS); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + if (!IsCompAssign) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + if (LHS.isInvalid()) + return QualType(); + } + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + if (RHS.isInvalid()) + return QualType(); + + auto *LHSMatType = LHS.get()->getType()->getAs<ConstantMatrixType>(); + auto *RHSMatType = RHS.get()->getType()->getAs<ConstantMatrixType>(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && RHSMatType) { + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getConstantMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); + } + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { Index: clang/lib/CodeGen/CGExprScalar.cpp =================================================================== --- clang/lib/CodeGen/CGExprScalar.cpp +++ clang/lib/CodeGen/CGExprScalar.cpp @@ -741,6 +741,22 @@ } } + if (Ops.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast<BinaryOperator>(Ops.E); + auto *LHSMatTy = dyn_cast<ConstantMatrixType>( + BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = dyn_cast<ConstantMatrixType>( + BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -11218,6 +11218,8 @@ QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, bool IsCompAssign); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType);
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits