If we just check that we are not dealing with an identity swizzle in match_value() before calling match_expression() we can avoid a bunch of temp swizzle arrays and the passing it around and resetting craziness. --- src/compiler/nir/nir_search.c | 89 ++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 51 deletions(-)
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index b34b13f..7a84b18 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -37,8 +37,7 @@ struct match_state { static bool match_expression(const nir_search_expression *expr, nir_alu_instr *instr, - unsigned num_components, const uint8_t *swizzle, - struct match_state *state); + unsigned num_components, struct match_state *state); static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 }; @@ -93,22 +92,15 @@ src_is_type(nir_src src, nir_alu_type type) static bool match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, - unsigned num_components, const uint8_t *swizzle, - struct match_state *state) + unsigned num_components, struct match_state *state) { - uint8_t new_swizzle[4]; - /* If the source is an explicitly sized source, then we need to reset - * both the number of components and the swizzle. + * the number of components. */ if (nir_op_infos[instr->op].input_sizes[src] != 0) { num_components = nir_op_infos[instr->op].input_sizes[src]; - swizzle = identity_swizzle; } - for (unsigned i = 0; i < num_components; ++i) - new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; - /* If the value has a specific bit size and it doesn't match, bail */ if (value->bit_size && nir_src_bit_size(instr->src[src].src) != value->bit_size) @@ -122,9 +114,23 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) return false; + /* If we have an explicitly sized destination, we can only handle the + * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid + * expression, we don't have the information right now to propagate that + * swizzle through. We can only properly propagate swizzles if the + * instruction is vectorized. + */ + nir_alu_instr *alu_instr = + nir_instr_as_alu(instr->src[src].src.ssa->parent_instr); + if (nir_op_infos[alu_instr->op].output_size != 0) { + for (unsigned i = 0; i < num_components; i++) { + if (instr->src[src].swizzle[i] != i) + return false; + } + } + return match_expression(nir_search_value_as_expression(value), - nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), - num_components, new_swizzle, state); + alu_instr, num_components, state); case nir_search_value_variable: { nir_search_variable *var = nir_search_value_as_variable(value); @@ -138,7 +144,8 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, assert(!instr->src[src].abs && !instr->src[src].negate); for (unsigned i = 0; i < num_components; ++i) { - if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) + if (state->variables[var->variable].swizzle[i] != + instr->src[src].swizzle[i]) return false; } @@ -148,7 +155,8 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) return false; - if (var->cond && !var->cond(instr, src, num_components, new_swizzle)) + if (var->cond && !var->cond(instr, src, num_components, + instr->src[src].swizzle)) return false; if (var->type != nir_type_invalid && @@ -161,9 +169,10 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, state->variables[var->variable].negate = false; for (unsigned i = 0; i < 4; ++i) { - if (i < num_components) - state->variables[var->variable].swizzle[i] = new_swizzle[i]; - else + if (i < num_components) { + state->variables[var->variable].swizzle[i] = + instr->src[src].swizzle[i]; + } else state->variables[var->variable].swizzle[i] = 0; } @@ -189,10 +198,10 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, double val; switch (load->def.bit_size) { case 32: - val = load->value.f32[new_swizzle[i]]; + val = load->value.f32[instr->src[src].swizzle[i]]; break; case 64: - val = load->value.f64[new_swizzle[i]]; + val = load->value.f64[instr->src[src].swizzle[i]]; break; default: unreachable("unknown bit size"); @@ -208,10 +217,10 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, int64_t val; switch (load->def.bit_size) { case 32: - val = load->value.i32[new_swizzle[i]]; + val = load->value.i32[instr->src[src].swizzle[i]]; break; case 64: - val = load->value.i64[new_swizzle[i]]; + val = load->value.i64[instr->src[src].swizzle[i]]; break; default: unreachable("unknown bit size"); @@ -228,10 +237,10 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, uint64_t val; switch (load->def.bit_size) { case 32: - val = load->value.u32[new_swizzle[i]]; + val = load->value.u32[instr->src[src].swizzle[i]]; break; case 64: - val = load->value.u64[new_swizzle[i]]; + val = load->value.u64[instr->src[src].swizzle[i]]; break; default: unreachable("unknown bit size"); @@ -254,8 +263,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, static bool match_expression(const nir_search_expression *expr, nir_alu_instr *instr, - unsigned num_components, const uint8_t *swizzle, - struct match_state *state) + unsigned num_components, struct match_state *state) { if (instr->op != expr->opcode) return false; @@ -274,19 +282,6 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, assert(!instr->dest.saturate); assert(nir_op_infos[instr->op].num_inputs > 0); - /* If we have an explicitly sized destination, we can only handle the - * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid - * expression, we don't have the information right now to propagate that - * swizzle through. We can only properly propagate swizzles if the - * instruction is vectorized. - */ - if (nir_op_infos[instr->op].output_size != 0) { - for (unsigned i = 0; i < num_components; i++) { - if (swizzle[i] != i) - return false; - } - } - /* Stash off the current variables_seen bitmask. This way we can * restore it prior to matching in the commutative case below. */ @@ -294,8 +289,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, bool matched = true; for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { - if (!match_value(expr->srcs[i], instr, i, num_components, - swizzle, state)) { + if (!match_value(expr->srcs[i], instr, i, num_components, state)) { matched = false; break; } @@ -313,12 +307,10 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr, */ state->variables_seen = variables_seen_stash; - if (!match_value(expr->srcs[0], instr, 1, num_components, - swizzle, state)) + if (!match_value(expr->srcs[0], instr, 1, num_components, state)) return false; - return match_value(expr->srcs[1], instr, 0, num_components, - swizzle, state); + return match_value(expr->srcs[1], instr, 0, num_components, state); } else { return false; } @@ -578,11 +570,6 @@ nir_alu_instr * nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search, const nir_search_value *replace, void *mem_ctx) { - uint8_t swizzle[4] = { 0, 0, 0, 0 }; - - for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) - swizzle[i] = i; - assert(instr->dest.dest.is_ssa); struct match_state state; @@ -591,7 +578,7 @@ nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search, state.variables_seen = 0; if (!match_expression(search, instr, instr->dest.dest.ssa.num_components, - swizzle, &state)) + &state)) return NULL; void *bitsize_ctx = ralloc_context(NULL); -- 2.9.3 _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev