This introduces new vec8 and vec16 instructions (which are the only instructions taking more than 4 sources), in order to construct 8 and 16 component vectors.
In order to avoid fixing up the non-autogenerated nir_build_alu() sites and making them pass 16 src args for the benefit of the two instructions that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is used for the > 4 src args case). Signed-off-by: Rob Clark <robdcl...@gmail.com> Signed-off-by: Karol Herbst <kher...@redhat.com> --- src/compiler/nir/nir.h | 4 +- src/compiler/nir/nir_builder.h | 58 +++++++++++++++----- src/compiler/nir/nir_builder_opcodes_h.py | 5 +- src/compiler/nir/nir_constant_expressions.py | 33 +++++++++-- src/compiler/nir/nir_lower_alu_to_scalar.c | 2 + src/compiler/nir/nir_opcodes.py | 39 ++++++++++++- src/compiler/nir/nir_print.c | 17 ++++-- src/compiler/nir/nir_search.c | 8 ++- src/compiler/spirv/spirv_to_nir.c | 4 +- 9 files changed, 140 insertions(+), 30 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 3855eb0b582..89c28e36618 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -57,8 +57,8 @@ extern "C" { #define NIR_FALSE 0u #define NIR_TRUE (~0u) -#define NIR_MAX_VEC_COMPONENTS 4 -typedef uint8_t nir_component_mask_t; +#define NIR_MAX_VEC_COMPONENTS 16 +typedef uint16_t nir_component_mask_t; /** Defines a cast function * diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index 3271a480520..57f0a188c46 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -352,24 +352,12 @@ nir_imm_ivec4(nir_builder *build, int x, int y, int z, int w) } static inline nir_ssa_def * -nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0, - nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3) +nir_build_alu_tail(nir_builder *build, nir_alu_instr *instr) { - const nir_op_info *op_info = &nir_op_infos[op]; - nir_alu_instr *instr = nir_alu_instr_create(build->shader, op); - if (!instr) - return NULL; + const nir_op_info *op_info = &nir_op_infos[instr->op]; instr->exact = build->exact; - instr->src[0].src = nir_src_for_ssa(src0); - if (src1) - instr->src[1].src = nir_src_for_ssa(src1); - if (src2) - instr->src[2].src = nir_src_for_ssa(src2); - if (src3) - instr->src[3].src = nir_src_for_ssa(src3); - /* Guess the number of components the destination temporary should have * based on our input sizes, if it's not fixed for the op. */ @@ -425,12 +413,54 @@ nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0, return &instr->dest.dest.ssa; } +static inline nir_ssa_def * +nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0, + nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3) +{ + nir_alu_instr *instr = nir_alu_instr_create(build->shader, op); + if (!instr) + return NULL; + + instr->src[0].src = nir_src_for_ssa(src0); + if (src1) + instr->src[1].src = nir_src_for_ssa(src1); + if (src2) + instr->src[2].src = nir_src_for_ssa(src2); + if (src3) + instr->src[3].src = nir_src_for_ssa(src3); + + return nir_build_alu_tail(build, instr); +} + +/* for the couple special cases with more than 4 src args: */ +static inline nir_ssa_def * +nir_build_alu2(nir_builder *build, nir_op op, nir_ssa_def **srcs) +{ + const nir_op_info *op_info = &nir_op_infos[op]; + nir_alu_instr *instr = nir_alu_instr_create(build->shader, op); + if (!instr) + return NULL; + + for (unsigned i = 0; i < op_info->num_inputs; i++) + instr->src[i].src = nir_src_for_ssa(srcs[i]); + + return nir_build_alu_tail(build, instr); +} + #include "nir_builder_opcodes.h" static inline nir_ssa_def * nir_vec(nir_builder *build, nir_ssa_def **comp, unsigned num_components) { switch (num_components) { + case 16: + return nir_vec16(build, comp[0], comp[1], comp[2], comp[3], + comp[4], comp[5], comp[6], comp[7], + comp[8], comp[9], comp[10], comp[11], + comp[12], comp[13], comp[14], comp[15]); + case 8: + return nir_vec8(build, comp[0], comp[1], comp[2], comp[3], + comp[4], comp[5], comp[6], comp[7]); case 4: return nir_vec4(build, comp[0], comp[1], comp[2], comp[3]); case 3: diff --git a/src/compiler/nir/nir_builder_opcodes_h.py b/src/compiler/nir/nir_builder_opcodes_h.py index 84e5400958e..47edc02896c 100644 --- a/src/compiler/nir/nir_builder_opcodes_h.py +++ b/src/compiler/nir/nir_builder_opcodes_h.py @@ -31,14 +31,15 @@ def src_decl_list(num_srcs): return ', '.join('nir_ssa_def *src' + str(i) for i in range(num_srcs)) def src_list(num_srcs): - return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4)) + return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(16)) %> % for name, opcode in sorted(opcodes.items()): static inline nir_ssa_def * nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)}) { - return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)}); + nir_ssa_def *srcs[] = {${src_list(opcode.num_inputs)}}; + return nir_build_alu2(build, nir_op_${name}, srcs); } % endfor diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index 118af9f7818..fe54a8f710d 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -258,6 +258,7 @@ typedef float float16_t; typedef float float32_t; typedef double float64_t; typedef bool bool32_t; + % for type in ["float", "int", "uint"]: % for width in type_sizes(type): struct ${type}${width}_vec { @@ -265,6 +266,18 @@ struct ${type}${width}_vec { ${type}${width}_t y; ${type}${width}_t z; ${type}${width}_t w; + ${type}${width}_t e; + ${type}${width}_t f; + ${type}${width}_t g; + ${type}${width}_t h; + ${type}${width}_t i; + ${type}${width}_t j; + ${type}${width}_t k; + ${type}${width}_t l; + ${type}${width}_t m; + ${type}${width}_t n; + ${type}${width}_t o; + ${type}${width}_t p; }; % endfor % endfor @@ -274,6 +287,18 @@ struct bool32_vec { bool y; bool z; bool w; + bool e; + bool f; + bool g; + bool h; + bool i; + bool j; + bool k; + bool l; + bool m; + bool n; + bool o; + bool p; }; <%def name="evaluate_op(op, bit_size)"> @@ -303,7 +328,7 @@ struct bool32_vec { _src[${j}].${get_const_field(input_types[j])}[${k}], % endif % endfor - % for k in range(op.input_sizes[j], 4): + % for k in range(op.input_sizes[j], 16): 0, % endfor }; @@ -377,11 +402,11 @@ struct bool32_vec { % for k in range(op.output_size): % if output_type == "bool32": ## Sanitize the C value to a proper NIR bool - _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE; + _dst_val.u32[${k}] = dst.${"xyzwefghijklmnop"[k]} ? NIR_TRUE : NIR_FALSE; % elif output_type == "float16": - _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]}); + _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzwefghijklmnop"[k]}); % else: - _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]}; + _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzwefghijklmnop"[k]}; % endif % endfor % endif diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c index 0be3aba9456..5e8b76426fb 100644 --- a/src/compiler/nir/nir_lower_alu_to_scalar.c +++ b/src/compiler/nir/nir_lower_alu_to_scalar.c @@ -93,6 +93,8 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b) return true; switch (instr->op) { + case nir_op_vec16: + case nir_op_vec8: case nir_op_vec4: case nir_op_vec3: case nir_op_vec2: diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index 4ef4ecc6f22..bd212b7c2fb 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -72,7 +72,7 @@ class Opcode(object): assert isinstance(algebraic_properties, str) assert isinstance(const_expr, str) assert len(input_sizes) == len(input_types) - assert 0 <= output_size <= 4 + assert (0 <= output_size <= 4) or (output_size == 8) or (output_size == 16) for size in input_sizes: assert 0 <= size <= 4 if output_size != 0: @@ -804,4 +804,41 @@ dst.z = src2.x; dst.w = src3.x; """) +opcode("vec8", 8, tuint, + [1, 1, 1, 1, 1, 1, 1, 1], + [tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint], + "", """ +dst.x = src0.x; +dst.y = src1.x; +dst.z = src2.x; +dst.w = src3.x; +dst.e = src4.x; +dst.f = src5.x; +dst.g = src6.x; +dst.h = src7.x; +""") + +opcode("vec16", 16, tuint, + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint, + tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint], + "", """ +dst.x = src0.x; +dst.y = src1.x; +dst.z = src2.x; +dst.w = src3.x; +dst.e = src4.x; +dst.f = src5.x; +dst.g = src6.x; +dst.h = src7.x; +dst.i = src8.x; +dst.j = src9.x; +dst.k = src10.x; +dst.l = src11.x; +dst.m = src12.x; +dst.n = src13.x; +dst.o = src14.x; +dst.p = src15.x; +""") + diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index ab3d5115688..0de82800f8c 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -173,6 +173,12 @@ print_dest(nir_dest *dest, print_state *state) print_reg_dest(&dest->reg, state); } +static const char * +wrmask_string(unsigned num_components) +{ + return (num_components > 4) ? "abcdefghijklmnop" : "xyzw"; +} + static void print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state) { @@ -208,7 +214,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state) if (!nir_alu_instr_channel_used(instr, src, i)) continue; - fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]); + fprintf(fp, "%c", wrmask_string(live_channels)[instr->src[src].swizzle[i]]); } } @@ -226,10 +232,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state) if (!dest->dest.is_ssa && dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) { + unsigned live_channels = dest->dest.reg.reg->num_components; fprintf(fp, "."); for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) if ((dest->write_mask >> i) & 1) - fprintf(fp, "%c", "xyzw"[i]); + fprintf(fp, "%c", wrmask_string(live_channels)[i]); } } @@ -493,7 +500,7 @@ print_var_decl(nir_variable *var, print_state *state) case nir_var_shader_in: case nir_var_shader_out: if (num_components < 4 && num_components != 0) { - const char *xyzw = "xyzw"; + const char *xyzw = wrmask_string(num_components); for (int i = 0; i < num_components; i++) components_local[i + 1] = xyzw[i + var->data.location_frac]; @@ -700,9 +707,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state) /* special case wrmask to show it as a writemask.. */ unsigned wrmask = nir_intrinsic_write_mask(instr); fprintf(fp, " wrmask="); - for (unsigned i = 0; i < 4; i++) + for (unsigned i = 0; i < instr->num_components; i++) if ((wrmask >> i) & 1) - fprintf(fp, "%c", "xyzw"[i]); + fprintf(fp, "%c", wrmask_string(instr->num_components)[i]); } else if (idx == NIR_INTRINSIC_REDUCTION_OP) { nir_op reduction_op = nir_intrinsic_reduction_op(instr); fprintf(fp, " reduction_op=%s", nir_op_infos[reduction_op].name); diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 0270302fd3d..642755f2a6a 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -42,7 +42,13 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, unsigned num_components, const uint8_t *swizzle, struct match_state *state); -static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 }; +static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = +{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15 +}; /** * Check if a source produces a value of the given type. diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index e597b2462cb..a350a95e27e 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2838,6 +2838,8 @@ create_vec(struct vtn_builder *b, unsigned num_components, unsigned bit_size) case 2: op = nir_op_vec2; break; case 3: op = nir_op_vec3; break; case 4: op = nir_op_vec4; break; + case 8: op = nir_op_vec8; break; + case 16: op = nir_op_vec16; break; default: vtn_fail("bad vector size"); } @@ -3422,10 +3424,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, case SpvCapabilityInputAttachment: case SpvCapabilityImageGatherExtended: case SpvCapabilityStorageImageExtendedFormats: + case SpvCapabilityVector16: break; case SpvCapabilityLinkage: - case SpvCapabilityVector16: case SpvCapabilityFloat16Buffer: case SpvCapabilityFloat16: case SpvCapabilityInt64Atomics: -- 2.19.1 _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev