On 26/04/16 06:39, Jason Ekstrand wrote: > This commit adds a validator that ensures that all expressions passed > through nir_algebraic are 100% non-ambiguous as far as bit-sizes are > concerned. This way it's a compile-time error rather than a hard-to-trace > C exception some time later. > --- > src/compiler/nir/nir_algebraic.py | 270 > ++++++++++++++++++++++++++++++++++++++ > 1 file changed, 270 insertions(+) > > diff --git a/src/compiler/nir/nir_algebraic.py > b/src/compiler/nir/nir_algebraic.py > index e9b5832..503371b 100644 > --- a/src/compiler/nir/nir_algebraic.py > +++ b/src/compiler/nir/nir_algebraic.py > @@ -33,6 +33,19 @@ import mako.template > import re > import traceback > > +from nir_opcodes import opcodes > + > +_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") > + > +def type_bits(type_str): > + m = _type_re.match(type_str) > + assert m.group('type') > + > + if m.group('bits') is None: > + return 0 > + else: > + return int(m.group('bits')) > + > # Represents a set of variables, each with a unique id > class VarSet(object): > def __init__(self): > @@ -188,6 +201,261 @@ class Expression(Value): > srcs = "\n".join(src.render() for src in self.sources) > return srcs + super(Expression, self).render() > > +class IntEquivalenceRelation(object): > + """A class representing an equivalence relation on integers. > + > + Each integer has a cannonical form which is the maximum integer to which > it > + is equivalent. Two integers are equivalent precicely when they have the
precisely > + same cannonical form. > + canonical. This typo is repeated in the rest of the patch. > + The convention of maximum is explicitly chosen to make using it in > + BitSizeValidator easier because it means that an actual bit_size (if any) > + will always be the cannonical form. > + """ > + def __init__(self): > + self._remap = {} > + > + def get_cannonical(self, x): > + """Get the cannonical integer corresponding to x.""" > + if x in self._remap: > + return self.get_cannonical(self._remap[x]) > + else: > + return x > + > + def add_equiv(self, a, b): > + """Add an equivalence and return the cannonical form.""" > + c = max(self.get_cannonical(a), self.get_cannonical(b)) > + if a != c: > + assert a < c > + self._remap[a] = c > + > + if b != c: > + assert b < c > + self._remap[b] = c > + > + return c > + > +class BitSizeValidator(object): > + """A class for validating bit sizes of expressions. > + > + NIR supports multiple bit-sizes on expressions in order to handle things > + such as fp64. The source and destination of every ALU operation is > + assigned a type and that type may or may not specify a bit size. Sources > + and destinations whose type does not specify a bit size are considered > + "unsized" and automatically take on the bit size of the corresponding > + register or SSA value. NIR has two simple rules for bit sizes that are > + validated by nir_validator: > + > + 1) A given SSA def or register has a single bit size that is respected by > + everything that reads from it or writes to it. > + > + 2) The bit sizes of all unsized inputs/outputs on any given ALU > + instruction must match. They need not match the sized inputs or > + outputs but they must match each other. > + > + In order to keep nir_algebraic relatively simple and easy-to-use, > + nir_search supports a type of bit-size inference based on the two rules > + above. This is similar to type inference in many common programming > + languages. If, for instance, you are constructing an add operation and > you > + know the second source is 16-bit, then you know that the other source and > + the destination must also be 16-bit. There are, however, cases where this > + inference can be ambiguous or contradictory. Consider, for instance, the > + following transformation: > + > + (('usub_borrow', a, b), ('b2i', ('ult', a, b))) > + > + This transformation can potentiall cause a problem because usub_borrow is potentially > + well-defined for any bit-size of integer. However, b2i always generates a > + 32-bit result so it could end up replacing a 64-bit expression with one > + that takes two 64-bit values and produces a 32-bit value. As another > + example, consider this expression: > + > + (('bcsel', a, b, 0), ('iand', a, b)) > + > + In this case, in the search expression a must be 32-bit but b can > + potentially have any bit size. If we had a 64-bit b value, we would end > up > + trying to and a 32-bit value with a 64-bit value which would be invalid > + > + This class solves that problem by providing a validation layer that proves > + that a given search-and-replace operation is 100% well-defined before we > + generate any code. This ensures that bugs are caught at compile time > + rather than at run time. > + > + The basic operation of the validator is very similar to the bitsize_tree > in > + nir_search only a little more subtle. Instead of simply tracking bit > + sizes, it tracks "bit classes" where each class is represented by an > + integer. A value of 0 means we don't know anything yet, positive values > + are actual bit-sizes, and negative values are used to track equivalence > + classes of sizes that must be the same but have yet to recieve an actual receive > + size. The first stage uses the bitsize_tree algorithm to assign bit > + classes to each variable. If it ever comes across an inconsistency, it > + assert-fails. Then the second stage uses that information to prove that > + the resulting expression can always validly be constructed. > + """ > + > + def __init__(self, varset): > + self._num_classes = 0 > + self._var_classes = [0] * len(varset.names) > + self._class_relation = IntEquivalenceRelation() > + > + def validate(self, search, replace): > + dst_class = self._propagate_bit_size_up(search) > + if dst_class == 0: > + dst_class = self._new_class() > + self._propagate_bit_class_down(search, dst_class) > + > + validate_dst_class = self._validate_bit_class_up(replace) > + assert validate_dst_class == 0 or validate_dst_class == dst_class > + self._validate_bit_class_down(replace, dst_class) > + > + def _new_class(self): > + self._num_classes += 1 > + return -self._num_classes > + > + def _set_var_bit_class(self, var_id, bit_class): > + assert bit_class != 0 > + var_class = self._var_classes[var_id] > + if var_class == 0: > + self._var_classes[var_id] = bit_class > + else: > + cannon_class = self._class_relation.get_cannonical(var_class) canon_class?? > + assert cannon_class < 0 or cannon_class == bit_class > + var_class = self._class_relation.add_equiv(var_class, bit_class) > + self._var_classes[var_id] = var_class > + > + def _get_var_bit_class(self, var_id): > + return self._class_relation.get_cannonical(self._var_classes[var_id]) > + > + def _propagate_bit_size_up(self, val): > + if isinstance(val, (Constant, Variable)): > + return val.bit_size > + > + elif isinstance(val, Expression): > + nir_op = opcodes[val.opcode] > + val.common_size = 0 > + for i in range(nir_op.num_inputs): > + src_bits = self._propagate_bit_size_up(val.sources[i]) > + if src_bits == 0: > + continue > + > + src_type_bits = type_bits(nir_op.input_types[i]) > + if src_type_bits != 0: > + assert src_bits == src_type_bits > + else: > + assert val.common_size == 0 or src_bits == val.common_size > + val.common_size = src_bits > + > + dst_type_bits = type_bits(nir_op.output_type) > + if dst_type_bits != 0: > + assert val.bit_size == 0 or val.bit_size == dst_type_bits > + return dst_type_bits > + else: > + if val.common_size != 0: > + assert val.bit_size == 0 or val.bit_size == val.common_size > + else: > + val.common_size = val.bit_size > + return val.common_size > + > + def _propagate_bit_class_down(self, val, bit_class): > + if isinstance(val, Constant): > + assert val.bit_size == 0 or val.bit_size == bit_class > + > + elif isinstance(val, Variable): > + assert val.bit_size == 0 or val.bit_size == bit_class > + self._set_var_bit_class(val.index, bit_class) > + > + elif isinstance(val, Expression): > + nir_op = opcodes[val.opcode] > + dst_type_bits = type_bits(nir_op.output_type) > + if dst_type_bits != 0: > + assert bit_class == 0 or bit_class == dst_type_bits > + else: > + assert val.common_size == 0 or val.common_size == bit_class > + val.common_size = bit_class > + > + if val.common_size: > + common_class = val.common_size > + elif nir_op.num_inputs: > + # If we got here then we have no idea what the actual size is. > + # Instead, we use a generic class > + common_class = self._new_class() > + > + for i in range(nir_op.num_inputs): > + src_type_bits = type_bits(nir_op.input_types[i]) > + if src_type_bits != 0: > + self._propagate_bit_class_down(val.sources[i], src_type_bits) > + else: > + self._propagate_bit_class_down(val.sources[i], common_class) > + > + def _validate_bit_class_up(self, val): > + if isinstance(val, Constant): > + return val.bit_size > + > + elif isinstance(val, Variable): > + var_class = self._get_var_bit_class(val.index) > + # By the time we get to validation, every variable should have a > class > + assert var_class != 0 > + > + # If we have an explicit size provided by the user, the variable > + # *must* exactly match the search. It cannot be implicitly sized > + # because otherwise we could end up with a conflict at runtime. > + assert val.bit_size == 0 or val.bit_size == var_class > + > + return var_class > + > + elif isinstance(val, Expression): > + nir_op = opcodes[val.opcode] > + val.common_class = 0 > + for i in range(nir_op.num_inputs): > + src_class = self._validate_bit_class_up(val.sources[i]) > + if src_class == 0: > + continue > + > + src_type_bits = type_bits(nir_op.input_types[i]) > + if src_type_bits != 0: > + assert src_class == src_type_bits > + else: > + assert val.common_class == 0 or src_class == val.common_class > + val.common_class = src_class > + > + dst_type_bits = type_bits(nir_op.output_type) > + if dst_type_bits != 0: > + assert val.bit_size == 0 or val.bit_size == dst_type_bits > + return dst_type_bits > + else: > + if val.common_class != 0: > + assert val.bit_size == 0 or val.bit_size == val.common_class > + else: > + val.common_class = val.bit_size > + return val.common_class > + > + def _validate_bit_class_down(self, val, bit_class): > + # At this point, everthing *must* have a bit class. Otherwise, we have everything Other than that, Reviewed-by: Samuel Iglesias Gonsálvez <sigles...@igalia.com> Sam > + # a value we don't know how to define. > + assert bit_class != 0 > + > + if isinstance(val, Constant): > + assert val.bit_size == 0 or val.bit_size == bit_class > + > + elif isinstance(val, Variable): > + assert val.bit_size == 0 or val.bit_size == bit_class > + > + elif isinstance(val, Expression): > + nir_op = opcodes[val.opcode] > + dst_type_bits = type_bits(nir_op.output_type) > + if dst_type_bits != 0: > + assert bit_class == dst_type_bits > + else: > + assert val.common_class == 0 or val.common_class == bit_class > + val.common_class = bit_class > + > + for i in range(nir_op.num_inputs): > + src_type_bits = type_bits(nir_op.input_types[i]) > + if src_type_bits != 0: > + self._validate_bit_class_down(val.sources[i], src_type_bits) > + else: > + self._validate_bit_class_down(val.sources[i], > val.common_class) > + > _optimization_ids = itertools.count() > > condition_list = ['true'] > @@ -220,6 +488,8 @@ class SearchAndReplace(object): > else: > self.replace = Value.create(replace, "replace{0}".format(self.id), > varset) > > + BitSizeValidator(varset).validate(self.search, self.replace) > + > _algebraic_pass_template = mako.template.Template(""" > #include "nir.h" > #include "nir_search.h" > _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev