--- src/compiler/nir/nir_algebraic.py | 50 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+)
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 9b89828..0d1ed3a 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -131,6 +131,16 @@ class Constant(Value): assert self.bit_size == 0 or self.bit_size == 32 self.bit_size = 32 + def __hash__(self): + return hash(self.value) ^ self.bit_size * 37; + + def __eq__(self, other): + if not isinstance(other, Constant): + return False + + return type(self.value) is type(other.value) and \ + self.value == other.value and self.bit_size == other.bit_size + def __hex__(self): if isinstance(self.value, (bool)): return 'NIR_TRUE' if self.value else 'NIR_FALSE' @@ -173,6 +183,22 @@ class Variable(Value): self.index = varset[self.var_name] + def __hash__(self): + h = 1893 + self.index + if self.is_constant: + h = ~h + h = h * 37 ^ hash(self.required_type) + return h * 37 ^ self.bit_size + + def __eq__(self, other): + if not isinstance(other, Variable): + return False + + return self.index == other.index \ + and self.is_constant == other.is_constant \ + and self.required_type == other.required_type \ + and self.bit_size == other.bit_size + def type(self): if self.required_type == 'bool': return "nir_type_bool32" @@ -197,6 +223,30 @@ class Expression(Value): self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) for (i, src) in enumerate(expr[1:]) ] + def __hash__(self): + h = hash(self.opcode) + h = h * 37 ^ self.bit_size + if self.inexact: + h = ~h + for src in self.sources: + h = h * 37 ^ hash(src) + return h + + def __eq__(self, other): + if not isinstance(other, Expression): + return False + + if self.opcode != other.opcode: + return False + + assert len(self.sources) == len(other.sources) + for a, b in zip(self.sources, other.sources): + if a != b: + return False + + return self.inexact == other.inexact \ + and self.bit_size == other.bit_size + def render(self): srcs = "\n".join(src.render() for src in self.sources) return srcs + super(Expression, self).render() -- 2.5.0.400.gff86faf _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev