From f95fe451861066ae7e3ae6c9eb85778de5e949e0 Mon Sep 17 00:00:00 2001 From: Catherine Date: Wed, 23 Aug 2023 07:48:33 +0000 Subject: [PATCH] Implement RFC 22: Add `ValueCastable.shape()`. Fixes #794. Closes #876. --- amaranth/hdl/ast.py | 9 ++++++++- amaranth/lib/data.py | 24 ++++++++++-------------- docs/changes.rst | 6 ++++-- tests/test_hdl_ast.py | 26 ++++++++++++++++++++++++-- tests/test_lib_data.py | 35 +++++++++++++++++++++-------------- 5 files changed, 67 insertions(+), 33 deletions(-) diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index 0a728c6..e1ee694 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -1110,7 +1110,11 @@ class Signal(Value, DUID, metaclass=_SignalMeta): new_name = other.name + str(name_suffix) else: new_name = tracer.get_var_name(depth=2 + src_loc_at, default="$like") - kw = dict(shape=Value.cast(other).shape(), name=new_name) + if isinstance(other, ValueCastable): + shape = other.shape() + else: + shape = Value.cast(other).shape() + kw = dict(shape=shape, name=new_name) if isinstance(other, Signal): kw.update(reset=other.reset, reset_less=other.reset_less, attrs=other.attrs, decoder=other.decoder) @@ -1363,6 +1367,9 @@ class ValueCastable: if not hasattr(cls, "as_value"): raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " "the `as_value` method") + if not hasattr(cls, "shape"): + raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " + "the `shape` method") if not hasattr(cls.as_value, "_ValueCastable__memoized"): raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate " "the `as_value` method with the `ValueCastable.lowermethod` decorator") diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 605b1fc..fb231ef 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -115,20 +115,6 @@ class Layout(ShapeCastable, metaclass=ABCMeta): raise TypeError("Object {!r} cannot be converted to a data layout" .format(obj)) - @staticmethod - def of(obj): - """Extract the layout that was used to create a view. - - Raises - ------ - TypeError - If ``obj`` is not a :class:`View` instance. - """ - if not isinstance(obj, View): - raise TypeError("Object {!r} is not a data view" - .format(obj)) - return obj._View__orig_layout - @abstractmethod def __iter__(self): """Iterate fields in the layout. @@ -611,6 +597,16 @@ class View(ValueCastable): self.__layout = cast_layout self.__target = cast_target + def shape(self): + """Get layout of this view. + + Returns + ------- + :class:`Layout` + The ``layout`` provided when constructing the view. + """ + return self.__orig_layout + @ValueCastable.lowermethod def as_value(self): """Get underlying value. diff --git a/docs/changes.rst b/docs/changes.rst index 0b57f04..66e278e 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -42,18 +42,20 @@ Implemented RFCs .. _RFC 9: https://amaranth-lang.org/rfcs/0009-const-init-shape-castable.html .. _RFC 10: https://amaranth-lang.org/rfcs/0010-move-repl-to-value.html .. _RFC 15: https://amaranth-lang.org/rfcs/0015-lifting-shape-castables.html +.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html * `RFC 1`_: Aggregate data structure library * `RFC 3`_: Enumeration shapes * `RFC 4`_: Constant-castable expressions -* `RFC 5`_: Remove Const.normalize +* `RFC 5`_: Remove ``Const.normalize`` * `RFC 6`_: CRC generator * `RFC 8`_: Aggregate extensibility * `RFC 9`_: Constant initialization for shape-castable objects * `RFC 8`_: Aggregate extensibility * `RFC 9`_: Constant initialization for shape-castable objects -* `RFC 10`_: Move Repl to Value.replicate +* `RFC 10`_: Move ``Repl`` to ``Value.replicate`` * `RFC 15`_: Lifting shape-castable objects +* `RFC 22`_: Define ``ValueCastable.shape()`` Language changes diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 3177807..0cb8865 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1183,6 +1183,9 @@ class MockValueCastable(ValueCastable): def __init__(self, dest): self.dest = dest + def shape(self): + return Value.cast(self.dest).shape() + @ValueCastable.lowermethod def as_value(self): return self.dest @@ -1192,6 +1195,9 @@ class MockValueCastableChanges(ValueCastable): def __init__(self, width=0): self.width = width + def shape(self): + return unsigned(self.width) + @ValueCastable.lowermethod def as_value(self): return Signal(self.width) @@ -1201,6 +1207,9 @@ class MockValueCastableCustomGetattr(ValueCastable): def __init__(self): pass + def shape(self): + assert False + @ValueCastable.lowermethod def as_value(self): return Const(0) @@ -1218,17 +1227,30 @@ class ValueCastableTestCase(FHDLTestCase): def __init__(self): pass + def shape(self): + pass + def as_value(self): return Signal() def test_no_override(self): with self.assertRaisesRegex(TypeError, - r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must " + r"^Class 'MockValueCastableNoOverrideAsValue' deriving from `ValueCastable` must " r"override the `as_value` method$"): - class MockValueCastableNoOverride(ValueCastable): + class MockValueCastableNoOverrideAsValue(ValueCastable): def __init__(self): pass + with self.assertRaisesRegex(TypeError, + r"^Class 'MockValueCastableNoOverrideShapec' deriving from `ValueCastable` must " + r"override the `shape` method$"): + class MockValueCastableNoOverrideShapec(ValueCastable): + def __init__(self): + pass + + def as_value(self): + return Signal() + def test_memoized(self): vc = MockValueCastableChanges(1) sig1 = vc.as_value() diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index 2dd1928..7c62177 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -365,11 +365,6 @@ class LayoutTestCase(FHDLTestCase): r"^Shape-castable object <.+> casts to itself$"): Layout.cast(sc) - def test_of_wrong(self): - with self.assertRaisesRegex(TypeError, - r"^Object <.+> is not a data view$"): - Layout.of(object()) - def test_eq_wrong_recur(self): sc = MockShapeCastable(None) sc.shape = sc @@ -379,7 +374,7 @@ class LayoutTestCase(FHDLTestCase): sl = StructLayout({"f": unsigned(1)}) s = Signal(1) v = sl(s) - self.assertIs(Layout.of(v), sl) + self.assertIs(v.shape(), sl) self.assertIs(v.as_value(), s) def test_const(self): @@ -621,6 +616,11 @@ class ViewTestCase(FHDLTestCase): r"and may only be accessed by indexing$"): Signal(StructLayout({"_c": signed(1)}))._c + def test_signal_like(self): + s1 = Signal(StructLayout({"a": unsigned(1)})) + s2 = Signal.like(s1) + self.assertEqual(s2.shape(), StructLayout({"a": unsigned(1)})) + def test_bug_837_array_layout_getitem_str(self): with self.assertRaisesRegex(TypeError, r"^Views with array layout may only be indexed with an integer or a value, " @@ -646,7 +646,7 @@ class StructTestCase(FHDLTestCase): })) v = Signal(S) - self.assertEqual(Layout.of(v), S) + self.assertEqual(v.shape(), S) self.assertEqual(Value.cast(v).shape(), Shape.cast(S)) self.assertEqual(Value.cast(v).name, "v") self.assertRepr(v.a, "(slice (sig v) 0:1)") @@ -666,11 +666,11 @@ class StructTestCase(FHDLTestCase): self.assertEqual(Shape.cast(S), unsigned(9)) v = Signal(S) - self.assertIs(Layout.of(v), S) + self.assertIs(v.shape(), S) self.assertIsInstance(v, S) - self.assertIs(Layout.of(v.b), R) + self.assertIs(v.b.shape(), R) self.assertIsInstance(v.b, R) - self.assertIs(Layout.of(v.b.q), Q) + self.assertIs(v.b.q.shape(), Q) self.assertIsInstance(v.b.q, View) self.assertRepr(v.b.p, "(slice (slice (sig v) 1:9) 0:4)") self.assertRepr(v.b.q.as_value(), "(slice (slice (sig v) 1:9) 4:8)") @@ -747,10 +747,17 @@ class StructTestCase(FHDLTestCase): b: int c: str = "x" - self.assertEqual(Layout.of(Signal(S)), StructLayout({"a": unsigned(1)})) + self.assertEqual(Layout.cast(S), StructLayout({"a": unsigned(1)})) self.assertEqual(S.__annotations__, {"b": int, "c": str}) self.assertEqual(S.c, "x") + def test_signal_like(self): + class S(Struct): + a: 1 + s1 = Signal(S) + s2 = Signal.like(s1) + self.assertEqual(s2.shape(), S) + class UnionTestCase(FHDLTestCase): def test_construct(self): @@ -765,7 +772,7 @@ class UnionTestCase(FHDLTestCase): })) v = Signal(U) - self.assertEqual(Layout.of(v), U) + self.assertEqual(v.shape(), U) self.assertEqual(Value.cast(v).shape(), Shape.cast(U)) self.assertRepr(v.a, "(slice (sig v) 0:1)") self.assertRepr(v.b, "(s (slice (sig v) 0:3))") @@ -887,7 +894,7 @@ class RFCExamplesTestCase(TestCase): view1 = Signal(layout1) self.assertIsInstance(view1, View) - self.assertEqual(Layout.of(view1), layout1) + self.assertEqual(view1.shape(), layout1) self.assertEqual(view1.as_value().shape(), unsigned(3)) m1 = Module() @@ -933,4 +940,4 @@ class RFCExamplesTestCase(TestCase): self.assertEqual(layout1, Layout.cast(SomeVariant)) - self.assertIs(SomeVariant, Layout.of(view2)) + self.assertIs(SomeVariant, view2.shape())