From 54d5c4c047f0d8dd010a599dc1fccbeafe026995 Mon Sep 17 00:00:00 2001 From: Catherine Date: Fri, 12 May 2023 22:04:35 +0000 Subject: [PATCH] Implement RFC 9: Constant initialization for shape-castable objects. See amaranth-lang/rfcs#9 and #771. --- amaranth/hdl/ast.py | 29 ++++++++++++++----- amaranth/lib/data.py | 56 ++++++++++++++++++++++-------------- amaranth/lib/enum.py | 11 +++++++ tests/test_hdl_ast.py | 27 ++++++++++++++++++ tests/test_hdl_dsl.py | 2 +- tests/test_lib_data.py | 65 ++++++++++++++++++++++++++++++++++++++---- tests/test_lib_enum.py | 33 +++++++++++++++------ 7 files changed, 181 insertions(+), 42 deletions(-) diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index bf3f526..7fe1d56 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -45,6 +45,9 @@ class ShapeCastable: if not hasattr(cls, "as_shape"): raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " f"the `as_shape` method") + if not hasattr(cls, "const"): + raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " + f"the `const` method") class Shape: @@ -988,7 +991,7 @@ class Signal(Value, DUID): decoder : function """ - def __init__(self, shape=None, *, name=None, reset=0, reset_less=False, + def __init__(self, shape=None, *, name=None, reset=None, reset_less=False, attrs=None, decoder=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) @@ -1005,12 +1008,24 @@ class Signal(Value, DUID): self.signed = shape.signed orig_reset = reset - try: - reset = Const.cast(reset) - except TypeError: - raise TypeError("Reset value must be a constant-castable expression, not {!r}" - .format(orig_reset)) - if orig_reset not in (0, -1): # Avoid false positives for all-zeroes and all-ones + if isinstance(orig_shape, ShapeCastable): + try: + reset = Const.cast(orig_shape.const(reset)) + except Exception: + raise TypeError("Reset value must be a constant initializer of {!r}" + .format(orig_shape)) + if reset.shape() != Shape.cast(orig_shape): + raise ValueError("Constant returned by {!r}.const() must have the shape that " + "it casts to, {!r}, and not {!r}" + .format(orig_shape, Shape.cast(orig_shape), + reset.shape())) + else: + try: + reset = Const.cast(reset or 0) + except TypeError: + raise TypeError("Reset value must be a constant-castable expression, not {!r}" + .format(orig_reset)) + if orig_reset not in (None, 0, -1): # Avoid false positives for all-zeroes and all-ones if reset.shape().signed and not self.signed: warnings.warn( message="Reset value {!r} is signed, but the signal shape is {!r}" diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index f2f1c63..981ca82 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -201,29 +201,44 @@ class Layout(ShapeCastable, metaclass=ABCMeta): """ return View(self, target) - def _convert_to_int(self, value): - """Convert ``value``, which may be a dict or an array of field values, to an integer using - the representation defined by this layout. + def const(self, init): + """Convert a constant initializer to a constant. - This method is private because Amaranth does not currently have a concept of - a constant initializer; this requires an RFC. It will be renamed or removed - in a future version. + Converts ``init``, which may be a sequence or a mapping of field values, to a constant. + + Returns + ------- + :class:`Const` + A constant that has the same value as a view with this layout that was initialized with + an all-zero value and had every field assigned to the corresponding value in the order + in which they appear in ``init``. """ - if isinstance(value, Mapping): - iterator = value.items() - elif isinstance(value, Sequence): - iterator = enumerate(value) + if init is None: + iterator = iter(()) + elif isinstance(init, Mapping): + iterator = init.items() + elif isinstance(init, Sequence): + iterator = enumerate(init) else: - raise TypeError("Layout initializer must be a mapping or a sequence, not {!r}" - .format(value)) + raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}" + .format(init)) int_value = 0 for key, key_value in iterator: field = self[key] - if isinstance(field.shape, Layout): - key_value = field.shape._convert_to_int(key_value) - int_value |= Const(key_value, Shape.cast(field.shape)).value << field.offset - return int_value + cast_field_shape = Shape.cast(field.shape) + if isinstance(field.shape, ShapeCastable): + key_value = Const.cast(field.shape.const(key_value)) + if key_value.shape() != cast_field_shape: + raise ValueError("Constant returned by {!r}.const() must have the shape that " + "it casts to, {!r}, and not {!r}" + .format(field.shape, cast_field_shape, + key_value.shape())) + else: + key_value = Const(key_value, cast_field_shape) + int_value &= ~(((1 << cast_field_shape.width) - 1) << field.offset) + int_value |= key_value.value << field.offset + return Const(int_value, self.as_shape()) class StructLayout(Layout): @@ -617,13 +632,9 @@ class View(ValueCastable): "the {} bit(s) wide view layout" .format(len(cast_target), cast_layout.size)) else: - if reset is None: - reset = 0 - else: - reset = cast_layout._convert_to_int(reset) if reset_less is None: reset_less = False - cast_target = Signal(cast_layout, name=name, reset=reset, reset_less=reset_less, + cast_target = Signal(layout, name=name, reset=reset, reset_less=reset_less, attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1) self.__orig_layout = layout self.__layout = cast_layout @@ -774,6 +785,9 @@ class _AggregateMeta(ShapeCastable, type): .format(cls.__module__, cls.__qualname__)) return cls.__layout + def const(cls, init): + return cls.as_shape().const(init) + class _Aggregate(View, metaclass=_AggregateMeta): def __init__(self, target=None, *, name=None, reset=None, reset_less=None, diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 62f6036..a0382b5 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -137,6 +137,17 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): return value return super().__call__(value) + def const(cls, init): + # Same considerations apply as above. + if init is None: + # Signal with unspecified reset value passes ``None`` to :meth:`const`. + # Before RFC 9 was implemented, the unspecified reset value was 0, so this keeps + # the old behavior intact. + member = cls(0) + else: + member = cls(init) + return Const(member.value, cls.as_shape()) + class Enum(py_enum.Enum, metaclass=EnumMeta): """Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 7d95e20..88209bb 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -2,6 +2,7 @@ import warnings from enum import Enum from amaranth.hdl.ast import * +from amaranth.lib.enum import Enum as AmaranthEnum from .utils import * from amaranth._utils import _ignore_deprecated @@ -144,6 +145,9 @@ class MockShapeCastable(ShapeCastable): def as_shape(self): return self.dest + def const(self, obj): + return Const(obj, self.dest) + class ShapeCastableTestCase(FHDLTestCase): def test_no_override(self): @@ -995,6 +999,29 @@ class SignalTestCase(FHDLTestCase): r"not $"): Signal(1, reset=StringEnum.FOO) + def test_reset_shape_castable_const(self): + class CastableFromHex(ShapeCastable): + def as_shape(self): + return unsigned(8) + + def const(self, init): + return int(init, 16) + + s1 = Signal(CastableFromHex(), reset="aa") + self.assertEqual(s1.reset, 0xaa) + + with self.assertRaisesRegex(ValueError, + r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape " + r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"): + Signal(CastableFromHex(), reset="01") + + def test_reset_shape_castable_enum_wrong(self): + class EnumA(AmaranthEnum): + X = 1 + with self.assertRaisesRegex(TypeError, + r"^Reset value must be a constant initializer of $"): + Signal(EnumA) # implied reset=0 + def test_reset_signed_mismatch(self): with self.assertWarnsRegex(SyntaxWarning, r"^Reset value -2 is signed, but the signal shape is unsigned\(2\)$"): diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index c3ea0fa..d08878e 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -436,7 +436,7 @@ class DSLTestCase(FHDLTestCase): RED = 1 BLUE = 2 m = Module() - se = Signal(Color) + se = Signal(Color, reset=Color.RED) with m.Switch(se): with m.Case(Color.RED): m.d.comb += self.c1.eq(1) diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index ff72e9c..af54527 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -16,6 +16,9 @@ class MockShapeCastable(ShapeCastable): def as_shape(self): return self.shape + def const(self, init): + return Const(init, self.shape) + class FieldTestCase(TestCase): def test_construct(self): @@ -332,7 +335,7 @@ class FlexibleLayoutTestCase(TestCase): il[object()] -class LayoutTestCase(TestCase): +class LayoutTestCase(FHDLTestCase): def test_cast(self): sl = StructLayout({}) self.assertIs(Layout.cast(sl), sl) @@ -371,6 +374,53 @@ class LayoutTestCase(TestCase): self.assertIs(Layout.of(v), sl) self.assertIs(v.as_value(), s) + def test_const(self): + sl = StructLayout({ + "a": unsigned(1), + "b": unsigned(2) + }) + self.assertRepr(sl.const(None), "(const 3'd0)") + self.assertRepr(sl.const({"a": 0b1, "b": 0b10}), "(const 3'd5)") + + ul = UnionLayout({ + "a": unsigned(1), + "b": unsigned(2) + }) + self.assertRepr(ul.const({"a": 0b11}), "(const 2'd1)") + self.assertRepr(ul.const({"b": 0b10}), "(const 2'd2)") + self.assertRepr(ul.const({"a": 0b1, "b": 0b10}), "(const 2'd2)") + + def test_const_wrong(self): + sl = StructLayout({"f": unsigned(1)}) + with self.assertRaisesRegex(TypeError, + r"^Layout constant initializer must be a mapping or a sequence, not " + r"<.+?object.+?>$"): + sl.const(object()) + + def test_const_field_shape_castable(self): + class CastableFromHex(ShapeCastable): + def as_shape(self): + return unsigned(8) + + def const(self, init): + return int(init, 16) + + sl = StructLayout({"f": CastableFromHex()}) + self.assertRepr(sl.const({"f": "aa"}), "(const 8'd170)") + + with self.assertRaisesRegex(ValueError, + r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape " + r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"): + sl.const({"f": "01"}) + + def test_signal_reset(self): + sl = StructLayout({ + "a": unsigned(1), + "b": unsigned(2) + }) + self.assertEqual(Signal(sl).reset, 0) + self.assertEqual(Signal(sl, reset={"a": 0b1, "b": 0b10}).reset, 5) + class ViewTestCase(FHDLTestCase): def test_construct(self): @@ -434,7 +484,7 @@ class ViewTestCase(FHDLTestCase): def test_signal_reset_wrong(self): with self.assertRaisesRegex(TypeError, - r"^Layout initializer must be a mapping or a sequence, not 1$"): + r"^Reset value must be a constant initializer of StructLayout\({}\)$"): View(StructLayout({}), reset=0b1) def test_target_signal_wrong(self): @@ -483,6 +533,9 @@ class ViewTestCase(FHDLTestCase): def __call__(self, value): return value[::-1] + def const(self, init): + return Const(init, 2) + v = View(StructLayout({ "f": Reverser() })) @@ -497,13 +550,15 @@ class ViewTestCase(FHDLTestCase): def __call__(self, value): pass + def const(self, init): + return Const(init, 2) + v = View(StructLayout({ "f": WrongCastable() })) with self.assertRaisesRegex(TypeError, - r"^" - r"\.WrongCastable object at 0x.+?>\.__call__\(\) must return a value or " - r"a value-castable object, not None$"): + r"^<.+?\.WrongCastable.+?>\.__call__\(\) must return a value or a value-castable " + r"object, not None$"): v.f def test_index_wrong_missing(self): diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 5f86945..2bf2d19 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -5,18 +5,12 @@ from .utils import * class EnumTestCase(FHDLTestCase): - def test_non_int_members(self): + def test_members_non_int(self): # Mustn't raise to be a drop-in replacement for Enum. class EnumA(Enum): A = "str" - def test_non_const_non_int_members_wrong(self): - with self.assertRaisesRegex(TypeError, - r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"): - class EnumA(Enum, shape=unsigned(1)): - A = "str" - - def test_const_non_int_members(self): + def test_members_const_non_int(self): class EnumA(Enum): A = C(0) B = C(1) @@ -59,6 +53,12 @@ class EnumTestCase(FHDLTestCase): B = -5 self.assertEqual(Shape.cast(EnumD), signed(4)) + def test_shape_members_non_const_non_int_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"): + class EnumA(Enum, shape=unsigned(1)): + A = "str" + def test_shape_explicit_wrong_signed_mismatch(self): with self.assertWarnsRegex(SyntaxWarning, r"^Value -1 of enumeration member 'A' is signed, but the enumeration " @@ -88,6 +88,23 @@ class EnumTestCase(FHDLTestCase): A = 1 self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)") + def test_const_no_shape(self): + class EnumA(Enum): + Z = 0 + A = 10 + B = 20 + self.assertRepr(EnumA.const(None), "(const 5'd0)") + self.assertRepr(EnumA.const(10), "(const 5'd10)") + self.assertRepr(EnumA.const(EnumA.A), "(const 5'd10)") + + def test_const_shape(self): + class EnumA(Enum, shape=8): + Z = 0 + A = 10 + self.assertRepr(EnumA.const(None), "(const 8'd0)") + self.assertRepr(EnumA.const(10), "(const 8'd10)") + self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)") + def test_shape_implicit_wrong_in_concat(self): class EnumA(Enum): A = 0