https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/118020
Backport 12cefcc7ecd2615069206b35b0ea81b9e78bb1ea Requested by: @fhahn >From c186e5d3e0e88ef5f285c0fd5f33ae826f7d9221 Mon Sep 17 00:00:00 2001 From: Florian Hahn <f...@fhahn.com> Date: Thu, 28 Nov 2024 16:11:39 +0000 Subject: [PATCH] [Matrix] Skip already fused instructions before trying to fuse multiply. lowerDotProduct called above may already lower a matrix multiply and mark it as procssed by adding it to FusedInsts. Don't try to process it again in LowerMatrixMultiplyFused by checking if FusedInsts. Without this change, we trigger an assertion when trying to erase the same original matrix multiply twice. (cherry picked from commit 12cefcc7ecd2615069206b35b0ea81b9e78bb1ea) --- .../Scalar/LowerMatrixIntrinsics.cpp | 3 +- .../dot-product-int-also-fusable-multiply.ll | 49 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 6a681fd9339717..a44a123fdf8cda 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1014,7 +1014,8 @@ class LowerMatrixIntrinsics { // Third, try to fuse candidates. for (CallInst *CI : MaybeFusableInsts) - LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); + if (!FusedInsts.contains(CI)) + LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); Changed = !FusedInsts.empty(); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll new file mode 100644 index 00000000000000..b78d56646d9e4d --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s + +define void @test(ptr %p, <8 x i32> %x) { +; CHECK-LABEL: define void @test( +; CHECK-SAME: ptr [[P:%.*]], <8 x i32> [[X:%.*]]) { +; CHECK-NEXT: [[L:%.*]] = load <8 x i32>, ptr [[P]], align 4 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 1> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 2> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 3> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 4> +; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 5> +; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 6> +; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 7> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <1 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x i32> poison, i32 [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <1 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <8 x i32> [[TMP2]], i32 [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <1 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <8 x i32> [[TMP4]], i32 [[TMP5]], i64 2 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <1 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[TMP7]], i64 3 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <1 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <8 x i32> [[TMP8]], i32 [[TMP9]], i64 4 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <1 x i32> [[SPLIT5]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <8 x i32> [[TMP10]], i32 [[TMP11]], i64 5 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <1 x i32> [[SPLIT6]], i64 0 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <8 x i32> [[TMP12]], i32 [[TMP13]], i64 6 +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <1 x i32> [[SPLIT7]], i64 0 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x i32> [[TMP14]], i32 [[TMP15]], i64 7 +; CHECK-NEXT: [[TMP17:%.*]] = mul <8 x i32> [[L]], [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP17]]) +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i64 0 +; CHECK-NEXT: [[E:%.*]] = extractelement <1 x i32> [[TMP19]], i64 0 +; CHECK-NEXT: store i32 [[E]], ptr [[P]], align 4 +; CHECK-NEXT: ret void +; + %l = load <8 x i32>, ptr %p, align 4 + %t = tail call <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32> %x, i32 1, i32 8) + %m = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %l, <8 x i32> %t, i32 1, i32 8, i32 1) + %e = extractelement <1 x i32> %m, i64 0 + store i32 %e, ptr %p, align 4 + ret void +} + +declare <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32>, i32 immarg, i32 immarg) + +declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32 immarg, i32 immarg, i32 immarg) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits