From: Nicolai Hähnle <nicolai.haeh...@amd.com>

Order-aware scan/reduce can trade-off LDS traffic for external atomics
memory traffic in producer/consumer compute shaders.
---
 src/amd/common/ac_llvm_build.c | 195 ++++++++++++++++++++++++++++++++-
 src/amd/common/ac_llvm_build.h |  36 ++++++
 2 files changed, 227 insertions(+), 4 deletions(-)

diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c
index 68c8bad9e83..932f4bbdeef 100644
--- a/src/amd/common/ac_llvm_build.c
+++ b/src/amd/common/ac_llvm_build.c
@@ -3345,68 +3345,88 @@ ac_build_alu_op(struct ac_llvm_context *ctx, 
LLVMValueRef lhs, LLVMValueRef rhs,
                                        _64bit ? ctx->f64 : ctx->f32,
                                        (LLVMValueRef[]){lhs, rhs}, 2, 
AC_FUNC_ATTR_READNONE);
        case nir_op_iand: return LLVMBuildAnd(ctx->builder, lhs, rhs, "");
        case nir_op_ior: return LLVMBuildOr(ctx->builder, lhs, rhs, "");
        case nir_op_ixor: return LLVMBuildXor(ctx->builder, lhs, rhs, "");
        default:
                unreachable("bad reduction intrinsic");
        }
 }
 
-/* TODO: add inclusive and excluse scan functions for SI chip class.  */
+/**
+ * \param maxprefix specifies that the result only needs to be correct for a
+ *     prefix of this many threads
+ *
+ * TODO: add inclusive and excluse scan functions for SI chip class.
+ */
 static LLVMValueRef
-ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, 
LLVMValueRef identity)
+ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, 
LLVMValueRef identity,
+             unsigned maxprefix)
 {
        LLVMValueRef result, tmp;
        result = src;
+       if (maxprefix <= 1)
+               return result;
        tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
        result = ac_build_alu_op(ctx, result, tmp, op);
+       if (maxprefix <= 2)
+               return result;
        tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(2), 0xf, 0xf, false);
        result = ac_build_alu_op(ctx, result, tmp, op);
+       if (maxprefix <= 3)
+               return result;
        tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(3), 0xf, 0xf, false);
        result = ac_build_alu_op(ctx, result, tmp, op);
+       if (maxprefix <= 4)
+               return result;
        tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(4), 0xf, 0xe, 
false);
        result = ac_build_alu_op(ctx, result, tmp, op);
+       if (maxprefix <= 8)
+               return result;
        tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(8), 0xf, 0xc, 
false);
        result = ac_build_alu_op(ctx, result, tmp, op);
+       if (maxprefix <= 16)
+               return result;
        tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, 
false);
        result = ac_build_alu_op(ctx, result, tmp, op);
+       if (maxprefix <= 32)
+               return result;
        tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, 
false);
        result = ac_build_alu_op(ctx, result, tmp, op);
        return result;
 }
 
 LLVMValueRef
 ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op 
op)
 {
        ac_build_optimization_barrier(ctx, &src);
        LLVMValueRef result;
        LLVMValueRef identity =
                get_reduction_identity(ctx, op, 
ac_get_type_size(LLVMTypeOf(src)));
        result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, 
identity),
                                  LLVMTypeOf(identity), "");
-       result = ac_build_scan(ctx, op, result, identity);
+       result = ac_build_scan(ctx, op, result, identity, 64);
 
        return ac_build_wwm(ctx, result);
 }
 
 LLVMValueRef
 ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op 
op)
 {
        ac_build_optimization_barrier(ctx, &src);
        LLVMValueRef result;
        LLVMValueRef identity =
                get_reduction_identity(ctx, op, 
ac_get_type_size(LLVMTypeOf(src)));
        result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, 
identity),
                                  LLVMTypeOf(identity), "");
        result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, 
false);
-       result = ac_build_scan(ctx, op, result, identity);
+       result = ac_build_scan(ctx, op, result, identity, 64);
 
        return ac_build_wwm(ctx, result);
 }
 
 LLVMValueRef
 ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, 
unsigned cluster_size)
 {
        if (cluster_size == 1) return src;
        ac_build_optimization_barrier(ctx, &src);
        LLVMValueRef result, swap;
@@ -3450,20 +3470,187 @@ ac_build_reduce(struct ac_llvm_context *ctx, 
LLVMValueRef src, nir_op op, unsign
                result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 
63, 0));
                return ac_build_wwm(ctx, result);
        } else {
                swap = ac_build_readlane(ctx, result, ctx->i32_0);
                result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 
32, 0));
                result = ac_build_alu_op(ctx, result, swap, op);
                return ac_build_wwm(ctx, result);
        }
 }
 
