From 29502442fba903f3476b9fcecc77153e6b54dfd4 Mon Sep 17 00:00:00 2001 From: Catherine Date: Tue, 31 Jan 2023 12:24:56 +0000 Subject: [PATCH] hdl.ast: remove Shape<>tuple casts. Closes #691. --- amaranth/back/rtlil.py | 126 ++++++++++++++---------------- amaranth/compat/fhdl/structure.py | 5 +- amaranth/hdl/ast.py | 99 ++++++++++------------- amaranth/hdl/dsl.py | 3 +- docs/changes.rst | 1 + tests/test_hdl_ast.py | 26 +++--- 6 files changed, 116 insertions(+), 144 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 8150213..c052c7b 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -430,7 +430,7 @@ class _ValueCompiler(xfrm.ValueVisitor): elem = value.elems[index.value] else: elem = value.elems[-1] - return self.match_shape(elem, *value.shape()) + return self.match_shape(elem, value.shape()) else: max_index = 1 << len(value.index) max_elem = len(value.elems) @@ -475,12 +475,12 @@ class _RHSValueCompiler(_ValueCompiler): if value in self.s.anys: return self.s.anys[value] - res_bits, res_sign = value.shape() - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + res_shape = value.shape() + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell("$anyconst", ports={ "\\Y": res, }, params={ - "WIDTH": res_bits, + "WIDTH": res_shape.width, }, src=_src(value.src_loc)) self.s.anys[value] = res return res @@ -489,12 +489,12 @@ class _RHSValueCompiler(_ValueCompiler): if value in self.s.anys: return self.s.anys[value] - res_bits, res_sign = value.shape() - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + res_shape = value.shape() + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell("$anyseq", ports={ "\\Y": res, }, params={ - "WIDTH": res_bits, + "WIDTH": res_shape.width, }, src=_src(value.src_loc)) self.s.anys[value] = res return res @@ -509,74 +509,71 @@ class _RHSValueCompiler(_ValueCompiler): # These operators don't change the bit pattern, only its interpretation. return self(arg) - arg_bits, arg_sign = arg.shape() - res_bits, res_sign = value.shape() - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + arg_shape, res_shape = arg.shape(), value.shape() + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell(self.operator_map[(1, value.operator)], ports={ "\\A": self(arg), "\\Y": res, }, params={ - "A_SIGNED": arg_sign, - "A_WIDTH": arg_bits, - "Y_WIDTH": res_bits, + "A_SIGNED": arg_shape.signed, + "A_WIDTH": arg_shape.width, + "Y_WIDTH": res_shape.width, }, src=_src(value.src_loc)) return res - def match_shape(self, value, new_bits, new_sign): + def match_shape(self, value, new_shape): if isinstance(value, ast.Const): - return self(ast.Const(value.value, ast.Shape(new_bits, new_sign))) + return self(ast.Const(value.value, new_shape)) - value_bits, value_sign = value.shape() - if new_bits <= value_bits: - return self(ast.Slice(value, 0, new_bits)) + value_shape = value.shape() + if new_shape.width <= value_shape.width: + return self(ast.Slice(value, 0, new_shape.width)) - res = self.s.rtlil.wire(width=new_bits, src=_src(value.src_loc)) + res = self.s.rtlil.wire(width=new_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell("$pos", ports={ "\\A": self(value), "\\Y": res, }, params={ - "A_SIGNED": value_sign, - "A_WIDTH": value_bits, - "Y_WIDTH": new_bits, + "A_SIGNED": value_shape.signed, + "A_WIDTH": value_shape.width, + "Y_WIDTH": new_shape.width, }, src=_src(value.src_loc)) return res def on_Operator_binary(self, value): lhs, rhs = value.operands - lhs_bits, lhs_sign = lhs.shape() - rhs_bits, rhs_sign = rhs.shape() - if lhs_sign == rhs_sign or value.operator in ("<<", ">>", "**"): + lhs_shape, rhs_shape, res_shape = lhs.shape(), rhs.shape(), value.shape() + if lhs_shape.signed == rhs_shape.signed or value.operator in ("<<", ">>", "**"): lhs_wire = self(lhs) rhs_wire = self(rhs) else: - lhs_bits = rhs_bits = max(lhs_bits + rhs_sign, rhs_bits + lhs_sign) - lhs_sign = rhs_sign = True - lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign) - rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign) - res_bits, res_sign = value.shape() - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + lhs_shape = rhs_shape = ast.signed(max(lhs_shape.width + rhs_shape.signed, + rhs_shape.width + lhs_shape.signed)) + lhs_wire = self.match_shape(lhs, lhs_shape) + rhs_wire = self.match_shape(rhs, rhs_shape) + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell(self.operator_map[(2, value.operator)], ports={ "\\A": lhs_wire, "\\B": rhs_wire, "\\Y": res, }, params={ - "A_SIGNED": lhs_sign, - "A_WIDTH": lhs_bits, - "B_SIGNED": rhs_sign, - "B_WIDTH": rhs_bits, - "Y_WIDTH": res_bits, + "A_SIGNED": lhs_shape.signed, + "A_WIDTH": lhs_shape.width, + "B_SIGNED": rhs_shape.signed, + "B_WIDTH": rhs_shape.width, + "Y_WIDTH": res_shape.width, }, src=_src(value.src_loc)) if value.operator in ("//", "%"): # RTLIL leaves division by zero undefined, but we require it to return zero. divmod_res = res - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell("$mux", ports={ "\\A": divmod_res, - "\\B": self(ast.Const(0, ast.Shape(res_bits, res_sign))), + "\\B": self(ast.Const(0, res_shape)), "\\S": self(rhs == 0), "\\Y": res, }, params={ - "WIDTH": res_bits + "WIDTH": res_shape.width }, src=_src(value.src_loc)) return res @@ -584,20 +581,17 @@ class _RHSValueCompiler(_ValueCompiler): sel, val1, val0 = value.operands if len(sel) != 1: sel = sel.bool() - val1_bits, val1_sign = val1.shape() - val0_bits, val0_sign = val0.shape() - res_bits, res_sign = value.shape() - val1_bits = val0_bits = res_bits = max(val1_bits, val0_bits, res_bits) - val1_wire = self.match_shape(val1, val1_bits, val1_sign) - val0_wire = self.match_shape(val0, val0_bits, val0_sign) - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + res_shape = value.shape() + val1_wire = self.match_shape(val1, res_shape) + val0_wire = self.match_shape(val0, res_shape) + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) self.s.rtlil.cell("$mux", ports={ "\\A": val0_wire, "\\B": val1_wire, "\\S": self(sel), "\\Y": res, }, params={ - "WIDTH": res_bits + "WIDTH": res_shape.width }, src=_src(value.src_loc)) return res @@ -624,10 +618,8 @@ class _RHSValueCompiler(_ValueCompiler): lhs, rhs = value.value, value.offset if value.stride != 1: rhs *= value.stride - lhs_bits, lhs_sign = lhs.shape() - rhs_bits, rhs_sign = rhs.shape() - res_bits, res_sign = value.shape() - res = self.s.rtlil.wire(width=res_bits, src=_src(value.src_loc)) + lhs_shape, rhs_shape, res_shape = lhs.shape(), rhs.shape(), value.shape() + res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) # Note: Verilog's x[o+:w] construct produces a $shiftx cell, not a $shift cell. # However, Amaranth's semantics defines the out-of-range bits to be zero, so it is correct # to use a $shift cell here instead, even though it produces less idiomatic Verilog. @@ -636,11 +628,11 @@ class _RHSValueCompiler(_ValueCompiler): "\\B": self(rhs), "\\Y": res, }, params={ - "A_SIGNED": lhs_sign, - "A_WIDTH": lhs_bits, - "B_SIGNED": rhs_sign, - "B_WIDTH": rhs_bits, - "Y_WIDTH": res_bits, + "A_SIGNED": lhs_shape.signed, + "A_WIDTH": lhs_shape.width, + "B_SIGNED": rhs_shape.signed, + "B_WIDTH": rhs_shape.width, + "Y_WIDTH": res_shape.width, }, src=_src(value.src_loc)) return res @@ -666,14 +658,14 @@ class _LHSValueCompiler(_ValueCompiler): raise TypeError # :nocov: - def match_shape(self, value, new_bits, new_sign): - value_bits, value_sign = value.shape() - if new_bits == value_bits: + def match_shape(self, value, new_shape): + value_shape = value.shape() + if new_shape.width == value_shape.width: return self(value) - elif new_bits < value_bits: - return self(ast.Slice(value, 0, new_bits)) - else: # new_bits > value_bits - dummy_bits = new_bits - value_bits + elif new_shape.width < value_shape.width: + return self(ast.Slice(value, 0, new_shape.width)) + else: # new_shape.width > value_shape.width + dummy_bits = new_shape.width - value_shape.width dummy_wire = self.s.rtlil.wire(dummy_bits) return "{{ {} {} }}".format(dummy_wire, self(value)) @@ -738,14 +730,12 @@ class _StatementCompiler(xfrm.StatementVisitor): def on_Assign(self, stmt): self._check_rhs(stmt.rhs) - lhs_bits, lhs_sign = stmt.lhs.shape() - rhs_bits, rhs_sign = stmt.rhs.shape() - if lhs_bits == rhs_bits: + lhs_shape, rhs_shape = stmt.lhs.shape(), stmt.rhs.shape() + if lhs_shape.width == rhs_shape.width: rhs_sigspec = self.rhs_compiler(stmt.rhs) else: # In RTLIL, LHS and RHS of assignment must have exactly same width. - rhs_sigspec = self.rhs_compiler.match_shape( - stmt.rhs, lhs_bits, lhs_sign) + rhs_sigspec = self.rhs_compiler.match_shape(stmt.rhs, lhs_shape) if self._wrap_assign: # In RTLIL, all assigns are logically sequenced before all switches, even if they are # interleaved in the source. In Amaranth, the source ordering is used. To handle this diff --git a/amaranth/compat/fhdl/structure.py b/amaranth/compat/fhdl/structure.py index d450e45..c0d4434 100644 --- a/amaranth/compat/fhdl/structure.py +++ b/amaranth/compat/fhdl/structure.py @@ -51,7 +51,10 @@ class CompatSignal(NativeSignal): else: if not (min is None and max is None): raise ValueError("Only one of bits/signedness or bounds may be specified") - shape = bits_sign + if isinstance(bits_sign, tuple): + shape = Shape(*bits_sign) + else: + shape = Shape.cast(bits_sign) super().__init__(shape=shape, name=name_override or name, reset=reset, reset_less=reset_less, diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index 586fbb3..dc243ab 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -78,10 +78,6 @@ class Shape: self.width = width self.signed = signed - # TODO(nmigen-0.4): remove - def __iter__(self): - return iter((self.width, self.signed)) - @staticmethod def cast(obj, *, src_loc_at=0): while True: @@ -89,14 +85,6 @@ class Shape: return obj elif isinstance(obj, int): return Shape(obj) - # TODO(nmigen-0.4): remove - elif isinstance(obj, tuple): - width, signed = obj - warnings.warn("instead of `{tuple}`, use `{constructor}({width})`" - .format(constructor="signed" if signed else "unsigned", width=width, - tuple=obj), - DeprecationWarning, stacklevel=2 + src_loc_at) - return Shape(width, signed) elif isinstance(obj, range): if len(obj) == 0: return Shape(0, obj.start < 0) @@ -216,8 +204,7 @@ class Value(metaclass=ABCMeta): return Operator("//", [other, self]) def __check_shamt(self): - width, signed = self.shape() - if signed: + if self.shape().signed: # Neither Python nor HDLs implement shifts by negative values; prohibit any shifts # by a signed value to make sure the shift amount can always be interpreted as # an unsigned value. @@ -264,8 +251,7 @@ class Value(metaclass=ABCMeta): return Operator(">=", [self, other]) def __abs__(self): - width, signed = self.shape() - if signed: + if self.shape().signed: return Mux(self >= 0, self, -self) else: return self @@ -607,10 +593,9 @@ class Const(Value): @staticmethod def normalize(value, shape): - width, signed = shape - mask = (1 << width) - 1 + mask = (1 << shape.width) - 1 value &= mask - if signed and value >> (width - 1): + if shape.signed and value >> (shape.width - 1): value |= ~mask return value @@ -623,8 +608,9 @@ class Const(Value): shape = Shape(shape, signed=self.value < 0) else: shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) - self.width, self.signed = shape - self.value = self.normalize(self.value, shape) + self.width = shape.width + self.signed = shape.signed + self.value = self.normalize(self.value, shape) def shape(self): return Shape(self.width, self.signed) @@ -645,10 +631,9 @@ C = Const # shorthand class AnyValue(Value, DUID): def __init__(self, shape, *, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) - self.width, self.signed = Shape.cast(shape, src_loc_at=1 + src_loc_at) - if not isinstance(self.width, int) or self.width < 0: - raise TypeError("Width must be a non-negative integer, not {!r}" - .format(self.width)) + shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) + self.width = shape.width + self.signed = shape.signed def shape(self): return Shape(self.width, self.signed) @@ -678,55 +663,53 @@ class Operator(Value): def shape(self): def _bitwise_binary_shape(a_shape, b_shape): - a_bits, a_sign = a_shape - b_bits, b_sign = b_shape - if not a_sign and not b_sign: + if not a_shape.signed and not b_shape.signed: # both operands unsigned - return Shape(max(a_bits, b_bits), False) - elif a_sign and b_sign: + return unsigned(max(a_shape.width, b_shape.width)) + elif a_shape.signed and b_shape.signed: # both operands signed - return Shape(max(a_bits, b_bits), True) - elif not a_sign and b_sign: + return signed(max(a_shape.width, b_shape.width)) + elif not a_shape.signed and b_shape.signed: # first operand unsigned (add sign bit), second operand signed - return Shape(max(a_bits + 1, b_bits), True) + return signed(max(a_shape.width + 1, b_shape.width)) else: # first signed, second operand unsigned (add sign bit) - return Shape(max(a_bits, b_bits + 1), True) + return signed(max(a_shape.width, b_shape.width + 1)) op_shapes = list(map(lambda x: x.shape(), self.operands)) if len(op_shapes) == 1: - (a_width, a_signed), = op_shapes + a_shape, = op_shapes if self.operator in ("+", "~"): - return Shape(a_width, a_signed) + return Shape(a_shape.width, a_shape.signed) if self.operator == "-": - return Shape(a_width + 1, True) + return Shape(a_shape.width + 1, True) if self.operator in ("b", "r|", "r&", "r^"): return Shape(1, False) if self.operator == "u": - return Shape(a_width, False) + return Shape(a_shape.width, False) if self.operator == "s": - return Shape(a_width, True) + return Shape(a_shape.width, True) elif len(op_shapes) == 2: - (a_width, a_signed), (b_width, b_signed) = op_shapes + a_shape, b_shape = op_shapes if self.operator in ("+", "-"): - width, signed = _bitwise_binary_shape(*op_shapes) - return Shape(width + 1, signed) + o_shape = _bitwise_binary_shape(*op_shapes) + return Shape(o_shape.width + 1, o_shape.signed) if self.operator == "*": - return Shape(a_width + b_width, a_signed or b_signed) + return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed) if self.operator == "//": - return Shape(a_width + b_signed, a_signed or b_signed) + return Shape(a_shape.width + b_shape.signed, a_shape.signed or b_shape.signed) if self.operator == "%": - return Shape(b_width, b_signed) + return Shape(b_shape.width, b_shape.signed) if self.operator in ("<", "<=", "==", "!=", ">", ">="): return Shape(1, False) if self.operator in ("&", "^", "|"): return _bitwise_binary_shape(*op_shapes) if self.operator == "<<": - assert not b_signed - return Shape(a_width + 2 ** b_width - 1, a_signed) + assert not b_shape.signed + return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed) if self.operator == ">>": - assert not b_signed - return Shape(a_width, a_signed) + assert not b_shape.signed + return Shape(a_shape.width, a_shape.signed) elif len(op_shapes) == 3: if self.operator == "m": s_shape, a_shape, b_shape = op_shapes @@ -982,9 +965,13 @@ class Signal(Value, DUID): raise TypeError("Name must be a string, not {!r}".format(name)) self.name = name or tracer.get_var_name(depth=2 + src_loc_at, default="$signal") + orig_shape = shape if shape is None: shape = unsigned(1) - self.width, self.signed = Shape.cast(shape, src_loc_at=1 + src_loc_at) + else: + shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) + self.width = shape.width + self.signed = shape.signed if isinstance(reset, Enum): reset = reset.value @@ -1003,8 +990,8 @@ class Signal(Value, DUID): self.attrs = OrderedDict(() if attrs is None else attrs) - if decoder is None and isinstance(shape, type) and issubclass(shape, Enum): - decoder = shape + if decoder is None and isinstance(orig_shape, type) and issubclass(orig_shape, Enum): + decoder = orig_shape if isinstance(decoder, type) and issubclass(decoder, Enum): def enum_decoder(value): try: @@ -1231,13 +1218,13 @@ class ArrayProxy(Value): def shape(self): unsigned_width = signed_width = 0 has_unsigned = has_signed = False - for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()): - if elem_signed: + for elem_shape in (elem.shape() for elem in self._iter_as_values()): + if elem_shape.signed: has_signed = True - signed_width = max(signed_width, elem_width) + signed_width = max(signed_width, elem_shape.width) else: has_unsigned = True - unsigned_width = max(unsigned_width, elem_width) + unsigned_width = max(unsigned_width, elem_shape.width) # The shape of the proxy must be such that it preserves the mathematical value of the array # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree. # To ensure this holds, if the array contains both signed and unsigned values, make sure diff --git a/amaranth/hdl/dsl.py b/amaranth/hdl/dsl.py index 279e7c5..7ad0e3e 100644 --- a/amaranth/hdl/dsl.py +++ b/amaranth/hdl/dsl.py @@ -210,8 +210,7 @@ class Module(_ModuleBuilderRoot, Elaboratable): def _check_signed_cond(self, cond): cond = Value.cast(cond) - width, signed = cond.shape() - if signed: + if cond.shape().signed: warnings.warn("Signed values in If/Elif conditions usually result from inverting " "Python booleans with ~, which leads to unexpected results. " "Replace `~flag` with `not flag`. (If this is a false positive, " diff --git a/docs/changes.rst b/docs/changes.rst index b30a7bb..9de9e26 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -24,6 +24,7 @@ Language changes .. currentmodule:: amaranth.hdl +* Removed: casting of :class:`Shape` to and from a ``(width, signed)`` tuple. * Added: :class:`ShapeCastable`, similar to :class:`ValueCastable`. * Added: :meth:`Value.as_signed` and :meth:`Value.as_unsigned` can be used on left-hand side of assignment (with no difference in behavior). * Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively. diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index b0cd9b2..1f14385 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -53,17 +53,18 @@ class ShapeTestCase(FHDLTestCase): def test_compare_tuple_wrong(self): with self.assertRaisesRegex(TypeError, - r"^Shapes may be compared with other Shapes and \(int, bool\) tuples, not \(2, 3\)$"): + r"^Shapes may be compared with other Shapes and \(int, bool\) tuples, " + r"not \(2, 3\)$"): Shape(1, True) == (2, 3) def test_repr(self): self.assertEqual(repr(Shape()), "unsigned(1)") self.assertEqual(repr(Shape(2, True)), "signed(2)") - def test_tuple(self): - width, signed = Shape() - self.assertEqual(width, 1) - self.assertEqual(signed, False) + def test_convert_tuple_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^cannot unpack non-iterable Shape object$"): + width, signed = Shape() def test_unsigned(self): s1 = unsigned(2) @@ -95,19 +96,10 @@ class ShapeTestCase(FHDLTestCase): r"^Width must be a non-negative integer, not -1$"): Shape.cast(-1) - def test_cast_tuple(self): - with warnings.catch_warnings(): - warnings.filterwarnings(action="ignore", category=DeprecationWarning) - s1 = Shape.cast((1, True)) - self.assertEqual(s1.width, 1) - self.assertEqual(s1.signed, True) - def test_cast_tuple_wrong(self): - with warnings.catch_warnings(): - warnings.filterwarnings(action="ignore", category=DeprecationWarning) - with self.assertRaisesRegex(TypeError, - r"^Width must be a non-negative integer, not -1$"): - Shape.cast((-1, True)) + with self.assertRaisesRegex(TypeError, + r"^Object \(1, True\) cannot be converted to an Amaranth shape$"): + Shape.cast((1, True)) def test_cast_range(self): s1 = Shape.cast(range(0, 8))