diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 55b825b..7d2efef 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -1197,11 +1197,27 @@ class ArrayProxy(Value): return (Value.cast(elem) for elem in self.elems) def shape(self): - width, signed = 0, False + unsigned_width = signed_width = 0 + has_unsigned = has_signed = False for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()): - width = max(width, elem_width + elem_signed) - signed = max(signed, elem_signed) - return Shape(width, signed) + if elem_signed: + has_signed = True + signed_width = max(signed_width, elem_width) + else: + has_unsigned = True + unsigned_width = max(unsigned_width, elem_width) + # The shape of the proxy must be such that it preserves the mathematical value of the array + # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree. + # To ensure this holds, if the array contains both signed and unsigned values, make sure + # that every unsigned value is zero-extended by at least one bit. + if has_signed and has_unsigned and unsigned_width >= signed_width: + # Array contains both signed and unsigned values, and at least one of the unsigned + # values won't be zero-extended otherwise. + return signed(unsigned_width + 1) + else: + # Array contains values of the same signedness, or else all of the unsigned values + # are zero-extended. + return Shape(max(unsigned_width, signed_width), has_signed) def _lhs_signals(self): signals = union((elem._lhs_signals() for elem in self._iter_as_values()), diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index af87415..af96e16 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -806,7 +806,29 @@ class ArrayProxyTestCase(FHDLTestCase): s = Signal(range(len(a))) v = a[s] self.assertEqual(v.p.shape(), unsigned(4)) - self.assertEqual(v.n.shape(), signed(6)) + self.assertEqual(v.n.shape(), signed(5)) + + def test_attr_shape_signed(self): + # [unsigned(1), unsigned(1)] → unsigned(1) + a1 = Array([1, 1]) + v1 = a1[Const(0)] + self.assertEqual(v1.shape(), unsigned(1)) + # [signed(1), signed(1)] → signed(1) + a2 = Array([-1, -1]) + v2 = a2[Const(0)] + self.assertEqual(v2.shape(), signed(1)) + # [unsigned(1), signed(2)] → signed(2) + a3 = Array([1, -2]) + v3 = a3[Const(0)] + self.assertEqual(v3.shape(), signed(2)) + # [unsigned(1), signed(1)] → signed(2); 1st operand padded with sign bit! + a4 = Array([1, -1]) + v4 = a4[Const(0)] + self.assertEqual(v4.shape(), signed(2)) + # [unsigned(2), signed(1)] → signed(3); 1st operand padded with sign bit! + a5 = Array([1, -1]) + v5 = a5[Const(0)] + self.assertEqual(v5.shape(), signed(2)) def test_repr(self): a = Array([1, 2, 3])