hdl.ast: implement ShapeCastable (like ValueCastable).

Refs #693.
This commit is contained in:
Catherine 2022-04-05 21:55:50 +00:00
parent 0723f6bac9
commit bf16acf2f0
2 changed files with 105 additions and 40 deletions

View file

@ -12,7 +12,7 @@ from .._unused import *
__all__ = [ __all__ = [
"Shape", "signed", "unsigned", "Shape", "signed", "unsigned", "ShapeCastable",
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
"Array", "ArrayProxy", "Array", "ArrayProxy",
"Signal", "ClockSignal", "ResetSignal", "Signal", "ClockSignal", "ResetSignal",
@ -32,6 +32,22 @@ class DUID:
DUID.__next_uid += 1 DUID.__next_uid += 1
class ShapeCastable:
"""Interface of user-defined objects that can be cast to :class:`Shape`s.
An object deriving from ``ShapeCastable`` is automatically converted to a ``Shape`` when it is
used in a context where a ``Shape`` is expected. Such objects can contain a richer description
of the shape than what is supported by the core Amaranth language, yet still be transparently
used with it.
"""
def __new__(cls, *args, **kwargs):
self = super().__new__(cls)
if not hasattr(self, "as_shape"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
f"the `as_shape` method")
return self
class Shape: class Shape:
"""Bit width and signedness of a value. """Bit width and signedness of a value.
@ -62,30 +78,33 @@ class Shape:
self.width = width self.width = width
self.signed = signed self.signed = signed
# TODO(nmigen-0.4): remove
def __iter__(self): def __iter__(self):
return iter((self.width, self.signed)) return iter((self.width, self.signed))
@staticmethod @staticmethod
def cast(obj, *, src_loc_at=0): def cast(obj, *, src_loc_at=0):
while True:
if isinstance(obj, Shape): if isinstance(obj, Shape):
return obj return obj
if isinstance(obj, int): elif isinstance(obj, int):
return Shape(obj) return Shape(obj)
if isinstance(obj, tuple): # TODO(nmigen-0.4): remove
elif isinstance(obj, tuple):
width, signed = obj width, signed = obj
warnings.warn("instead of `{tuple}`, use `{constructor}({width})`" warnings.warn("instead of `{tuple}`, use `{constructor}({width})`"
.format(constructor="signed" if signed else "unsigned", width=width, .format(constructor="signed" if signed else "unsigned", width=width,
tuple=obj), tuple=obj),
DeprecationWarning, stacklevel=2 + src_loc_at) DeprecationWarning, stacklevel=2 + src_loc_at)
return Shape(width, signed) return Shape(width, signed)
if isinstance(obj, range): elif isinstance(obj, range):
if len(obj) == 0: if len(obj) == 0:
return Shape(0, obj.start < 0) return Shape(0, obj.start < 0)
signed = obj.start < 0 or (obj.stop - obj.step) < 0 signed = obj.start < 0 or (obj.stop - obj.step) < 0
width = max(bits_for(obj.start, signed), width = max(bits_for(obj.start, signed),
bits_for(obj.stop - obj.step, signed)) bits_for(obj.stop - obj.step, signed))
return Shape(width, signed) return Shape(width, signed)
if isinstance(obj, type) and issubclass(obj, Enum): elif isinstance(obj, type) and issubclass(obj, Enum):
min_value = min(member.value for member in obj) min_value = min(member.value for member in obj)
max_value = max(member.value for member in obj) max_value = max(member.value for member in obj)
if not isinstance(min_value, int) or not isinstance(max_value, int): if not isinstance(min_value, int) or not isinstance(max_value, int):
@ -94,7 +113,13 @@ class Shape:
signed = min_value < 0 or max_value < 0 signed = min_value < 0 or max_value < 0
width = max(bits_for(min_value, signed), bits_for(max_value, signed)) width = max(bits_for(min_value, signed), bits_for(max_value, signed))
return Shape(width, signed) return Shape(width, signed)
raise TypeError("Object {!r} cannot be used as value shape".format(obj)) elif isinstance(obj, ShapeCastable):
new_obj = obj.as_shape()
else:
raise TypeError("Object {!r} cannot be converted to an Amaranth shape".format(obj))
if new_obj is obj:
raise RecursionError("Shape-castable object {!r} casts to itself".format(obj))
obj = new_obj
def __repr__(self): def __repr__(self):
if self.signed: if self.signed:
@ -103,6 +128,7 @@ class Shape:
return "unsigned({})".format(self.width) return "unsigned({})".format(self.width)
def __eq__(self, other): def __eq__(self, other):
# TODO(nmigen-0.4): remove
if isinstance(other, tuple) and len(other) == 2: if isinstance(other, tuple) and len(other) == 2:
width, signed = other width, signed = other
if isinstance(width, int) and isinstance(signed, bool): if isinstance(width, int) and isinstance(signed, bool):
@ -112,9 +138,11 @@ class Shape:
"not {!r}" "not {!r}"
.format(other)) .format(other))
if not isinstance(other, Shape): if not isinstance(other, Shape):
raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, " try:
"not {!r}" other = self.__class__.cast(other)
.format(other)) except TypeError as e:
raise TypeError("Shapes may be compared with shape-castable objects, not {!r}"
.format(other)) from e
return self.width == other.width and self.signed == other.signed return self.width == other.width and self.signed == other.signed
@ -1282,12 +1310,14 @@ class UserValue(Value):
class ValueCastable: class ValueCastable:
"""Interface of objects can be cast to :class:`Value`s. """Interface of user-defined objects that can be cast to :class:`Value`s.
A ``ValueCastable`` can be cast to ``Value``, meaning its precise representation does not have An object deriving from ``ValueCastable`` is automatically converted to a ``Value`` when it is
to be immediately known. This is useful in certain metaprogramming scenarios. Instead of used in a context where a ``Value`` is expected. Such objects can implement different or
providing fixed semantics upfront, it is kept abstract for as long as possible, only being richer semantics than what is supported by the core Amaranth language, yet still be
cast to a concrete Amaranth value when required. transparently used with it as long as the final underlying representation is a single Amaranth
``Value``. These objects also need not commit to a specific representation until they are
converted to a concrete Amaranth value.
Note that it is necessary to ensure that Amaranth's view of representation of all values stays Note that it is necessary to ensure that Amaranth's view of representation of all values stays
internally consistent. The class deriving from ``ValueCastable`` must decorate the ``as_value`` internally consistent. The class deriving from ``ValueCastable`` must decorate the ``as_value``
@ -1301,7 +1331,6 @@ class ValueCastable:
if not hasattr(self, "as_value"): if not hasattr(self, "as_value"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
"the `as_value` method") "the `as_value` method")
if not hasattr(self.as_value, "_ValueCastable__memoized"): if not hasattr(self.as_value, "_ValueCastable__memoized"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate " raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate "
"the `as_value` method with the `ValueCastable.lowermethod` decorator") "the `as_value` method with the `ValueCastable.lowermethod` decorator")