+/**
+ * "Top half" of a scan that reduces per-wave values across an entire
+ * workgroup.
+ *
+ * The source value must be present in the highest lane of the wave, and the
+ * highest lane must be live.
+ */
+void
+ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
+{
+       if (ws->maxwaves <= 1)
+               return;
+
+       const LLVMValueRef i32_63 = LLVMConstInt(ctx->i32, 63, false);
+       LLVMBuilderRef builder = ctx->builder;
+       LLVMValueRef tid = ac_get_thread_id(ctx);
+       LLVMValueRef tmp;
+
+       tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, i32_63, "");
+       ac_build_ifcc(ctx, tmp, 1000);
+       LLVMBuildStore(builder, ws->src, LLVMBuildGEP(builder, ws->scratch, 
&ws->waveidx, 1, ""));
+       ac_build_endif(ctx, 1000);
+}
+
+/**
+ * "Bottom half" of a scan that reduces per-wave values across an entire
+ * workgroup.
+ *
+ * The caller must place a barrier between the top and bottom halves.
+ */
+void
+ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
+{
+       const LLVMTypeRef type = LLVMTypeOf(ws->src);
+       const LLVMValueRef identity =
+               get_reduction_identity(ctx, ws->op, ac_get_type_size(type));
+
+       if (ws->maxwaves <= 1) {
+               ws->result_reduce = ws->src;
+               ws->result_inclusive = ws->src;
+               ws->result_exclusive = identity;
+               return;
+       }
+       assert(ws->maxwaves <= 32);
+
+       LLVMBuilderRef builder = ctx->builder;
+       LLVMValueRef tid = ac_get_thread_id(ctx);
+       LLVMBasicBlockRef bbs[2];
+       LLVMValueRef phivalues_scan[2];
+       LLVMValueRef tmp, tmp2;
+
+       bbs[0] = LLVMGetInsertBlock(builder);
+       phivalues_scan[0] = LLVMGetUndef(type);
+
+       if (ws->enable_reduce)
+               tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->numwaves, "");
+       else if (ws->enable_inclusive)
+               tmp = LLVMBuildICmp(builder, LLVMIntULE, tid, ws->waveidx, "");
+       else
+               tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->waveidx, "");
+       ac_build_ifcc(ctx, tmp, 1001);
+       {
+               tmp = LLVMBuildLoad(builder, LLVMBuildGEP(builder, ws->scratch, 
&tid, 1, ""), "");
+
+               ac_build_optimization_barrier(ctx, &tmp);
+
+               bbs[1] = LLVMGetInsertBlock(builder);
+               phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, 
ws->maxwaves);
+       }
+       ac_build_endif(ctx, 1001);
+
+       const LLVMValueRef scan = ac_build_phi(ctx, type, 2, phivalues_scan, 
bbs);
+
+       if (ws->enable_reduce) {
+               tmp = LLVMBuildSub(builder, ws->numwaves, ctx->i32_1, "");
+               ws->result_reduce = ac_build_readlane(ctx, scan, tmp);
+       }
+       if (ws->enable_inclusive)
+               ws->result_inclusive = ac_build_readlane(ctx, scan, 
ws->waveidx);
+       if (ws->enable_exclusive) {
+               tmp = LLVMBuildSub(builder, ws->waveidx, ctx->i32_1, "");
+               tmp = ac_build_readlane(ctx, scan, tmp);
+               tmp2 = LLVMBuildICmp(builder, LLVMIntEQ, ws->waveidx, 
ctx->i32_0, "");
+               ws->result_exclusive = LLVMBuildSelect(builder, tmp2, identity, 
tmp, "");
+       }
+}
+
+/**
+ * Inclusive scan of a per-wave value across an entire workgroup.
+ *
+ * This implies an s_barrier instruction.
+ *
+ * Unlike ac_build_inclusive_scan, the caller \em must ensure that all threads
+ * of the workgroup are live. (This requirement cannot easily be relaxed in a
+ * useful manner because of the barrier in the algorithm.)
+ */
+void
+ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
+{
+       ac_build_wg_wavescan_top(ctx, ws);
+       ac_build_s_barrier(ctx);
+       ac_build_wg_wavescan_bottom(ctx, ws);
+}
+
+/**
+ * "Top half" of a scan that reduces per-thread values across an entire
+ * workgroup.
+ *
+ * All lanes must be active when this code runs.
+ */
+void
+ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
+{
+       if (ws->enable_exclusive) {
+               ws->extra = ac_build_exclusive_scan(ctx, ws->src, ws->op);
+               ws->src = ac_build_alu_op(ctx, ws->extra, ws->src, ws->op);
+       } else {
+               ws->src = ac_build_inclusive_scan(ctx, ws->src, ws->op);
+       }
+
+       bool enable_inclusive = ws->enable_inclusive;
+       bool enable_exclusive = ws->enable_exclusive;
+       ws->enable_inclusive = false;
+       ws->enable_exclusive = ws->enable_exclusive || enable_inclusive;
+       ac_build_wg_wavescan_top(ctx, ws);
+       ws->enable_inclusive = enable_inclusive;
+       ws->enable_exclusive = enable_exclusive;
+}
+
+/**
+ * "Bottom half" of a scan that reduces per-thread values across an entire
+ * workgroup.
+ *
+ * The caller must place a barrier between the top and bottom halves.
+ */
+void
+ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
+{
+       bool enable_inclusive = ws->enable_inclusive;
+       bool enable_exclusive = ws->enable_exclusive;
+       ws->enable_inclusive = false;
+       ws->enable_exclusive = ws->enable_exclusive || enable_inclusive;
+       ac_build_wg_wavescan_bottom(ctx, ws);
+       ws->enable_inclusive = enable_inclusive;
+       ws->enable_exclusive = enable_exclusive;
+
+       /* ws->result_reduce is already the correct value */
+       if (ws->enable_inclusive)
+               ws->result_inclusive = ac_build_alu_op(ctx, 
ws->result_exclusive, ws->src, ws->op);
+       if (ws->enable_exclusive)
+               ws->result_exclusive = ac_build_alu_op(ctx, 
ws->result_exclusive, ws->extra, ws->op);
+}
+
+/**
+ * A scan that reduces per-thread values across an entire workgroup.
+ *
+ * The caller must ensure that all lanes are active when this code runs
+ * (WWM is insufficient!), because there is an implied barrier.
+ */
+void
+ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
+{
+       ac_build_wg_scan_top(ctx, ws);
+       ac_build_s_barrier(ctx);
+       ac_build_wg_scan_bottom(ctx, ws);
+}
+
 LLVMValueRef
 ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
                unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3)
 {
        unsigned mask = dpp_quad_perm(lane0, lane1, lane2, lane3);
        if (ctx->chip_class >= VI) {
                return ac_build_dpp(ctx, src, src, mask, 0xf, 0xf, false);
        } else {
                return ac_build_ds_swizzle(ctx, src, (1 << 15) | mask);
        }
diff --git a/src/amd/common/ac_llvm_build.h b/src/amd/common/ac_llvm_build.h
index cf3e3cedf65..cad131768d2 100644
--- a/src/amd/common/ac_llvm_build.h
+++ b/src/amd/common/ac_llvm_build.h
@@ -519,20 +519,56 @@ ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef 
mask);
 
 LLVMValueRef
 ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op 
