diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index f469ac8..275512e 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -431,10 +431,15 @@ class _AggregateMeta(ShapeCastable, type): class _Aggregate(View, metaclass=_AggregateMeta): def __init__(self, target=None, *, name=None, reset=None, reset_less=None, attrs=None, decoder=None, src_loc_at=0): + if self.__class__._AggregateMeta__layout_cls is UnionLayout: + if reset is not None and len(reset) > 1: + raise ValueError("Reset value for at most one field can be provided for " + "a union class (specified: {})" + .format(", ".join(reset.keys()))) if target is None and hasattr(self.__class__, "_AggregateMeta__reset"): if reset is None: reset = self.__class__._AggregateMeta__reset - else: + elif self.__class__._AggregateMeta__layout_cls is not UnionLayout: reset = {**self.__class__._AggregateMeta__reset, **reset} super().__init__(self.__class__, target, name=name, reset=reset, reset_less=reset_less, attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1) diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index 633dd58..b8e71bd 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -669,6 +669,24 @@ class UnionTestCase(FHDLTestCase): self.assertEqual(s.attrs, {"debug": 1}) self.assertEqual(s.decoder, decoder) + def test_construct_reset_two_wrong(self): + class U(Union): + a: unsigned(1) + b: unsigned(2) + + with self.assertRaisesRegex(ValueError, + r"^Reset value for at most one field can be provided for a union class " + r"\(specified: a, b\)$"): + U(reset=dict(a=1, b=2)) + + def test_construct_reset_override(self): + class U(Union): + a: unsigned(1) = 1 + b: unsigned(2) + + self.assertEqual(U().as_value().reset, 0b01) + self.assertEqual(U(reset=dict(b=0b10)).as_value().reset, 0b10) + # Examples from https://github.com/amaranth-lang/amaranth/issues/693 class RFCExamplesTestCase(TestCase):