fhahn updated this revision to Diff 267627. fhahn added a comment. Ping :)
Updated the patch to include the feedback from D76793 <https://reviews.llvm.org/D76793> (adding overloads, conversions, more targeted tests) Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D76794/new/ https://reviews.llvm.org/D76794 Files: clang/include/clang/Sema/Sema.h clang/lib/CodeGen/CGExprScalar.cpp clang/lib/Sema/SemaExpr.cpp clang/lib/Sema/SemaOverload.cpp clang/test/CodeGen/matrix-type-operators.c clang/test/CodeGenCXX/matrix-type-operators.cpp clang/test/Sema/matrix-type-operators.c clang/test/SemaCXX/matrix-type-operators.cpp llvm/include/llvm/IR/MatrixBuilder.h
Index: llvm/include/llvm/IR/MatrixBuilder.h =================================================================== --- llvm/include/llvm/IR/MatrixBuilder.h +++ llvm/include/llvm/IR/MatrixBuilder.h @@ -33,6 +33,21 @@ IRBuilderTy &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, + Value *RHS) { + assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && + "One of the operands must be a matrix (embedded in a vector)"); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast<VectorType>(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast<VectorType>(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return {LHS, RHS}; + } + public: MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} @@ -164,15 +179,13 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); + if (LHS->getType()->getScalarType()->isFloatingPointTy()) + return B.CreateFMul(LHS, RHS); + return B.CreateMul(LHS, RHS); } /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. Index: clang/test/SemaCXX/matrix-type-operators.cpp =================================================================== --- clang/test/SemaCXX/matrix-type-operators.cpp +++ clang/test/SemaCXX/matrix-type-operators.cpp @@ -204,3 +204,26 @@ a[2] = f; // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} } + +template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2> +typename MyMatrix<EltTy2, R2, C2>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix<unsigned, 2, 2> Mat1; + MyMatrix<unsigned, 3, 3> Mat2; + MyMatrix<float, 2, 2> Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + Mat1.value = multiply<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}} +} Index: clang/test/Sema/matrix-type-operators.c =================================================================== --- clang/test/Sema/matrix-type-operators.c +++ clang/test/Sema/matrix-type-operators.c @@ -132,3 +132,9 @@ return &(*a)[0][1]; // expected-error@-1 {{address of matrix element requested}} } + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float scalar) { + // Shape of multiplication result does not match the type of b. + b = a * scalar; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} Index: clang/test/CodeGenCXX/matrix-type-operators.cpp =================================================================== --- clang/test/CodeGenCXX/matrix-type-operators.cpp +++ clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -157,6 +157,45 @@ m.value = w3 - m.value; } +template <typename EltTy0, unsigned R0, unsigned C0, unsigned C1> +typename MyMatrix<EltTy0, R0, C1>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, C0, C1> &B) { + return A.value * B.value; +} + +MyMatrix<float, 2, 2> test_multiply_template(MyMatrix<float, 2, 5> Mat1, + MyMatrix<float, 5, 2> Mat2) { + // CHECK-LABEL: define void @_Z22test_multiply_template8MyMatrixIfLj2ELj5EES_IfLj5ELj2EE( + // CHECK-NEXT: entry: + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.2* nonnull align 4 dereferenceable(40) %Mat2) + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %agg.result, i32 0, i32 0 + // CHECK-NEXT: [[VALUE_ADDR:%.*]] = bitcast [4 x float]* %value to <4 x float>* + // CHECK-NEXT: store <4 x float> [[RES]], <4 x float>* [[VALUE_ADDR]], align 4 + // CHECK-NEXT: ret void + // + // CHECK-LABEL: define linkonce_odr <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE( + // CHECK: [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK: [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @llvm.matrix.multiply.v4f32.v10f32.v10f32(<10 x float> [[MAT1]], <10 x float> [[MAT2]], i32 2, i32 5, i32 2) + // CHECK-NEXT: ret <4 x float> [[RES]] + + MyMatrix<float, 2, 2> Res; + Res.value = multiply(Mat1, Mat2); + return Res; +} + +void test_IntWrapper_Multiply(MyMatrix<double, 10, 9> &m, IntWrapper &w3) { + // CHECK-LABEL: define void @_Z24test_IntWrapper_MultiplyR8MyMatrixIdLj10ELj9EER10IntWrapper( + // CHECK: [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* {{.*}}) + // CHECK-NEXT: [[SCALAR_FP:%.*]] = sitofp i32 %call to double + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fmul <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + // CHECK: ret void + m.value = w3 * m.value; +} + template <typename EltTy, unsigned Rows, unsigned Columns> void insert(MyMatrix<EltTy, Rows, Columns> &Mat, EltTy e, unsigned i, unsigned j) { Mat.value[i][j] = e; @@ -164,11 +203,11 @@ void test_insert_template1(MyMatrix<unsigned, 2, 2> &Mat, unsigned e, unsigned i, unsigned j) { // CHECK-LABEL: @_Z21test_insert_template1R8MyMatrixIjLj2ELj2EEjjj( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load i32, i32* %e.addr, align 4 // CHECK-NEXT: [[I:%.*]] = load i32, i32* %i.addr, align 4 // CHECK-NEXT: [[J:%.*]] = load i32, i32* %j.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.1* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) + // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -190,9 +229,9 @@ void test_insert_template2(MyMatrix<float, 3, 8> &Mat, float e) { // CHECK-LABEL: @_Z21test_insert_template2R8MyMatrixIfLj3ELj8EEf( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.4*, %struct.MyMatrix.4** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load float, float* %e.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.2* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) + // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.4* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -220,7 +259,7 @@ int test_extract_template(MyMatrix<int, 2, 2> Mat1) { // CHECK-LABEL: @_Z21test_extract_template8MyMatrixIiLj2ELj2EE( // CHECK-NEXT: entry: - // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) + // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.5* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) // CHECK-NEXT: ret i32 [[CALL]] // // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE( @@ -301,7 +340,7 @@ constexpr identmatrix_t identmatrix; void test_constexpr1(matrix_type<float, 4, 4> &m) { - // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef([16 x float]* nonnull align 4 dereferenceable(64) %m) #3 { + // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef( // CHECK: [[MAT:%.*]] = load <16 x float>, <16 x float>* {{.*}}, align 4 // CHECK-NEXT: [[IM:%.*]] = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK-NEXT: [[ADD:%.*]] = fadd <16 x float> [[MAT]], [[IM]] @@ -327,7 +366,7 @@ } void test_constexpr2(matrix_type<int, 5, 5> &m) { - // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei([25 x i32]* nonnull align 4 dereferenceable(100) %m) #4 { + // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei( // CHECK: [[IM:%.*]] = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK: [[MAT:%.*]] = load <25 x i32>, <25 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[SUB:%.*]] = sub <25 x i32> [[IM]], [[MAT]] Index: clang/test/CodeGen/matrix-type-operators.c =================================================================== --- clang/test/CodeGen/matrix-type-operators.c +++ clang/test/CodeGen/matrix-type-operators.c @@ -173,6 +173,134 @@ b = vulli + b; } +// Tests for matrix multiplication. + +void multiply_matrix_matrix_double(dx5x5_t b, dx5x5_t c) { + // CHECK-LABEL: @multiply_matrix_matrix_double( + // CHECK: [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[RES:%.*]] = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> [[B]], <25 x double> [[C]], i32 5, i32 5, i32 5) + // CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [25 x double]* %a to <25 x double>* + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* [[A_ADDR]], align 8 + // CHECK-NEXT: ret void + // + + dx5x5_t a; + a = b * c; +} + +typedef int ix3x9_t __attribute__((matrix_type(3, 9))); +typedef int ix9x9_t __attribute__((matrix_type(9, 9))); +// CHECK-LABEL: @multiply_matrix_matrix_int( +// CHECK: [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[RES:%.*]] = call <81 x i32> @llvm.matrix.multiply.v81i32.v27i32.v27i32(<27 x i32> [[B]], <27 x i32> [[C]], i32 9, i32 3, i32 9) +// CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [81 x i32]* %a to <81 x i32>* +// CHECK-NEXT: store <81 x i32> [[RES]], <81 x i32>* [[A_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_matrix_matrix_int(ix9x3_t b, ix3x9_t c) { + ix9x9_t a; + a = b * c; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_float( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load float, float* %s.addr, align 4 +// CHECK-NEXT: [[S_EXT:%.*]] = fpext float [[S]] to double +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_float(dx5x5_t a, float s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_double( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_double(dx5x5_t a, double s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_float_matrix_scalar_double( +// CHECK: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = fptrunc double [[S]] to float +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <6 x float> undef, float [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> undef, <6 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_scalar_double(fx2x3_t b, double s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_short( +// CHECK: [[S:%.*]] = load i16, i16* %s.addr, align 2 +// CHECK-NEXT: [[S_EXT:%.*]] = sext i16 [[S]] to i32 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_short(ix9x3_t b, short s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_ull( +// CHECK: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[S:%.*]] = load i64, i64* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[MAT]], [[VECSPLAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) { + b = b * s; +} + +// CHECK-LABEL: @multiply_float_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [6 x float], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>* +// CHECK-NEXT: store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[MAT]], <float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00> +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_constant(fx2x3_t a) { + a = a * 2.5; +} + +// CHECK-LABEL: @multiply_int_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [27 x i32], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [27 x i32]* [[A_ADDR]] to <27 x i32>* +// CHECK-NEXT: store <27 x i32> [[A:%.*]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> <i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5>, [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_constant(ix9x3_t a) { + a = 5 * a; +} + // Tests for the matrix type operators. typedef double dx5x5_t __attribute__((matrix_type(5, 5))); Index: clang/lib/Sema/SemaOverload.cpp =================================================================== --- clang/lib/Sema/SemaOverload.cpp +++ clang/lib/Sema/SemaOverload.cpp @@ -9184,8 +9184,10 @@ case OO_Star: // '*' is either unary or binary if (Args.size() == 1) OpBuilder.addUnaryStarPointerOverloads(); - else + else { OpBuilder.addGenericBinaryArithmeticOverloads(); + OpBuilder.addMatrixBinaryArithmeticOverloads(); + } break; case OO_Slash: Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -10047,6 +10047,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -12109,6 +12112,42 @@ return InvalidOperands(Loc, LHS, RHS); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + if (!IsCompAssign) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + if (LHS.isInvalid()) + return QualType(); + } + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + if (RHS.isInvalid()) + return QualType(); + + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + auto *LHSMatType = LHSType->getAs<ConstantMatrixType>(); + auto *RHSMatType = RHSType->getAs<ConstantMatrixType>(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && RHSMatType) { + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getConstantMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); + } + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { Index: clang/lib/CodeGen/CGExprScalar.cpp =================================================================== --- clang/lib/CodeGen/CGExprScalar.cpp +++ clang/lib/CodeGen/CGExprScalar.cpp @@ -765,6 +765,22 @@ } } + if (Ops.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast<BinaryOperator>(Ops.E); + auto *LHSMatTy = dyn_cast<ConstantMatrixType>( + BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = dyn_cast<ConstantMatrixType>( + BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -11214,6 +11214,8 @@ QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, bool IsCompAssign); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType);
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits