diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 44831d6..f469ac8 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -323,7 +323,7 @@ class View(ValueCastable): try: cast_layout = Layout.cast(layout) except TypeError as e: - raise TypeError("View layout must be a Layout instance, not {!r}" + raise TypeError("View layout must be a layout, not {!r}" .format(layout)) from e if target is not None: if (name is not None or reset is not None or reset_less is not None or @@ -398,28 +398,51 @@ class View(ValueCastable): class _AggregateMeta(ShapeCastable, type): - def __new__(metacls, name, bases, namespace, *, _layout_cls=None, **kwargs): - cls = type.__new__(metacls, name, bases, namespace, **kwargs) - if _layout_cls is not None: - cls.__layout_cls = _layout_cls - if "__annotations__" in namespace: + def __new__(metacls, name, bases, namespace): + if "__annotations__" not in namespace: + # This is a base class without its own layout. It is not shape-castable, and cannot + # be instantiated. It can be used to share behavior. + return type.__new__(metacls, name, bases, namespace) + elif all(not hasattr(base, "_AggregateMeta__layout") for base in bases): + # This is a leaf class with its own layout. It is shape-castable and can + # be instantiated. It can also be subclassed, and used to share layout and behavior. + reset = dict() + for name in namespace["__annotations__"]: + if name in namespace: + reset[name] = namespace.pop(name) + cls = type.__new__(metacls, name, bases, namespace) cls.__layout = cls.__layout_cls(namespace["__annotations__"]) - return cls + cls.__reset = reset + return cls + else: + # This is a class that has a base class with a layout and annotations. Such a class + # is not well-formed. + raise TypeError("Aggregate class '{}' must either inherits or specify a layout, " + "not both" + .format(name)) def as_shape(cls): + if not hasattr(cls, "_AggregateMeta__layout"): + raise TypeError("Aggregate class '{}.{}' does not have a defined shape" + .format(cls.__module__, cls.__qualname__)) return cls.__layout class _Aggregate(View, metaclass=_AggregateMeta): def __init__(self, target=None, *, name=None, reset=None, reset_less=None, attrs=None, decoder=None, src_loc_at=0): + if target is None and hasattr(self.__class__, "_AggregateMeta__reset"): + if reset is None: + reset = self.__class__._AggregateMeta__reset + else: + reset = {**self.__class__._AggregateMeta__reset, **reset} super().__init__(self.__class__, target, name=name, reset=reset, reset_less=reset_less, attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1) -class Struct(_Aggregate, _layout_cls=StructLayout): - pass +class Struct(_Aggregate): + _AggregateMeta__layout_cls = StructLayout -class Union(_Aggregate, _layout_cls=UnionLayout): - pass +class Union(_Aggregate): + _AggregateMeta__layout_cls = UnionLayout diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index 1d5a3cb..633dd58 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -411,7 +411,7 @@ class ViewTestCase(FHDLTestCase): def test_layout_wrong(self): with self.assertRaisesRegex(TypeError, - r"^View layout must be a Layout instance, not <.+?>$"): + r"^View layout must be a layout, not <.+?>$"): View(object(), Signal(1)) def test_target_wrong_type(self): @@ -575,6 +575,70 @@ class StructTestCase(FHDLTestCase): self.assertEqual(s.attrs, {"debug": 1}) self.assertEqual(s.decoder, decoder) + def test_construct_reset(self): + class S(Struct): + p: 4 + q: 2 = 1 + + with self.assertRaises(AttributeError): + S.q + + v1 = S() + self.assertEqual(v1.as_value().reset, 0b010000) + v2 = S(reset=dict(p=0b0011)) + self.assertEqual(v2.as_value().reset, 0b010011) + v3 = S(reset=dict(p=0b0011, q=0b00)) + self.assertEqual(v3.as_value().reset, 0b000011) + + def test_shape_undefined_wrong(self): + class S(Struct): + pass + + with self.assertRaisesRegex(TypeError, + r"^Aggregate class '.+?\.S' does not have a defined shape$"): + Shape.cast(S) + + def test_base_class_1(self): + class Sb(Struct): + def add(self): + return self.a + self.b + + class Sb1(Sb): + a: 1 + b: 1 + + class Sb2(Sb): + a: 2 + b: 2 + + self.assertEqual(Sb1().add().shape(), unsigned(2)) + self.assertEqual(Sb2().add().shape(), unsigned(3)) + + def test_base_class_2(self): + class Sb(Struct): + a: 2 + b: 2 + + class Sb1(Sb): + def do(self): + return Cat(self.a, self.b) + + class Sb2(Sb): + def do(self): + return self.a + self.b + + self.assertEqual(Sb1().do().shape(), unsigned(4)) + self.assertEqual(Sb2().do().shape(), unsigned(3)) + + def test_layout_redefined_wrong(self): + class Sb(Struct): + a: 1 + + with self.assertRaisesRegex(TypeError, + r"^Aggregate class 'Sd' must either inherits or specify a layout, not both$"): + class Sd(Sb): + b: 1 + class UnionTestCase(FHDLTestCase): def test_construct(self):