hdl._ast, hdl._ir: Deduplicate shape unification logic. NFC

This commit is contained in:
Wanda 2024-03-05 10:56:05 +01:00 committed by Catherine
parent 31a12c03d1
commit 161b01450e
2 changed files with 41 additions and 61 deletions

View file

@ -159,6 +159,28 @@ class Shape:
return (isinstance(other, Shape) and return (isinstance(other, Shape) and
self.width == other.width and self.signed == other.signed) self.width == other.width and self.signed == other.signed)
@staticmethod
def _unify(shapes):
"""Returns the minimal shape that contains all shapes from the list.
If no shapes passed in, returns unsigned(0).
"""
unsigned_width = signed_width = 0
has_signed = False
for shape in shapes:
assert isinstance(shape, Shape)
if shape.signed:
has_signed = True
signed_width = max(signed_width, shape.width)
else:
unsigned_width = max(unsigned_width, shape.width)
# If all shapes unsigned, simply take max.
if not has_signed:
return unsigned(unsigned_width)
# Otherwise, result is signed. All unsigned inputs, if any,
# need to be converted to signed by adding a zero bit.
return signed(max(signed_width, unsigned_width + 1))
def unsigned(width): def unsigned(width):
"""Returns :py:`Shape(width, signed=False)`.""" """Returns :py:`Shape(width, signed=False)`."""
@ -1524,20 +1546,6 @@ class Operator(Value):
return self._operands return self._operands
def shape(self): def shape(self):
def _bitwise_binary_shape(a_shape, b_shape):
if not a_shape.signed and not b_shape.signed:
# both operands unsigned
return unsigned(max(a_shape.width, b_shape.width))
elif a_shape.signed and b_shape.signed:
# both operands signed
return signed(max(a_shape.width, b_shape.width))
elif not a_shape.signed and b_shape.signed:
# first operand unsigned (add sign bit), second operand signed
return signed(max(a_shape.width + 1, b_shape.width))
else:
# first signed, second operand unsigned (add sign bit)
return signed(max(a_shape.width, b_shape.width + 1))
op_shapes = list(map(lambda x: x.shape(), self.operands)) op_shapes = list(map(lambda x: x.shape(), self.operands))
if len(op_shapes) == 1: if len(op_shapes) == 1:
a_shape, = op_shapes a_shape, = op_shapes
@ -1554,10 +1562,10 @@ class Operator(Value):
elif len(op_shapes) == 2: elif len(op_shapes) == 2:
a_shape, b_shape = op_shapes a_shape, b_shape = op_shapes
if self.operator == "+": if self.operator == "+":
o_shape = _bitwise_binary_shape(*op_shapes) o_shape = Shape._unify(op_shapes)
return Shape(o_shape.width + 1, o_shape.signed) return Shape(o_shape.width + 1, o_shape.signed)
if self.operator == "-": if self.operator == "-":
o_shape = _bitwise_binary_shape(*op_shapes) o_shape = Shape._unify(op_shapes)
return Shape(o_shape.width + 1, True) return Shape(o_shape.width + 1, True)
if self.operator == "*": if self.operator == "*":
return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed) return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed)
@ -1568,7 +1576,7 @@ class Operator(Value):
if self.operator in ("<", "<=", "==", "!=", ">", ">="): if self.operator in ("<", "<=", "==", "!=", ">", ">="):
return Shape(1, False) return Shape(1, False)
if self.operator in ("&", "|", "^"): if self.operator in ("&", "|", "^"):
return _bitwise_binary_shape(*op_shapes) return Shape._unify(op_shapes)
if self.operator == "<<": if self.operator == "<<":
assert not b_shape.signed assert not b_shape.signed
return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed) return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed)
@ -1578,7 +1586,7 @@ class Operator(Value):
elif len(op_shapes) == 3: elif len(op_shapes) == 3:
if self.operator == "m": if self.operator == "m":
s_shape, a_shape, b_shape = op_shapes s_shape, a_shape, b_shape = op_shapes
return _bitwise_binary_shape(a_shape, b_shape) return Shape._unify((a_shape, b_shape))
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
def _lhs_signals(self): def _lhs_signals(self):
@ -2254,27 +2262,9 @@ class ArrayProxy(Value):
return (Value.cast(elem) for elem in self.elems) return (Value.cast(elem) for elem in self.elems)
def shape(self): def shape(self):
unsigned_width = signed_width = 0
has_unsigned = has_signed = False
for elem_shape in (elem.shape() for elem in self._iter_as_values()):
if elem_shape.signed:
has_signed = True
signed_width = max(signed_width, elem_shape.width)
else:
has_unsigned = True
unsigned_width = max(unsigned_width, elem_shape.width)
# The shape of the proxy must be such that it preserves the mathematical value of the array # The shape of the proxy must be such that it preserves the mathematical value of the array
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree. # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
# To ensure this holds, if the array contains both signed and unsigned values, make sure return Shape._unify(elem.shape() for elem in self._iter_as_values())
# that every unsigned value is zero-extended by at least one bit.
if has_signed and has_unsigned and unsigned_width >= signed_width:
# Array contains both signed and unsigned values, and at least one of the unsigned
# values won't be zero-extended otherwise.
return signed(unsigned_width + 1)
else:
# Array contains values of the same signedness, or else all of the unsigned values
# are zero-extended.
return Shape(max(unsigned_width, signed_width), has_signed)
def _lhs_signals(self): def _lhs_signals(self):
signals = union((elem._lhs_signals() for elem in self._iter_as_values()), signals = union((elem._lhs_signals() for elem in self._iter_as_values()),

View file

@ -677,16 +677,13 @@ class NetlistEmitter:
def unify_shapes_bitwise(self, def unify_shapes_bitwise(self,
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool): operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
if signed_a == signed_b: shape = _ast.Shape._unify((
width = max(len(operand_a), len(operand_b)) _ast.Shape(len(operand_a), signed_a),
elif signed_a: _ast.Shape(len(operand_b), signed_b),
width = max(len(operand_a), len(operand_b) + 1) ))
else: # signed_b operand_a = self.extend(operand_a, signed_a, shape.width)
width = max(len(operand_a) + 1, len(operand_b)) operand_b = self.extend(operand_b, signed_b, shape.width)
operand_a = self.extend(operand_a, signed_a, width) return (operand_a, operand_b, shape.signed)
operand_b = self.extend(operand_b, signed_b, width)
signed = signed_a or signed_b
return (operand_a, operand_b, signed)
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]: def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
"""Emits a RHS value, returns a tuple of (value, is_signed)""" """Emits a RHS value, returns a tuple of (value, is_signed)"""
@ -825,19 +822,11 @@ class NetlistEmitter:
signed = False signed = False
elif isinstance(value, _ast.ArrayProxy): elif isinstance(value, _ast.ArrayProxy):
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems] elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
width = 0 shape = _ast.Shape._unify(
signed = False _ast.Shape(len(value), signed)
for elem, elem_signed in elems: for value, signed in elems
if elem_signed: )
if not signed: elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
width += 1
signed = True
width = max(width, len(elem))
elif signed:
width = max(width, len(elem) + 1)
else:
width = max(width, len(elem))
elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems)
index, _signed = self.emit_rhs(module_idx, value.index) index, _signed = self.emit_rhs(module_idx, value.index)
conds = [] conds = []
for case_index in range(len(elems)): for case_index in range(len(elems)):
@ -855,7 +844,8 @@ class NetlistEmitter:
] ]
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments, cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
src_loc=value.src_loc) src_loc=value.src_loc)
result = self.netlist.add_value_cell(width, cell) result = self.netlist.add_value_cell(shape.width, cell)
signed = shape.signed
elif isinstance(value, _ast.Cat): elif isinstance(value, _ast.Cat):
nets = [] nets = []
for val in value.parts: for val in value.parts: