diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 075991a..f9bec3d 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -71,6 +71,7 @@ class Fragment: self.src_loc = src_loc self.origins = None self.domains_propagated_up = {} + self.domain_renames = {} def add_domains(self, *domains): for domain in flatten(domains): @@ -655,6 +656,19 @@ class Design: subfragment_name = _add_name(frag_info.assigned_names, subfragment_name) self._assign_names(subfragment, hierarchy=(*hierarchy, subfragment_name)) + def lookup_domain(self, domain, context): + if domain == "comb": + raise KeyError("comb") + if context is not None: + try: + fragment = self.elaboratables[context] + except KeyError: + raise ValueError(f"Elaboratable {context!r} is not a part of the design") + else: + fragment = self.fragment + domain = fragment.domain_renames.get(domain, domain) + return fragment.domains[domain] + ############################################################################################### >:3 diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index c1d12a1..062cddc 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -252,6 +252,9 @@ class FragmentTransformer: for domain, statements in fragment.statements.items(): new_fragment.add_statements(domain, statements) + def map_domain_renames(self, fragment, new_fragment): + new_fragment.domain_renames = dict(fragment.domain_renames) + def map_memory_ports(self, fragment, new_fragment): if hasattr(self, "on_value"): for port in new_fragment._read_ports: @@ -318,6 +321,7 @@ class FragmentTransformer: self.map_subfragments(fragment, new_fragment) self.map_domains(fragment, new_fragment) self.map_statements(fragment, new_fragment) + self.map_domain_renames(fragment, new_fragment) return new_fragment def __call__(self, value, *, src_loc_at=0): @@ -513,6 +517,15 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer) map(self.on_statement, statements) ) + def map_domain_renames(self, fragment, new_fragment): + new_fragment.domain_renames = { + src: self.domain_map.get(dst, dst) + for src, dst in fragment.domain_renames.items() + } + for src, dst in self.domain_map.items(): + if src not in new_fragment.domain_renames: + new_fragment.domain_renames[src] = dst + def map_memory_ports(self, fragment, new_fragment): super().map_memory_ports(fragment, new_fragment) for port in new_fragment._read_ports: diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index a3326e5..34d5987 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -8,6 +8,7 @@ from amaranth.hdl._dsl import * from amaranth.hdl._ir import * from amaranth.hdl._mem import * from amaranth.hdl._nir import SignalField, CombinationalCycle +from amaranth.hdl._xfrm import * from amaranth.lib import enum, data @@ -3561,3 +3562,39 @@ class CycleTestCase(FHDLTestCase): r".*test_hdl_ir.py:\d+: signal a bit 0\n" r"$"): build_netlist(Fragment.get(m, None), []) + + +class DomainLookupTestCase(FHDLTestCase): + def test_domain_lookup(self): + m1 = Module() + m1_a = m1.domains.a = ClockDomain("a") + m1_b = m1.domains.b = ClockDomain("b") + m1_c = m1.domains.c = ClockDomain("c") + m2 = Module() + m3 = Module() + m3.d.sync += Print("m3") + m4 = Module() + m4.d.sync += Print("m4") + m4_d = m4.domains.d = ClockDomain("d") + m5 = Module() + m5.d.sync += Print("m5") + m5_d = m5.domains.d = ClockDomain("d") + + m1.submodules.m2 = xm2 = DomainRenamer({"a": "b"})(m2) + m2.submodules.m3 = xm3 = DomainRenamer("a")(m3) + m2.submodules.m4 = xm4 = DomainRenamer("b")(m4) + m2.submodules.m5 = xm5 = DomainRenamer("c")(m5) + + design = Fragment.get(m1, None).prepare() + + self.assertIs(design.lookup_domain("a", m1), m1_a) + self.assertIs(design.lookup_domain("b", m1), m1_b) + self.assertIs(design.lookup_domain("c", m1), m1_c) + self.assertIs(design.lookup_domain("a", xm2), m1_b) + self.assertIs(design.lookup_domain("b", xm2), m1_b) + self.assertIs(design.lookup_domain("c", xm2), m1_c) + self.assertIs(design.lookup_domain("sync", xm3), m1_b) + self.assertIs(design.lookup_domain("sync", xm4), m1_b) + self.assertIs(design.lookup_domain("sync", xm5), m1_c) + self.assertIs(design.lookup_domain("d", xm4), m4_d) + self.assertIs(design.lookup_domain("d", xm5), m5_d)