From bf16acf2f0bf7fa01ca027e60f4a2a9b6a92f1f5 Mon Sep 17 00:00:00 2001 From: Catherine Date: Tue, 5 Apr 2022 21:55:50 +0000 Subject: [PATCH] hdl.ast: implement ShapeCastable (like ValueCastable). Refs #693. --- amaranth/hdl/ast.py | 105 +++++++++++++++++++++++++++--------------- tests/test_hdl_ast.py | 40 +++++++++++++++- 2 files changed, 105 insertions(+), 40 deletions(-) diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index d99958b..d61f2dc 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -12,7 +12,7 @@ from .._unused import * __all__ = [ - "Shape", "signed", "unsigned", + "Shape", "signed", "unsigned", "ShapeCastable", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Array", "ArrayProxy", "Signal", "ClockSignal", "ResetSignal", @@ -32,6 +32,22 @@ class DUID: DUID.__next_uid += 1 +class ShapeCastable: + """Interface of user-defined objects that can be cast to :class:`Shape`s. + + An object deriving from ``ShapeCastable`` is automatically converted to a ``Shape`` when it is + used in a context where a ``Shape`` is expected. Such objects can contain a richer description + of the shape than what is supported by the core Amaranth language, yet still be transparently + used with it. + """ + def __new__(cls, *args, **kwargs): + self = super().__new__(cls) + if not hasattr(self, "as_shape"): + raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " + f"the `as_shape` method") + return self + + class Shape: """Bit width and signedness of a value. @@ -62,39 +78,48 @@ class Shape: self.width = width self.signed = signed + # TODO(nmigen-0.4): remove def __iter__(self): return iter((self.width, self.signed)) @staticmethod def cast(obj, *, src_loc_at=0): - if isinstance(obj, Shape): - return obj - if isinstance(obj, int): - return Shape(obj) - if isinstance(obj, tuple): - width, signed = obj - warnings.warn("instead of `{tuple}`, use `{constructor}({width})`" - .format(constructor="signed" if signed else "unsigned", width=width, - tuple=obj), - DeprecationWarning, stacklevel=2 + src_loc_at) - return Shape(width, signed) - if isinstance(obj, range): - if len(obj) == 0: - return Shape(0, obj.start < 0) - signed = obj.start < 0 or (obj.stop - obj.step) < 0 - width = max(bits_for(obj.start, signed), - bits_for(obj.stop - obj.step, signed)) - return Shape(width, signed) - if isinstance(obj, type) and issubclass(obj, Enum): - min_value = min(member.value for member in obj) - max_value = max(member.value for member in obj) - if not isinstance(min_value, int) or not isinstance(max_value, int): - raise TypeError("Only enumerations with integer values can be used " - "as value shapes") - signed = min_value < 0 or max_value < 0 - width = max(bits_for(min_value, signed), bits_for(max_value, signed)) - return Shape(width, signed) - raise TypeError("Object {!r} cannot be used as value shape".format(obj)) + while True: + if isinstance(obj, Shape): + return obj + elif isinstance(obj, int): + return Shape(obj) + # TODO(nmigen-0.4): remove + elif isinstance(obj, tuple): + width, signed = obj + warnings.warn("instead of `{tuple}`, use `{constructor}({width})`" + .format(constructor="signed" if signed else "unsigned", width=width, + tuple=obj), + DeprecationWarning, stacklevel=2 + src_loc_at) + return Shape(width, signed) + elif isinstance(obj, range): + if len(obj) == 0: + return Shape(0, obj.start < 0) + signed = obj.start < 0 or (obj.stop - obj.step) < 0 + width = max(bits_for(obj.start, signed), + bits_for(obj.stop - obj.step, signed)) + return Shape(width, signed) + elif isinstance(obj, type) and issubclass(obj, Enum): + min_value = min(member.value for member in obj) + max_value = max(member.value for member in obj) + if not isinstance(min_value, int) or not isinstance(max_value, int): + raise TypeError("Only enumerations with integer values can be used " + "as value shapes") + signed = min_value < 0 or max_value < 0 + width = max(bits_for(min_value, signed), bits_for(max_value, signed)) + return Shape(width, signed) + elif isinstance(obj, ShapeCastable): + new_obj = obj.as_shape() + else: + raise TypeError("Object {!r} cannot be converted to an Amaranth shape".format(obj)) + if new_obj is obj: + raise RecursionError("Shape-castable object {!r} casts to itself".format(obj)) + obj = new_obj def __repr__(self): if self.signed: @@ -103,6 +128,7 @@ class Shape: return "unsigned({})".format(self.width) def __eq__(self, other): + # TODO(nmigen-0.4): remove if isinstance(other, tuple) and len(other) == 2: width, signed = other if isinstance(width, int) and isinstance(signed, bool): @@ -112,9 +138,11 @@ class Shape: "not {!r}" .format(other)) if not isinstance(other, Shape): - raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, " - "not {!r}" - .format(other)) + try: + other = self.__class__.cast(other) + except TypeError as e: + raise TypeError("Shapes may be compared with shape-castable objects, not {!r}" + .format(other)) from e return self.width == other.width and self.signed == other.signed @@ -1282,12 +1310,14 @@ class UserValue(Value): class ValueCastable: - """Interface of objects can be cast to :class:`Value`s. + """Interface of user-defined objects that can be cast to :class:`Value`s. - A ``ValueCastable`` can be cast to ``Value``, meaning its precise representation does not have - to be immediately known. This is useful in certain metaprogramming scenarios. Instead of - providing fixed semantics upfront, it is kept abstract for as long as possible, only being - cast to a concrete Amaranth value when required. + An object deriving from ``ValueCastable`` is automatically converted to a ``Value`` when it is + used in a context where a ``Value`` is expected. Such objects can implement different or + richer semantics than what is supported by the core Amaranth language, yet still be + transparently used with it as long as the final underlying representation is a single Amaranth + ``Value``. These objects also need not commit to a specific representation until they are + converted to a concrete Amaranth value. Note that it is necessary to ensure that Amaranth's view of representation of all values stays internally consistent. The class deriving from ``ValueCastable`` must decorate the ``as_value`` @@ -1301,7 +1331,6 @@ class ValueCastable: if not hasattr(self, "as_value"): raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " "the `as_value` method") - if not hasattr(self.as_value, "_ValueCastable__memoized"): raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate " "the `as_value` method with the `ValueCastable.lowermethod` decorator") diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index d0945e1..55e5a88 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -42,7 +42,7 @@ class ShapeTestCase(FHDLTestCase): def test_compare_wrong(self): with self.assertRaisesRegex(TypeError, - r"^Shapes may be compared with other Shapes and \(int, bool\) tuples, not 'hi'$"): + r"^Shapes may be compared with shape-castable objects, not 'hi'$"): Shape(1, True) == 'hi' def test_compare_tuple_wrong(self): @@ -141,10 +141,46 @@ class ShapeTestCase(FHDLTestCase): def test_cast_bad(self): with self.assertRaisesRegex(TypeError, - r"^Object 'foo' cannot be used as value shape$"): + r"^Object 'foo' cannot be converted to an Amaranth shape$"): Shape.cast("foo") +class MockShapeCastable(ShapeCastable): + def __init__(self, dest): + self.dest = dest + + def as_shape(self): + return self.dest + + +class MockShapeCastableNoOverride(ShapeCastable): + def __init__(self): + pass + + +class ShapeCastableTestCase(FHDLTestCase): + def test_no_override(self): + with self.assertRaisesRegex(TypeError, + r"^Class 'MockShapeCastableNoOverride' deriving from `ShapeCastable` must " + r"override the `as_shape` method$"): + sc = MockShapeCastableNoOverride() + + def test_cast(self): + sc = MockShapeCastable(unsigned(2)) + self.assertEqual(Shape.cast(sc), unsigned(2)) + + def test_recurse_bad(self): + sc = MockShapeCastable(None) + sc.dest = sc + with self.assertRaisesRegex(RecursionError, + r"^Shape-castable object <.+> casts to itself$"): + Shape.cast(sc) + + def test_recurse(self): + sc = MockShapeCastable(MockShapeCastable(unsigned(1))) + self.assertEqual(Shape.cast(sc), unsigned(1)) + + class ValueTestCase(FHDLTestCase): def test_cast(self): self.assertIsInstance(Value.cast(0), Const)