diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 5c04efb..cf08e50 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -159,6 +159,28 @@ class Shape: return (isinstance(other, Shape) and self.width == other.width and self.signed == other.signed) + @staticmethod + def _unify(shapes): + """Returns the minimal shape that contains all shapes from the list. + + If no shapes passed in, returns unsigned(0). + """ + unsigned_width = signed_width = 0 + has_signed = False + for shape in shapes: + assert isinstance(shape, Shape) + if shape.signed: + has_signed = True + signed_width = max(signed_width, shape.width) + else: + unsigned_width = max(unsigned_width, shape.width) + # If all shapes unsigned, simply take max. + if not has_signed: + return unsigned(unsigned_width) + # Otherwise, result is signed. All unsigned inputs, if any, + # need to be converted to signed by adding a zero bit. + return signed(max(signed_width, unsigned_width + 1)) + def unsigned(width): """Returns :py:`Shape(width, signed=False)`.""" @@ -1524,20 +1546,6 @@ class Operator(Value): return self._operands def shape(self): - def _bitwise_binary_shape(a_shape, b_shape): - if not a_shape.signed and not b_shape.signed: - # both operands unsigned - return unsigned(max(a_shape.width, b_shape.width)) - elif a_shape.signed and b_shape.signed: - # both operands signed - return signed(max(a_shape.width, b_shape.width)) - elif not a_shape.signed and b_shape.signed: - # first operand unsigned (add sign bit), second operand signed - return signed(max(a_shape.width + 1, b_shape.width)) - else: - # first signed, second operand unsigned (add sign bit) - return signed(max(a_shape.width, b_shape.width + 1)) - op_shapes = list(map(lambda x: x.shape(), self.operands)) if len(op_shapes) == 1: a_shape, = op_shapes @@ -1554,10 +1562,10 @@ class Operator(Value): elif len(op_shapes) == 2: a_shape, b_shape = op_shapes if self.operator == "+": - o_shape = _bitwise_binary_shape(*op_shapes) + o_shape = Shape._unify(op_shapes) return Shape(o_shape.width + 1, o_shape.signed) if self.operator == "-": - o_shape = _bitwise_binary_shape(*op_shapes) + o_shape = Shape._unify(op_shapes) return Shape(o_shape.width + 1, True) if self.operator == "*": return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed) @@ -1568,7 +1576,7 @@ class Operator(Value): if self.operator in ("<", "<=", "==", "!=", ">", ">="): return Shape(1, False) if self.operator in ("&", "|", "^"): - return _bitwise_binary_shape(*op_shapes) + return Shape._unify(op_shapes) if self.operator == "<<": assert not b_shape.signed return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed) @@ -1578,7 +1586,7 @@ class Operator(Value): elif len(op_shapes) == 3: if self.operator == "m": s_shape, a_shape, b_shape = op_shapes - return _bitwise_binary_shape(a_shape, b_shape) + return Shape._unify((a_shape, b_shape)) raise NotImplementedError # :nocov: def _lhs_signals(self): @@ -2254,27 +2262,9 @@ class ArrayProxy(Value): return (Value.cast(elem) for elem in self.elems) def shape(self): - unsigned_width = signed_width = 0 - has_unsigned = has_signed = False - for elem_shape in (elem.shape() for elem in self._iter_as_values()): - if elem_shape.signed: - has_signed = True - signed_width = max(signed_width, elem_shape.width) - else: - has_unsigned = True - unsigned_width = max(unsigned_width, elem_shape.width) # The shape of the proxy must be such that it preserves the mathematical value of the array # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree. - # To ensure this holds, if the array contains both signed and unsigned values, make sure - # that every unsigned value is zero-extended by at least one bit. - if has_signed and has_unsigned and unsigned_width >= signed_width: - # Array contains both signed and unsigned values, and at least one of the unsigned - # values won't be zero-extended otherwise. - return signed(unsigned_width + 1) - else: - # Array contains values of the same signedness, or else all of the unsigned values - # are zero-extended. - return Shape(max(unsigned_width, signed_width), has_signed) + return Shape._unify(elem.shape() for elem in self._iter_as_values()) def _lhs_signals(self): signals = union((elem._lhs_signals() for elem in self._iter_as_values()), diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index ef883bf..6f2da52 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -677,16 +677,13 @@ class NetlistEmitter: def unify_shapes_bitwise(self, operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool): - if signed_a == signed_b: - width = max(len(operand_a), len(operand_b)) - elif signed_a: - width = max(len(operand_a), len(operand_b) + 1) - else: # signed_b - width = max(len(operand_a) + 1, len(operand_b)) - operand_a = self.extend(operand_a, signed_a, width) - operand_b = self.extend(operand_b, signed_b, width) - signed = signed_a or signed_b - return (operand_a, operand_b, signed) + shape = _ast.Shape._unify(( + _ast.Shape(len(operand_a), signed_a), + _ast.Shape(len(operand_b), signed_b), + )) + operand_a = self.extend(operand_a, signed_a, shape.width) + operand_b = self.extend(operand_b, signed_b, shape.width) + return (operand_a, operand_b, shape.signed) def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]: """Emits a RHS value, returns a tuple of (value, is_signed)""" @@ -825,19 +822,11 @@ class NetlistEmitter: signed = False elif isinstance(value, _ast.ArrayProxy): elems = [self.emit_rhs(module_idx, elem) for elem in value.elems] - width = 0 - signed = False - for elem, elem_signed in elems: - if elem_signed: - if not signed: - width += 1 - signed = True - width = max(width, len(elem)) - elif signed: - width = max(width, len(elem) + 1) - else: - width = max(width, len(elem)) - elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems) + shape = _ast.Shape._unify( + _ast.Shape(len(value), signed) + for value, signed in elems + ) + elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems) index, _signed = self.emit_rhs(module_idx, value.index) conds = [] for case_index in range(len(elems)): @@ -855,7 +844,8 @@ class NetlistEmitter: ] cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments, src_loc=value.src_loc) - result = self.netlist.add_value_cell(width, cell) + result = self.netlist.add_value_cell(shape.width, cell) + signed = shape.signed elif isinstance(value, _ast.Cat): nets = [] for val in value.parts: