Implement RFC 22: Add ValueCastable.shape().

Fixes #794.
Closes #876.
This commit is contained in:
Catherine 2023-08-23 07:48:33 +00:00
parent 7714ce329a
commit f95fe45186
5 changed files with 67 additions and 33 deletions

View file

@ -1110,7 +1110,11 @@ class Signal(Value, DUID, metaclass=_SignalMeta):
new_name = other.name + str(name_suffix) new_name = other.name + str(name_suffix)
else: else:
new_name = tracer.get_var_name(depth=2 + src_loc_at, default="$like") new_name = tracer.get_var_name(depth=2 + src_loc_at, default="$like")
kw = dict(shape=Value.cast(other).shape(), name=new_name) if isinstance(other, ValueCastable):
shape = other.shape()
else:
shape = Value.cast(other).shape()
kw = dict(shape=shape, name=new_name)
if isinstance(other, Signal): if isinstance(other, Signal):
kw.update(reset=other.reset, reset_less=other.reset_less, kw.update(reset=other.reset, reset_less=other.reset_less,
attrs=other.attrs, decoder=other.decoder) attrs=other.attrs, decoder=other.decoder)
@ -1363,6 +1367,9 @@ class ValueCastable:
if not hasattr(cls, "as_value"): if not hasattr(cls, "as_value"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
"the `as_value` method") "the `as_value` method")
if not hasattr(cls, "shape"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
"the `shape` method")
if not hasattr(cls.as_value, "_ValueCastable__memoized"): if not hasattr(cls.as_value, "_ValueCastable__memoized"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate " raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate "
"the `as_value` method with the `ValueCastable.lowermethod` decorator") "the `as_value` method with the `ValueCastable.lowermethod` decorator")

View file

@ -115,20 +115,6 @@ class Layout(ShapeCastable, metaclass=ABCMeta):
raise TypeError("Object {!r} cannot be converted to a data layout" raise TypeError("Object {!r} cannot be converted to a data layout"
.format(obj)) .format(obj))
@staticmethod
def of(obj):
"""Extract the layout that was used to create a view.
Raises
------
TypeError
If ``obj`` is not a :class:`View` instance.
"""
if not isinstance(obj, View):
raise TypeError("Object {!r} is not a data view"
.format(obj))
return obj._View__orig_layout
@abstractmethod @abstractmethod
def __iter__(self): def __iter__(self):
"""Iterate fields in the layout. """Iterate fields in the layout.
@ -611,6 +597,16 @@ class View(ValueCastable):
self.__layout = cast_layout self.__layout = cast_layout
self.__target = cast_target self.__target = cast_target
def shape(self):
"""Get layout of this view.
Returns
-------
:class:`Layout`
The ``layout`` provided when constructing the view.
"""
return self.__orig_layout
@ValueCastable.lowermethod @ValueCastable.lowermethod
def as_value(self): def as_value(self):
"""Get underlying value. """Get underlying value.

View file

@ -42,18 +42,20 @@ Implemented RFCs
.. _RFC 9: https://amaranth-lang.org/rfcs/0009-const-init-shape-castable.html .. _RFC 9: https://amaranth-lang.org/rfcs/0009-const-init-shape-castable.html
.. _RFC 10: https://amaranth-lang.org/rfcs/0010-move-repl-to-value.html .. _RFC 10: https://amaranth-lang.org/rfcs/0010-move-repl-to-value.html
.. _RFC 15: https://amaranth-lang.org/rfcs/0015-lifting-shape-castables.html .. _RFC 15: https://amaranth-lang.org/rfcs/0015-lifting-shape-castables.html
.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html
* `RFC 1`_: Aggregate data structure library * `RFC 1`_: Aggregate data structure library
* `RFC 3`_: Enumeration shapes * `RFC 3`_: Enumeration shapes
* `RFC 4`_: Constant-castable expressions * `RFC 4`_: Constant-castable expressions
* `RFC 5`_: Remove Const.normalize * `RFC 5`_: Remove ``Const.normalize``
* `RFC 6`_: CRC generator * `RFC 6`_: CRC generator
* `RFC 8`_: Aggregate extensibility * `RFC 8`_: Aggregate extensibility
* `RFC 9`_: Constant initialization for shape-castable objects * `RFC 9`_: Constant initialization for shape-castable objects
* `RFC 8`_: Aggregate extensibility * `RFC 8`_: Aggregate extensibility
* `RFC 9`_: Constant initialization for shape-castable objects * `RFC 9`_: Constant initialization for shape-castable objects
* `RFC 10`_: Move Repl to Value.replicate * `RFC 10`_: Move ``Repl`` to ``Value.replicate``
* `RFC 15`_: Lifting shape-castable objects * `RFC 15`_: Lifting shape-castable objects
* `RFC 22`_: Define ``ValueCastable.shape()``
Language changes Language changes

View file

@ -1183,6 +1183,9 @@ class MockValueCastable(ValueCastable):
def __init__(self, dest): def __init__(self, dest):
self.dest = dest self.dest = dest
def shape(self):
return Value.cast(self.dest).shape()
@ValueCastable.lowermethod @ValueCastable.lowermethod
def as_value(self): def as_value(self):
return self.dest return self.dest
@ -1192,6 +1195,9 @@ class MockValueCastableChanges(ValueCastable):
def __init__(self, width=0): def __init__(self, width=0):
self.width = width self.width = width
def shape(self):
return unsigned(self.width)
@ValueCastable.lowermethod @ValueCastable.lowermethod
def as_value(self): def as_value(self):
return Signal(self.width) return Signal(self.width)
@ -1201,6 +1207,9 @@ class MockValueCastableCustomGetattr(ValueCastable):
def __init__(self): def __init__(self):
pass pass
def shape(self):
assert False
@ValueCastable.lowermethod @ValueCastable.lowermethod
def as_value(self): def as_value(self):
return Const(0) return Const(0)
@ -1218,17 +1227,30 @@ class ValueCastableTestCase(FHDLTestCase):
def __init__(self): def __init__(self):
pass pass
def shape(self):
pass
def as_value(self): def as_value(self):
return Signal() return Signal()
def test_no_override(self): def test_no_override(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must " r"^Class 'MockValueCastableNoOverrideAsValue' deriving from `ValueCastable` must "
r"override the `as_value` method$"): r"override the `as_value` method$"):
class MockValueCastableNoOverride(ValueCastable): class MockValueCastableNoOverrideAsValue(ValueCastable):
def __init__(self): def __init__(self):
pass pass
with self.assertRaisesRegex(TypeError,
r"^Class 'MockValueCastableNoOverrideShapec' deriving from `ValueCastable` must "
r"override the `shape` method$"):
class MockValueCastableNoOverrideShapec(ValueCastable):
def __init__(self):
pass
def as_value(self):
return Signal()
def test_memoized(self): def test_memoized(self):
vc = MockValueCastableChanges(1) vc = MockValueCastableChanges(1)
sig1 = vc.as_value() sig1 = vc.as_value()

View file

@ -365,11 +365,6 @@ class LayoutTestCase(FHDLTestCase):
r"^Shape-castable object <.+> casts to itself$"): r"^Shape-castable object <.+> casts to itself$"):
Layout.cast(sc) Layout.cast(sc)
def test_of_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Object <.+> is not a data view$"):
Layout.of(object())
def test_eq_wrong_recur(self): def test_eq_wrong_recur(self):
sc = MockShapeCastable(None) sc = MockShapeCastable(None)
sc.shape = sc sc.shape = sc
@ -379,7 +374,7 @@ class LayoutTestCase(FHDLTestCase):
sl = StructLayout({"f": unsigned(1)}) sl = StructLayout({"f": unsigned(1)})
s = Signal(1) s = Signal(1)
v = sl(s) v = sl(s)
self.assertIs(Layout.of(v), sl) self.assertIs(v.shape(), sl)
self.assertIs(v.as_value(), s) self.assertIs(v.as_value(), s)
def test_const(self): def test_const(self):
@ -621,6 +616,11 @@ class ViewTestCase(FHDLTestCase):
r"and may only be accessed by indexing$"): r"and may only be accessed by indexing$"):
Signal(StructLayout({"_c": signed(1)}))._c Signal(StructLayout({"_c": signed(1)}))._c
def test_signal_like(self):
s1 = Signal(StructLayout({"a": unsigned(1)}))
s2 = Signal.like(s1)
self.assertEqual(s2.shape(), StructLayout({"a": unsigned(1)}))
def test_bug_837_array_layout_getitem_str(self): def test_bug_837_array_layout_getitem_str(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Views with array layout may only be indexed with an integer or a value, " r"^Views with array layout may only be indexed with an integer or a value, "
@ -646,7 +646,7 @@ class StructTestCase(FHDLTestCase):
})) }))
v = Signal(S) v = Signal(S)
self.assertEqual(Layout.of(v), S) self.assertEqual(v.shape(), S)
self.assertEqual(Value.cast(v).shape(), Shape.cast(S)) self.assertEqual(Value.cast(v).shape(), Shape.cast(S))
self.assertEqual(Value.cast(v).name, "v") self.assertEqual(Value.cast(v).name, "v")
self.assertRepr(v.a, "(slice (sig v) 0:1)") self.assertRepr(v.a, "(slice (sig v) 0:1)")
@ -666,11 +666,11 @@ class StructTestCase(FHDLTestCase):
self.assertEqual(Shape.cast(S), unsigned(9)) self.assertEqual(Shape.cast(S), unsigned(9))
v = Signal(S) v = Signal(S)
self.assertIs(Layout.of(v), S) self.assertIs(v.shape(), S)
self.assertIsInstance(v, S) self.assertIsInstance(v, S)
self.assertIs(Layout.of(v.b), R) self.assertIs(v.b.shape(), R)
self.assertIsInstance(v.b, R) self.assertIsInstance(v.b, R)
self.assertIs(Layout.of(v.b.q), Q) self.assertIs(v.b.q.shape(), Q)
self.assertIsInstance(v.b.q, View) self.assertIsInstance(v.b.q, View)
self.assertRepr(v.b.p, "(slice (slice (sig v) 1:9) 0:4)") self.assertRepr(v.b.p, "(slice (slice (sig v) 1:9) 0:4)")
self.assertRepr(v.b.q.as_value(), "(slice (slice (sig v) 1:9) 4:8)") self.assertRepr(v.b.q.as_value(), "(slice (slice (sig v) 1:9) 4:8)")
@ -747,10 +747,17 @@ class StructTestCase(FHDLTestCase):
b: int b: int
c: str = "x" c: str = "x"
self.assertEqual(Layout.of(Signal(S)), StructLayout({"a": unsigned(1)})) self.assertEqual(Layout.cast(S), StructLayout({"a": unsigned(1)}))
self.assertEqual(S.__annotations__, {"b": int, "c": str}) self.assertEqual(S.__annotations__, {"b": int, "c": str})
self.assertEqual(S.c, "x") self.assertEqual(S.c, "x")
def test_signal_like(self):
class S(Struct):
a: 1
s1 = Signal(S)
s2 = Signal.like(s1)
self.assertEqual(s2.shape(), S)
class UnionTestCase(FHDLTestCase): class UnionTestCase(FHDLTestCase):
def test_construct(self): def test_construct(self):
@ -765,7 +772,7 @@ class UnionTestCase(FHDLTestCase):
})) }))
v = Signal(U) v = Signal(U)
self.assertEqual(Layout.of(v), U) self.assertEqual(v.shape(), U)
self.assertEqual(Value.cast(v).shape(), Shape.cast(U)) self.assertEqual(Value.cast(v).shape(), Shape.cast(U))
self.assertRepr(v.a, "(slice (sig v) 0:1)") self.assertRepr(v.a, "(slice (sig v) 0:1)")
self.assertRepr(v.b, "(s (slice (sig v) 0:3))") self.assertRepr(v.b, "(s (slice (sig v) 0:3))")
@ -887,7 +894,7 @@ class RFCExamplesTestCase(TestCase):
view1 = Signal(layout1) view1 = Signal(layout1)
self.assertIsInstance(view1, View) self.assertIsInstance(view1, View)
self.assertEqual(Layout.of(view1), layout1) self.assertEqual(view1.shape(), layout1)
self.assertEqual(view1.as_value().shape(), unsigned(3)) self.assertEqual(view1.as_value().shape(), unsigned(3))
m1 = Module() m1 = Module()
@ -933,4 +940,4 @@ class RFCExamplesTestCase(TestCase):
self.assertEqual(layout1, Layout.cast(SomeVariant)) self.assertEqual(layout1, Layout.cast(SomeVariant))
self.assertIs(SomeVariant, Layout.of(view2)) self.assertIs(SomeVariant, view2.shape())