View file

@ -42,7 +42,7 @@ class ShapeTestCase(FHDLTestCase):
def test_compare_wrong(self): def test_compare_wrong(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Shapes may be compared with other Shapes and \(int, bool\) tuples, not 'hi'$"): r"^Shapes may be compared with shape-castable objects, not 'hi'$"):
Shape(1, True) == 'hi' Shape(1, True) == 'hi'
def test_compare_tuple_wrong(self): def test_compare_tuple_wrong(self):
@ -141,10 +141,46 @@ class ShapeTestCase(FHDLTestCase):
def test_cast_bad(self): def test_cast_bad(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Object 'foo' cannot be used as value shape$"): r"^Object 'foo' cannot be converted to an Amaranth shape$"):
Shape.cast("foo") Shape.cast("foo")
class MockShapeCastable(ShapeCastable):
def __init__(self, dest):
self.dest = dest
def as_shape(self):
return self.dest
class MockShapeCastableNoOverride(ShapeCastable):
def __init__(self):
pass
class ShapeCastableTestCase(FHDLTestCase):
def test_no_override(self):
with self.assertRaisesRegex(TypeError,
r"^Class 'MockShapeCastableNoOverride' deriving from `ShapeCastable` must "
r"override the `as_shape` method$"):
sc = MockShapeCastableNoOverride()
def test_cast(self):
sc = MockShapeCastable(unsigned(2))
self.assertEqual(Shape.cast(sc), unsigned(2))
def test_recurse_bad(self):
sc = MockShapeCastable(None)
sc.dest = sc
with self.assertRaisesRegex(RecursionError,
r"^Shape-castable object <.+> casts to itself$"):
Shape.cast(sc)
def test_recurse(self):
sc = MockShapeCastable(MockShapeCastable(unsigned(1)))
self.assertEqual(Shape.cast(sc), unsigned(1))
class ValueTestCase(FHDLTestCase): class ValueTestCase(FHDLTestCase):
def test_cast(self): def test_cast(self):
self.assertIsInstance(Value.cast(0), Const) self.assertIsInstance(Value.cast(0), Const)