diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index aaa0a06..2cebc1b 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -208,6 +208,11 @@ class Layout(ShapeCastable, metaclass=ABCMeta): an all-zero value and had every field assigned to the corresponding value in the order in which they appear in :py:`init`. """ + if isinstance(init, Const): + if Layout.cast(init.shape()) != self: + raise ValueError(f"Const layout {init.shape()!r} differs from shape layout " + f"{self!r}") + return init if init is None: iterator = iter(()) elif isinstance(init, Mapping): @@ -1139,6 +1144,11 @@ class _AggregateMeta(ShapeCastable, type): return super().__call__(cls, target) def const(cls, init): + if isinstance(init, Const): + if Layout.cast(init.shape()) != Layout.cast(cls.__layout): + raise ValueError(f"Const layout {init.shape()!r} differs from shape layout " + f"{cls.__layout!r}") + return init if cls.__layout_cls is UnionLayout: if init is not None and len(init) > 1: raise ValueError("Initializer for at most one field can be provided for " diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index 0426053..1dd4399 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -387,6 +387,7 @@ class LayoutTestCase(FHDLTestCase): }) self.assertRepr(sl.const(None).as_value(), "(const 3'd0)") self.assertRepr(sl.const({"a": 0b1, "b": 0b10}).as_value(), "(const 3'd5)") + self.assertRepr(sl.const(sl.const({"a": 0b1, "b": 0b10})).as_value(), "(const 3'd5)") fl = data.FlexibleLayout(2, { "a": data.Field(unsigned(1), 0), @@ -408,6 +409,10 @@ class LayoutTestCase(FHDLTestCase): r"^Layout constant initializer must be a mapping or a sequence, not " r"<.+?object.+?>$"): sl.const(object()) + sl2 = data.StructLayout({"f": unsigned(2)}) + with self.assertRaisesRegex(ValueError, + r"^Const layout StructLayout.* differs from shape layout StructLayout.*$"): + sl2.const(sl.const({})) def test_const_field_shape_castable(self): class CastableFromHex(ShapeCastable): @@ -1013,6 +1018,21 @@ class StructTestCase(FHDLTestCase): self.assertEqual(v2.as_value().init, 0b010011) v3 = Signal(S, init=dict(p=0b0011, q=0b00)) self.assertEqual(v3.as_value().init, 0b000011) + v3 = Signal(S, init=S.const({"p": 0b0011, "q": 0b00})) + self.assertEqual(v3.as_value().init, 0b000011) + + def test_const_wrong(self): + class S(data.Struct): + p: 4 + q: 2 = 1 + + class S2(data.Struct): + p: 2 + q: 4 + + with self.assertRaisesRegex(ValueError, + f"^Const layout StructLayout.* differs from shape layout StructLayout.*$"): + S.const(S2.const({"p": 0b11, "q": 0b0000})) def test_shape_undefined_wrong(self): class S(data.Struct):