op);
 
 LLVMValueRef
 ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op 
op);
 
 LLVMValueRef
 ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, 
unsigned cluster_size);
 
+/**
+ * Common arguments for a scan/reduce operation that accumulates per-wave
+ * values across an entire workgroup, while respecting the order of waves.
+ */
+struct ac_wg_scan {
+       bool enable_reduce;
+       bool enable_exclusive;
+       bool enable_inclusive;
+       nir_op op;
+       LLVMValueRef src; /* clobbered! */
+       LLVMValueRef result_reduce;
+       LLVMValueRef result_exclusive;
+       LLVMValueRef result_inclusive;
+       LLVMValueRef extra;
+       LLVMValueRef waveidx;
+       LLVMValueRef numwaves; /* only needed for "reduce" operations */
+
+       /* T addrspace(LDS) pointer to the same type as value, at least 
maxwaves entries */
+       LLVMValueRef scratch;
+       unsigned maxwaves;
+};
+
+void
+ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
+void
+ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan 
*ws);
+void
+ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
+
+void
+ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
+void
+ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
+void
+ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
+
 LLVMValueRef
 ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
                unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3);
 
 LLVMValueRef
 ac_build_shuffle(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef 
index);
 
 #ifdef __cplusplus
 }
 #endif
-- 
2.19.1

_______________________________________________
mesa-dev mailing list
mesa-dev@lists.freedesktop.org
https://lists.freedesktop.org/mailman/listinfo/mesa-dev

Reply via email to