https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/122839
>From 67d9c51dbd0cf78f4cf622f655adb76d519b774b Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Thu, 9 Jan 2025 19:19:27 -0500 Subject: [PATCH 1/3] [SPIRV] add pre legalization instruction combine - Add the boilerplate to support instcombine in SPIRV - instcombine length(X-Y) to distance(X,Y) - switch HLSL's distance intrinsic to not special case for SPIRV. - fixes #122766 --- clang/include/clang/Basic/BuiltinsSPIRV.td | 6 + clang/lib/CodeGen/CGBuiltin.cpp | 10 + clang/lib/Headers/hlsl/hlsl_detail.h | 8 +- clang/lib/Sema/SemaSPIRV.cpp | 18 ++ clang/test/CodeGenHLSL/builtins/distance.hlsl | 30 ++- clang/test/CodeGenHLSL/builtins/length.hlsl | 95 +++++-- clang/test/CodeGenSPIRV/Builtins/length.c | 31 +++ clang/test/SemaSPIRV/BuiltIns/length-errors.c | 25 ++ llvm/lib/Target/SPIRV/CMakeLists.txt | 3 + llvm/lib/Target/SPIRV/SPIRV.h | 2 + llvm/lib/Target/SPIRV/SPIRV.td | 1 + llvm/lib/Target/SPIRV/SPIRVCombine.td | 26 ++ .../SPIRV/SPIRVPreLegalizerCombiner.cpp | 252 ++++++++++++++++++ llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 2 + .../CodeGen/SPIRV/hlsl-intrinsics/distance.ll | 77 +++--- llvm/test/CodeGen/SPIRV/opencl/distance.ll | 11 + 16 files changed, 525 insertions(+), 72 deletions(-) create mode 100644 clang/test/CodeGenSPIRV/Builtins/length.c create mode 100644 clang/test/SemaSPIRV/BuiltIns/length-errors.c create mode 100644 llvm/lib/Target/SPIRV/SPIRVCombine.td create mode 100644 llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td index 1e66939b822ef8..f72c555921dfe6 100644 --- a/clang/include/clang/Basic/BuiltinsSPIRV.td +++ b/clang/include/clang/Basic/BuiltinsSPIRV.td @@ -13,3 +13,9 @@ def SPIRVDistance : Builtin { let Attributes = [NoThrow, Const]; let Prototype = "void(...)"; } + +def SPIRVLength : Builtin { + let Spellings = ["__builtin_spirv_length"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 2385f2a320b625..b80833fd91884d 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -20528,6 +20528,16 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID, /*ReturnType=*/X->getType()->getScalarType(), Intrinsic::spv_distance, ArrayRef<Value *>{X, Y}, nullptr, "spv.distance"); } + case SPIRV::BI__builtin_spirv_length: { + Value *X = EmitScalarExpr(E->getArg(0)); + assert(E->getArg(0)->getType()->hasFloatingRepresentation() && + "length operand must have a float representation"); + assert(E->getArg(0)->getType()->isVectorType() && + "length operand must be a vector"); + return Builder.CreateIntrinsic( + /*ReturnType=*/X->getType()->getScalarType(), Intrinsic::spv_length, + ArrayRef<Value *>{X}, nullptr, "spv.length"); + } } return nullptr; } diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h index 3eb4a3dc861e36..b2c8cc6c5c3dbb 100644 --- a/clang/lib/Headers/hlsl/hlsl_detail.h +++ b/clang/lib/Headers/hlsl/hlsl_detail.h @@ -61,7 +61,11 @@ length_impl(T X) { template <typename T, int N> constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T> length_vec_impl(vector<T, N> X) { +#if (__has_builtin(__builtin_spirv_length)) + return __builtin_spirv_length(X); +#else return __builtin_elementwise_sqrt(__builtin_hlsl_dot(X, X)); +#endif } template <typename T> @@ -73,11 +77,7 @@ distance_impl(T X, T Y) { template <typename T, int N> constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T> distance_vec_impl(vector<T, N> X, vector<T, N> Y) { -#if (__has_builtin(__builtin_spirv_distance)) - return __builtin_spirv_distance(X, Y); -#else return length_vec_impl(X - Y); -#endif } } // namespace __detail } // namespace hlsl diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp index d2de64826c6eb3..dc49fc79073572 100644 --- a/clang/lib/Sema/SemaSPIRV.cpp +++ b/clang/lib/Sema/SemaSPIRV.cpp @@ -51,6 +51,24 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, TheCall->setType(RetTy); break; } + case SPIRV::BI__builtin_spirv_length: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + ExprResult A = TheCall->getArg(0); + QualType ArgTyA = A.get()->getType(); + auto *VTy = ArgTyA->getAs<VectorType>(); + if (VTy == nullptr) { + SemaRef.Diag(A.get()->getBeginLoc(), + diag::err_typecheck_convert_incompatible) + << ArgTyA + << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 + << 0 << 0; + return true; + } + QualType RetTy = VTy->getElementType(); + TheCall->setType(RetTy); + break; + } } return false; } diff --git a/clang/test/CodeGenHLSL/builtins/distance.hlsl b/clang/test/CodeGenHLSL/builtins/distance.hlsl index 6952700a87f1df..e830903261c8cf 100644 --- a/clang/test/CodeGenHLSL/builtins/distance.hlsl +++ b/clang/test/CodeGenHLSL/builtins/distance.hlsl @@ -33,8 +33,9 @@ half test_distance_half(half X, half Y) { return distance(X, Y); } // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z19test_distance_half2Dv2_DhS_( // SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[X:%.*]], <2 x half> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] { // SPVCHECK-NEXT: [[ENTRY:.*:]] -// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.distance.v2f16(<2 x half> [[X]], <2 x half> [[Y]]) -// SPVCHECK-NEXT: ret half [[SPV_DISTANCE_I]] +// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[X]], [[Y]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v2f16(<2 x half> [[SUB_I]]) +// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]] // half test_distance_half2(half2 X, half2 Y) { return distance(X, Y); } @@ -49,8 +50,9 @@ half test_distance_half2(half2 X, half2 Y) { return distance(X, Y); } // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z19test_distance_half3Dv3_DhS_( // SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[X:%.*]], <3 x half> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] { // SPVCHECK-NEXT: [[ENTRY:.*:]] -// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.distance.v3f16(<3 x half> [[X]], <3 x half> [[Y]]) -// SPVCHECK-NEXT: ret half [[SPV_DISTANCE_I]] +// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[X]], [[Y]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v3f16(<3 x half> [[SUB_I]]) +// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]] // half test_distance_half3(half3 X, half3 Y) { return distance(X, Y); } @@ -65,8 +67,9 @@ half test_distance_half3(half3 X, half3 Y) { return distance(X, Y); } // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z19test_distance_half4Dv4_DhS_( // SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[X:%.*]], <4 x half> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] { // SPVCHECK-NEXT: [[ENTRY:.*:]] -// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.distance.v4f16(<4 x half> [[X]], <4 x half> [[Y]]) -// SPVCHECK-NEXT: ret half [[SPV_DISTANCE_I]] +// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[X]], [[Y]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v4f16(<4 x half> [[SUB_I]]) +// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]] // half test_distance_half4(half4 X, half4 Y) { return distance(X, Y); } @@ -97,8 +100,9 @@ float test_distance_float(float X, float Y) { return distance(X, Y); } // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z20test_distance_float2Dv2_fS_( // SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[X:%.*]], <2 x float> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] { // SPVCHECK-NEXT: [[ENTRY:.*:]] -// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.distance.v2f32(<2 x float> [[X]], <2 x float> [[Y]]) -// SPVCHECK-NEXT: ret float [[SPV_DISTANCE_I]] +// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[X]], [[Y]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v2f32(<2 x float> [[SUB_I]]) +// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]] // float test_distance_float2(float2 X, float2 Y) { return distance(X, Y); } @@ -113,8 +117,9 @@ float test_distance_float2(float2 X, float2 Y) { return distance(X, Y); } // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z20test_distance_float3Dv3_fS_( // SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[X:%.*]], <3 x float> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] { // SPVCHECK-NEXT: [[ENTRY:.*:]] -// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.distance.v3f32(<3 x float> [[X]], <3 x float> [[Y]]) -// SPVCHECK-NEXT: ret float [[SPV_DISTANCE_I]] +// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[X]], [[Y]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v3f32(<3 x float> [[SUB_I]]) +// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]] // float test_distance_float3(float3 X, float3 Y) { return distance(X, Y); } @@ -129,7 +134,8 @@ float test_distance_float3(float3 X, float3 Y) { return distance(X, Y); } // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z20test_distance_float4Dv4_fS_( // SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[X:%.*]], <4 x float> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] { // SPVCHECK-NEXT: [[ENTRY:.*:]] -// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.distance.v4f32(<4 x float> [[X]], <4 x float> [[Y]]) -// SPVCHECK-NEXT: ret float [[SPV_DISTANCE_I]] +// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[X]], [[Y]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v4f32(<4 x float> [[SUB_I]]) +// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]] // float test_distance_float4(float4 X, float4 Y) { return distance(X, Y); } diff --git a/clang/test/CodeGenHLSL/builtins/length.hlsl b/clang/test/CodeGenHLSL/builtins/length.hlsl index fcf3ee76ba5bbd..2d4bbd995298f2 100644 --- a/clang/test/CodeGenHLSL/builtins/length.hlsl +++ b/clang/test/CodeGenHLSL/builtins/length.hlsl @@ -1,114 +1,163 @@ -// RUN: %clang_cc1 -finclude-default-header -triple \ -// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ -// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,DXCHECK \ -// RUN: -DTARGET=dx +// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ +// RUN: -emit-llvm -O1 -o - | FileCheck %s // RUN: %clang_cc1 -finclude-default-header -triple \ // RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \ -// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,SPVCHECK \ -// RUN: -DTARGET=spv +// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z16test_length_halfDh( // DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z16test_length_halfDh( +// + +// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z16test_length_halfDh( // CHECK-SAME: half noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.fabs.f16(half [[P0]]) // CHECK-NEXT: ret half [[ELT_ABS_I]] // - +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z16test_length_halfDh( +// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.fabs.f16(half [[P0]]) +// SPVCHECK-NEXT: ret half [[ELT_ABS_I]] +// half test_length_half(half p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh( // DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh( +// + + +// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh( // CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] -// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.[[TARGET]].fdot.v2f16(<2 x half> [[P0]], <2 x half> [[P0]]) +// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[P0]], <2 x half> [[P0]]) // CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.sqrt.f16(half [[HLSL_DOT_I]]) // CHECK-NEXT: ret half [[TMP0]] // - - +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh( +// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v2f16(<2 x half> [[P0]]) +// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]] +// half test_length_half2(half2 p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh( // DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh( +// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh( // CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] -// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.[[TARGET]].fdot.v3f16(<3 x half> [[P0]], <3 x half> [[P0]]) +// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[P0]], <3 x half> [[P0]]) // CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.sqrt.f16(half [[HLSL_DOT_I]]) // CHECK-NEXT: ret half [[TMP0]] // +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh( +// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v3f16(<3 x half> [[P0]]) +// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]] +// half test_length_half3(half3 p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh( // DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh( +// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh( // CHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] -// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.[[TARGET]].fdot.v4f16(<4 x half> [[P0]], <4 x half> [[P0]]) +// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> [[P0]], <4 x half> [[P0]]) // CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.sqrt.f16(half [[HLSL_DOT_I]]) // CHECK-NEXT: ret half [[TMP0]] // +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh( +// SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v4f16(<4 x half> [[P0]]) +// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]] +// half test_length_half4(half4 p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z17test_length_floatf( // DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z17test_length_floatf( +// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z17test_length_floatf( // CHECK-SAME: float noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.fabs.f32(float [[P0]]) // CHECK-NEXT: ret float [[ELT_ABS_I]] // +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z17test_length_floatf( +// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.fabs.f32(float [[P0]]) +// SPVCHECK-NEXT: ret float [[ELT_ABS_I]] +// float test_length_float(float p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f( // DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f( +// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f( // CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] -// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].fdot.v2f32(<2 x float> [[P0]], <2 x float> [[P0]]) +// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> [[P0]], <2 x float> [[P0]]) // CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.sqrt.f32(float [[HLSL_DOT_I]]) // CHECK-NEXT: ret float [[TMP0]] // +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f( +// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v2f32(<2 x float> [[P0]]) +// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]] +// float test_length_float2(float2 p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f( // DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f( +// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f( // CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] -// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].fdot.v3f32(<3 x float> [[P0]], <3 x float> [[P0]]) +// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> [[P0]], <3 x float> [[P0]]) // CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.sqrt.f32(float [[HLSL_DOT_I]]) // CHECK-NEXT: ret float [[TMP0]] // +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f( +// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v3f32(<3 x float> [[P0]]) +// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]] +// float test_length_float3(float3 p0) { return length(p0); } -// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f( // DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f( +// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f( // CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] -// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].fdot.v4f32(<4 x float> [[P0]], <4 x float> [[P0]]) +// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> [[P0]], <4 x float> [[P0]]) // CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.sqrt.f32(float [[HLSL_DOT_I]]) // CHECK-NEXT: ret float [[TMP0]] // +// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f( +// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] { +// SPVCHECK-NEXT: [[ENTRY:.*:]] +// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v4f32(<4 x float> [[P0]]) +// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]] +// float test_length_float4(float4 p0) { return length(p0); diff --git a/clang/test/CodeGenSPIRV/Builtins/length.c b/clang/test/CodeGenSPIRV/Builtins/length.c new file mode 100644 index 00000000000000..59e7c298dd8167 --- /dev/null +++ b/clang/test/CodeGenSPIRV/Builtins/length.c @@ -0,0 +1,31 @@ +// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 + +// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s + +typedef float float2 __attribute__((ext_vector_type(2))); +typedef float float3 __attribute__((ext_vector_type(3))); +typedef float float4 __attribute__((ext_vector_type(4))); + +// CHECK-LABEL: define spir_func float @test_length_float2( +// CHECK-SAME: <2 x float> noundef [[X:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[SPV_LENGTH:%.*]] = tail call float @llvm.spv.length.v2f32(<2 x float> [[X]]) +// CHECK-NEXT: ret float [[SPV_LENGTH]] +// +float test_length_float2(float2 X) { return __builtin_spirv_length(X); } + +// CHECK-LABEL: define spir_func float @test_length_float3( +// CHECK-SAME: <3 x float> noundef [[X:%.*]]) local_unnamed_addr #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[SPV_LENGTH:%.*]] = tail call float @llvm.spv.length.v3f32(<3 x float> [[X]]) +// CHECK-NEXT: ret float [[SPV_LENGTH]] +// +float test_length_float3(float3 X) { return __builtin_spirv_length(X); } + +// CHECK-LABEL: define spir_func float @test_length_float4( +// CHECK-SAME: <4 x float> noundef [[X:%.*]]) local_unnamed_addr #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[SPV_LENGTH:%.*]] = tail call float @llvm.spv.length.v4f32(<4 x float> [[X]]) +// CHECK-NEXT: ret float [[SPV_LENGTH]] +// +float test_length_float4(float4 X) { return __builtin_spirv_length(X); } diff --git a/clang/test/SemaSPIRV/BuiltIns/length-errors.c b/clang/test/SemaSPIRV/BuiltIns/length-errors.c new file mode 100644 index 00000000000000..3244bd6737f116 --- /dev/null +++ b/clang/test/SemaSPIRV/BuiltIns/length-errors.c @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 %s -triple spirv-pc-vulkan-compute -verify + +typedef float float2 __attribute__((ext_vector_type(2))); + +void test_too_few_arg() +{ + return __builtin_spirv_length(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +void test_too_many_arg(float2 p0) +{ + return __builtin_spirv_length(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +float test_double_scalar_inputs(double p0) { + return __builtin_spirv_length(p0); + // expected-error@-1 {{passing 'double' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(double)))) double' (vector of 2 'double' values)}} +} + +float test_int_scalar_inputs(int p0) { + return __builtin_spirv_length(p0); + // expected-error@-1 {{passing 'int' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(int)))) int' (vector of 2 'int' values)}} +} diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index a79e19fcd753dc..efdd8c8d24fbd5 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -10,6 +10,8 @@ tablegen(LLVM SPIRVGenRegisterBank.inc -gen-register-bank) tablegen(LLVM SPIRVGenRegisterInfo.inc -gen-register-info) tablegen(LLVM SPIRVGenSubtargetInfo.inc -gen-subtarget) tablegen(LLVM SPIRVGenTables.inc -gen-searchable-tables) +tablegen(LLVM SPIRVGenPreLegalizeGICombiner.inc -gen-global-isel-combiner + -combiners="SPIRVPreLegalizerCombiner") add_public_tablegen_target(SPIRVCommonTableGen) @@ -33,6 +35,7 @@ add_llvm_target(SPIRVCodeGen SPIRVModuleAnalysis.cpp SPIRVStructurizer.cpp SPIRVPreLegalizer.cpp + SPIRVPreLegalizerCombiner.cpp SPIRVPostLegalizer.cpp SPIRVPrepareFunctions.cpp SPIRVRegisterBankInfo.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index 81b57202644256..6d00a046ff7caa 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -24,6 +24,7 @@ FunctionPass *createSPIRVStructurizerPass(); FunctionPass *createSPIRVMergeRegionExitTargetsPass(); FunctionPass *createSPIRVStripConvergenceIntrinsicsPass(); FunctionPass *createSPIRVRegularizerPass(); +FunctionPass *createSPIRVPreLegalizerCombiner(); FunctionPass *createSPIRVPreLegalizerPass(); FunctionPass *createSPIRVPostLegalizerPass(); ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM); @@ -36,6 +37,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, void initializeSPIRVModuleAnalysisPass(PassRegistry &); void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &); void initializeSPIRVPreLegalizerPass(PassRegistry &); +void initializeSPIRVPreLegalizerCombinerPass(PassRegistry &); void initializeSPIRVPostLegalizerPass(PassRegistry &); void initializeSPIRVStructurizerPass(PassRegistry &); void initializeSPIRVEmitIntrinsicsPass(PassRegistry &); diff --git a/llvm/lib/Target/SPIRV/SPIRV.td b/llvm/lib/Target/SPIRV/SPIRV.td index 108c7e6d3861f0..39a4131c7f1bdf 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.td +++ b/llvm/lib/Target/SPIRV/SPIRV.td @@ -11,6 +11,7 @@ include "llvm/Target/Target.td" include "SPIRVRegisterInfo.td" include "SPIRVRegisterBanks.td" include "SPIRVInstrInfo.td" +include "SPIRVCombine.td" include "SPIRVBuiltins.td" def SPIRVInstrInfo : InstrInfo; diff --git a/llvm/lib/Target/SPIRV/SPIRVCombine.td b/llvm/lib/Target/SPIRV/SPIRVCombine.td new file mode 100644 index 00000000000000..11851894e2f752 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVCombine.td @@ -0,0 +1,26 @@ +//=- SPIRVCombine.td - Define SPIRV Combine Rules -------------*-tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +include "llvm/Target/GlobalISel/Combine.td" + + +def vector_length_sub_to_distance_lowering : GICombineRule < + (defs root:$root), + (match (wip_match_opcode G_INTRINSIC):$root, + [{ return matchLengthToDistance(*${root}, MRI); }]), + (apply [{ applySPIRVDistance(*${root}, MRI, B); }]) +>; + +def SPIRVPreLegalizerCombiner + : GICombiner<"SPIRVPreLegalizerCombinerImpl", + [vector_length_sub_to_distance_lowering]> { + let CombineAllMethodName = "tryCombineAllImpl"; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp new file mode 100644 index 00000000000000..54b65e8b04d622 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp @@ -0,0 +1,252 @@ + +//===-- SPIRVPreLegalizerCombiner.cpp - combine legalization ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass does combining of machine instructions at the generic MI level, +// before the legalizer. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVTargetMachine.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/CodeGen/GlobalISel/CSEInfo.h" +#include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h" +#include "llvm/CodeGen/GlobalISel/Combiner.h" +#include "llvm/CodeGen/GlobalISel/CombinerHelper.h" +#include "llvm/CodeGen/GlobalISel/CombinerInfo.h" +#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" +#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" +#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" +#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/GlobalISel/Utils.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" + +#define GET_GICOMBINER_DEPS +#include "SPIRVGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_DEPS + +#define DEBUG_TYPE "spirv-prelegalizer-combiner" + +using namespace llvm; +using namespace MIPatternMatch; + +namespace { + +#define GET_GICOMBINER_TYPES +#include "SPIRVGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_TYPES + +bool matchLengthToDistance(MachineInstr &MI, MachineRegisterInfo &MRI) { + if (MI.getOpcode() != TargetOpcode::G_INTRINSIC || + cast<GIntrinsic>(MI).getIntrinsicID() != Intrinsic::spv_length) + return false; + + // First operand of MI is `G_INTRINSIC` so start at operand 2. + Register SubAssignTypeReg = MI.getOperand(2).getReg(); + MachineInstr *Sub1AssignTypeInst = MRI.getVRegDef(SubAssignTypeReg); + if (!Sub1AssignTypeInst || + Sub1AssignTypeInst->getDesc().getOpcode() != SPIRV::ASSIGN_TYPE) + return false; + + Register SubReg1 = Sub1AssignTypeInst->getOperand(1).getReg(); + MachineInstr *SubInstr1 = MRI.getVRegDef(SubReg1); + if (!SubInstr1 || SubInstr1->getOpcode() != TargetOpcode::G_FSUB) + return false; + + return true; +} +void applySPIRVDistance(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &B) { + + // Extract the operands for X and Y from the match criteria. + Register SubAssignTypeReg = MI.getOperand(2).getReg(); + MachineInstr *Sub1AssignTypeInst = MRI.getVRegDef(SubAssignTypeReg); + Register SubDestReg = Sub1AssignTypeInst->getOperand(1).getReg(); + MachineInstr *SubInstr = MRI.getVRegDef(SubDestReg); + Register SubOperand1 = SubInstr->getOperand(1).getReg(); + Register SubOperand2 = SubInstr->getOperand(2).getReg(); + + // Remove the original `spv_length` instruction. + + Register ResultReg = MI.getOperand(0).getReg(); + DebugLoc DL = MI.getDebugLoc(); + MachineBasicBlock &MBB = *MI.getParent(); + MachineBasicBlock::iterator InsertPt = MI.getIterator(); + + // Build the `spv_distance` intrinsic. + MachineInstrBuilder NewInstr = + BuildMI(MBB, InsertPt, DL, B.getTII().get(TargetOpcode::G_INTRINSIC)); + NewInstr + .addDef(ResultReg) // Result register + .addIntrinsicID(Intrinsic::spv_distance) // Intrinsic ID + .addUse(SubOperand1) // Operand X + .addUse(SubOperand2); // Operand Y + + auto RemoveAllUses = [&](Register Reg) { + for (auto &UseMI : MRI.use_instructions(Reg)) { + UseMI.eraseFromParent(); + } + }; + + RemoveAllUses( + SubAssignTypeReg); // remove all uses of FSUB ASSIGN_TYPE register + MI.eraseFromParent(); // remove spv_length intrinsic + RemoveAllUses(SubDestReg); // remove all uses of FSUB Result + SubInstr->eraseFromParent(); // remove FSUB instruction +} + +class SPIRVPreLegalizerCombinerImpl : public Combiner { +protected: + const CombinerHelper Helper; + const SPIRVPreLegalizerCombinerImplRuleConfig &RuleConfig; + const SPIRVSubtarget &STI; + +public: + SPIRVPreLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelKnownBits &KB, GISelCSEInfo *CSEInfo, + const SPIRVPreLegalizerCombinerImplRuleConfig &RuleConfig, + const SPIRVSubtarget &STI, MachineDominatorTree *MDT, + const LegalizerInfo *LI); + + static const char *getName() { return "SPIRV00PreLegalizerCombiner"; } + + bool tryCombineAll(MachineInstr &I) const override; + + bool tryCombineAllImpl(MachineInstr &I) const; + +private: +#define GET_GICOMBINER_CLASS_MEMBERS +#include "SPIRVGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CLASS_MEMBERS +}; + +#define GET_GICOMBINER_IMPL +#include "SPIRVGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_IMPL + +SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelKnownBits &KB, GISelCSEInfo *CSEInfo, + const SPIRVPreLegalizerCombinerImplRuleConfig &RuleConfig, + const SPIRVSubtarget &STI, MachineDominatorTree *MDT, + const LegalizerInfo *LI) + : Combiner(MF, CInfo, TPC, &KB, CSEInfo), + Helper(Observer, B, /*IsPreLegalize*/ true, &KB, MDT, LI), + RuleConfig(RuleConfig), STI(STI), +#define GET_GICOMBINER_CONSTRUCTOR_INITS +#include "SPIRVGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CONSTRUCTOR_INITS +{ +} + +bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const { + return tryCombineAllImpl(MI); +} + +// Pass boilerplate +// ================ + +class SPIRVPreLegalizerCombiner : public MachineFunctionPass { +public: + static char ID; + + SPIRVPreLegalizerCombiner(); + + StringRef getPassName() const override { return "SPIRVPreLegalizerCombiner"; } + + bool runOnMachineFunction(MachineFunction &MF) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + SPIRVPreLegalizerCombinerImplRuleConfig RuleConfig; +}; + +} // end anonymous namespace + +void SPIRVPreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetPassConfig>(); + AU.setPreservesCFG(); + getSelectionDAGFallbackAnalysisUsage(AU); + AU.addRequired<GISelKnownBitsAnalysis>(); + AU.addPreserved<GISelKnownBitsAnalysis>(); + AU.addRequired<MachineDominatorTreeWrapperPass>(); + AU.addPreserved<MachineDominatorTreeWrapperPass>(); + AU.addRequired<GISelCSEAnalysisWrapperPass>(); + AU.addPreserved<GISelCSEAnalysisWrapperPass>(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +SPIRVPreLegalizerCombiner::SPIRVPreLegalizerCombiner() + : MachineFunctionPass(ID) { + initializeSPIRVPreLegalizerCombinerPass(*PassRegistry::getPassRegistry()); + + if (!RuleConfig.parseCommandLineOption()) + report_fatal_error("Invalid rule identifier"); +} + +bool SPIRVPreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { + if (MF.getProperties().hasProperty( + MachineFunctionProperties::Property::FailedISel)) + return false; + auto &TPC = getAnalysis<TargetPassConfig>(); + + // Enable CSE. + GISelCSEAnalysisWrapper &Wrapper = + getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); + auto *CSEInfo = &Wrapper.get(TPC.getCSEConfig()); + + const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>(); + const auto *LI = ST.getLegalizerInfo(); + + const Function &F = MF.getFunction(); + bool EnableOpt = + MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); + GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); + MachineDominatorTree *MDT = + &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); + CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, + /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), + F.hasMinSize()); + // Disable fixed-point iteration to reduce compile-time + CInfo.MaxIterations = 1; + CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass; + // This is the first Combiner, so the input IR might contain dead + // instructions. + CInfo.EnableFullDCE = true; + SPIRVPreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *KB, CSEInfo, RuleConfig, + ST, MDT, LI); + return Impl.combineMachineInstrs(); +} + +char SPIRVPreLegalizerCombiner::ID = 0; +INITIALIZE_PASS_BEGIN(SPIRVPreLegalizerCombiner, DEBUG_TYPE, + "Combine SPIRV machine instrs before legalization", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) +INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass) +INITIALIZE_PASS_END(SPIRVPreLegalizerCombiner, DEBUG_TYPE, + "Combine SPIRV machine instrs before legalization", false, + false) + +namespace llvm { +FunctionPass *createSPIRVPreLegalizerCombiner() { + return new SPIRVPreLegalizerCombiner(); +} +} // end namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index dca67cb6c632bd..c9cee09cafca3f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -48,6 +48,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() { initializeSPIRVModuleAnalysisPass(PR); initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PR); initializeSPIRVStructurizerPass(PR); + initializeSPIRVPreLegalizerCombinerPass(PR); } static std::string computeDataLayout(const Triple &TT) { @@ -218,6 +219,7 @@ bool SPIRVPassConfig::addIRTranslator() { void SPIRVPassConfig::addPreLegalizeMachineIR() { addPass(createSPIRVPreLegalizerPass()); + addPass(createSPIRVPreLegalizerCombiner()); } // Use the default legalizer. diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll index 85a24a0127ae04..fac5d5f9fbd0d2 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll @@ -1,33 +1,44 @@ -; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} - -; Make sure SPIRV operation function calls for distance are lowered correctly. - -; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450" -; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16 -; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4 -; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 -; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 - -define noundef half @distance_half4(<4 x half> noundef %a, <4 x half> noundef %b) { -entry: - ; CHECK: %[[#]] = OpFunction %[[#float_16]] None %[[#]] - ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]] - ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]] - ; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]] - %spv.distance = call half @llvm.spv.distance.f16(<4 x half> %a, <4 x half> %b) - ret half %spv.distance -} - -define noundef float @distance_float4(<4 x float> noundef %a, <4 x float> noundef %b) { -entry: - ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]] - ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]] - ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]] - ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]] - %spv.distance = call float @llvm.spv.distance.f32(<4 x float> %a, <4 x float> %b) - ret float %spv.distance -} - -declare half @llvm.spv.distance.f16(<4 x half>, <4 x half>) -declare float @llvm.spv.distance.f32(<4 x float>, <4 x float>) +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; Make sure SPIRV operation function calls for distance are lowered correctly. + +; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450" +; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4 +; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 + +define noundef half @distance_half4(<4 x half> noundef %a, <4 x half> noundef %b) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#float_16]] None %[[#]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]] + ; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]] + %spv.distance = call half @llvm.spv.distance.f16(<4 x half> %a, <4 x half> %b) + ret half %spv.distance +} + +define noundef float @distance_float4(<4 x float> noundef %a, <4 x float> noundef %b) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]] + %spv.distance = call float @llvm.spv.distance.f32(<4 x float> %a, <4 x float> %b) + ret float %spv.distance +} + +define noundef float @distance_instcombine_float4(<4 x float> noundef %a, <4 x float> noundef %b) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]] + %delta = fsub <4 x float> %a, %b + %spv.length = call float @llvm.spv.length.f32(<4 x float> %delta) + ret float %spv.length +} + +declare half @llvm.spv.distance.f16(<4 x half>, <4 x half>) +declare float @llvm.spv.distance.f32(<4 x float>, <4 x float>) diff --git a/llvm/test/CodeGen/SPIRV/opencl/distance.ll b/llvm/test/CodeGen/SPIRV/opencl/distance.ll index ac18804c00c9ab..ed329175e9c07f 100644 --- a/llvm/test/CodeGen/SPIRV/opencl/distance.ll +++ b/llvm/test/CodeGen/SPIRV/opencl/distance.ll @@ -30,5 +30,16 @@ entry: ret float %spv.distance } +define noundef float @distance_instcombine_float4(<4 x float> noundef %a, <4 x float> noundef %b) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_cl]] distance %[[#arg0]] %[[#arg1]] + %delta = fsub <4 x float> %a, %b + %spv.length = call float @llvm.spv.length.f32(<4 x float> %delta) + ret float %spv.length +} + declare half @llvm.spv.distance.f16(<4 x half>, <4 x half>) declare float @llvm.spv.distance.f32(<4 x float>, <4 x float>) >From b90c08d2ace878865f8ae60f2e65c8a0da2df8c4 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Wed, 15 Jan 2025 13:12:25 -0500 Subject: [PATCH 2/3] address pr comments --- llvm/lib/Target/SPIRV/SPIRVCombine.td | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVCombine.td b/llvm/lib/Target/SPIRV/SPIRVCombine.td index 11851894e2f752..6f726e024de525 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCombine.td +++ b/llvm/lib/Target/SPIRV/SPIRVCombine.td @@ -4,10 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===----------------------------------------------------------------------===// -// -// -//===----------------------------------------------------------------------===// include "llvm/Target/GlobalISel/Combine.td" >From be94f37f3f81c26bc287103bfe37456c9afd7f35 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Thu, 16 Jan 2025 14:07:44 -0500 Subject: [PATCH 3/3] address pr comments describing transformation --- llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp index 54b65e8b04d622..fbe0c8ad9de5be 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp @@ -50,6 +50,14 @@ namespace { #include "SPIRVGenPreLegalizeGICombiner.inc" #undef GET_GICOMBINER_TYPES +/// This match is part of a combine that +/// rewrites length(X - Y) to distance(X, Y) +/// (f32 (g_intrinsic length +/// (g_fsub (vXf32 X) (vXf32 Y)))) +/// -> +/// (f32 (g_intrinsic distance +/// (vXf32 X) (vXf32 Y))) +/// bool matchLengthToDistance(MachineInstr &MI, MachineRegisterInfo &MRI) { if (MI.getOpcode() != TargetOpcode::G_INTRINSIC || cast<GIntrinsic>(MI).getIntrinsicID() != Intrinsic::spv_length) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits