diff --git a/amaranth/lib/wiring.py b/amaranth/lib/wiring.py index c8a4287..6534541 100644 --- a/amaranth/lib/wiring.py +++ b/amaranth/lib/wiring.py @@ -371,6 +371,33 @@ class Signature(metaclass=SignatureMeta): self.members.freeze() return self + def flatten(self, obj): + for name, member in self.members.items(): + path = (name,) + value = getattr(obj, name) + + def iter_member(value, *, path): + if member.is_port: + yield path, Member(member.flow, member.shape, reset=member.reset), value + elif member.is_signature: + for sub_path, sub_member, sub_value in member.signature.flatten(value): + if member.flow == In: + sub_member = sub_member.flip() + yield ((*path, *sub_path), sub_member, sub_value) + else: + assert False # :nocov: + + def iter_dimensions(value, dimensions, *, path): + if not dimensions: + yield from iter_member(value, path=path) + else: + dimension, *rest_of_dimensions = dimensions + for index in range(dimension): + yield from iter_dimensions(value[index], rest_of_dimensions, + path=(path, index)) + + yield from iter_dimensions(value, dimensions=member.dimensions, path=path) + def is_compliant(self, obj, *, reasons=None, path=("obj",)): def check_attr_value(member, attr_value, *, path): if member.is_port: diff --git a/tests/test_lib_wiring.py b/tests/test_lib_wiring.py index d33ac7e..ccdc3d1 100644 --- a/tests/test_lib_wiring.py +++ b/tests/test_lib_wiring.py @@ -372,6 +372,36 @@ class SignatureTestCase(unittest.TestCase): r"^Cannot add members to a frozen signature$"): sig.members += {"b": Out(1)} + def assertFlattenedSignature(self, actual, expected): + for (a_path, a_member, a_value), (b_path, b_member, b_value) in zip(actual, expected): + self.assertEqual(a_path, b_path) + self.assertEqual(a_member, b_member) + self.assertIs(a_value, b_value) + + def test_flatten(self): + sig = Signature({"a": In(1), "b": Out(2).array(2)}) + intf = sig.create() + self.assertFlattenedSignature(sig.flatten(intf), [ + (("a",), In(1), intf.a), + ((("b",), 0), Out(2), intf.b[0]), + ((("b",), 1), Out(2), intf.b[1]) + ]) + + def test_flatten_sig(self): + sig = Signature({ + "a": Out(Signature({"p": Out(1)})), + "b": Out(Signature({"q": In (1)})), + "c": In( Signature({"r": Out(1)})), + "d": In( Signature({"s": In (1)})), + }) + intf = sig.create() + self.assertFlattenedSignature(sig.flatten(intf), [ + (("a", "p"), Out(1), intf.a.p), + (("b", "q"), In (1), intf.b.q), + (("c", "r"), Out(1), intf.c.r), + (("d", "s"), In (1), intf.d.s), + ]) + def assertNotCompliant(self, reason_regex, sig, obj): self.assertFalse(sig.is_compliant(obj)) reasons = []