From 262e24b56455bb8256e84d07f576bb015f885dde Mon Sep 17 00:00:00 2001 From: Wanda Date: Wed, 3 Apr 2024 23:54:42 +0200 Subject: [PATCH] hdl._ir: Remove uses of `_[lr]hs_signals` and `_ioports`. --- amaranth/hdl/_ast.py | 13 ------ amaranth/hdl/_ir.py | 98 +++++++++++++++++++++++++++++++------------ tests/test_hdl_ast.py | 10 ----- 3 files changed, 71 insertions(+), 50 deletions(-) diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index a0d4ebe..6d91b70 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -3000,10 +3000,6 @@ class IOValue(metaclass=ABCMeta): else: raise TypeError(f"Cannot index IO value with {key!r}") - @abstractmethod - def _ioports(self): - raise NotImplementedError # :nocov: - @final class IOPort(IOValue): @@ -3035,9 +3031,6 @@ class IOPort(IOValue): def metadata(self): return self._metadata - def _ioports(self): - return {self} - def __repr__(self): return f"(io-port {self.name})" @@ -3059,9 +3052,6 @@ class IOConcat(IOValue): def metadata(self): return tuple(obj for part in self._parts for obj in part.metadata) - def _ioports(self): - return {port for part in self._parts for port in part._ioports()} - def __repr__(self): return "(io-cat {})".format(" ".join(map(repr, self.parts))) @@ -3115,9 +3105,6 @@ class IOSlice(IOValue): def metadata(self): return self._value.metadata[self.start:self.stop] - def _ioports(self): - return self.value._ioports() - def __repr__(self): return f"(io-slice {self.value!r} {self.start}:{self.stop})" diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 9f797ba..393ed46 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -498,50 +498,96 @@ class Design: if frag_info.parent is not None: self._use_io_port(frag_info.parent, port) + def _collect_used_signals_format(self, fragment: Fragment, fmt: _ast.Format): + for chunk in fmt._chunks: + if not isinstance(chunk, str): + obj, _spec = chunk + self._collect_used_signals_value(fragment, obj) + + def _collect_used_signals_value(self, fragment: Fragment, value: _ast.Value): + if isinstance(value, (_ast.Const, _ast.Initial, _ast.AnyValue)): + pass + elif isinstance(value, _ast.Signal): + self._use_signal(fragment, value) + elif isinstance(value, _ast.Operator): + for op in value.operands: + self._collect_used_signals_value(fragment, op) + elif isinstance(value, _ast.Slice): + self._collect_used_signals_value(fragment, value.value) + elif isinstance(value, _ast.Part): + self._collect_used_signals_value(fragment, value.value) + self._collect_used_signals_value(fragment, value.offset) + elif isinstance(value, _ast.SwitchValue): + self._collect_used_signals_value(fragment, value.test) + for _patterns, elem in value.cases: + self._collect_used_signals_value(fragment, elem) + elif isinstance(value, _ast.Concat): + for part in value.parts: + self._collect_used_signals_value(fragment, part) + else: + raise NotImplementedError # :nocov: + + def _collect_used_signals_io_value(self, fragment: Fragment, value: _ast.IOValue): + if isinstance(value, _ast.IOPort): + self._use_io_port(fragment, value) + elif isinstance(value, _ast.IOSlice): + self._collect_used_signals_io_value(fragment, value.value) + elif isinstance(value, _ast.IOConcat): + for part in value.parts: + self._collect_used_signals_io_value(fragment, part) + else: + raise NotImplementedError # :nocov: + + def _collect_used_signals_stmt(self, fragment: Fragment, stmt: _ast.Statement): + if isinstance(stmt, _ast.Assign): + self._collect_used_signals_value(fragment, stmt.lhs) + self._collect_used_signals_value(fragment, stmt.rhs) + elif isinstance(stmt, _ast.Print): + self._collect_used_signals_format(fragment, stmt.message) + elif isinstance(stmt, _ast.Property): + self._collect_used_signals_value(fragment, stmt.test) + if stmt.message is not None: + self._collect_used_signals_format(fragment, stmt.message) + elif isinstance(stmt, _ast.Switch): + self._collect_used_signals_value(fragment, stmt.test) + for _patterns, stmts, _src_loc in stmt.cases: + for s in stmts: + self._collect_used_signals_stmt(fragment, s) + else: + raise NotImplementedError # :nocov: + def _collect_used_signals(self, fragment: Fragment): """Collects used signals and IO ports for a fragment and all its subfragments.""" from . import _mem if isinstance(fragment, _ir.Instance): for conn, kind in fragment.ports.values(): if isinstance(conn, _ast.IOValue): - for port in conn._ioports(): - self._use_io_port(fragment, port) + self._collect_used_signals_io_value(fragment, conn) elif isinstance(conn, _ast.Value): - for signal in conn._rhs_signals(): - self._use_signal(fragment, signal) + self._collect_used_signals_value(fragment, conn) else: assert False # :nocov: elif isinstance(fragment, _ir.IOBufferInstance): - for port in fragment.port._ioports(): - self._use_io_port(fragment, port) + self._collect_used_signals_io_value(fragment, fragment.port) if fragment.i is not None: - for signal in fragment.i._rhs_signals(): - self._use_signal(fragment, signal) + self._collect_used_signals_value(fragment, fragment.i) if fragment.o is not None: - for signal in fragment.o._rhs_signals(): - self._use_signal(fragment, signal) - for signal in fragment.oe._rhs_signals(): - self._use_signal(fragment, signal) + self._collect_used_signals_value(fragment, fragment.o) + self._collect_used_signals_value(fragment, fragment.oe) elif isinstance(fragment, _mem.MemoryInstance): for port in fragment._read_ports: - for signal in port._addr._rhs_signals(): - self._use_signal(fragment, signal) - for signal in port._data._rhs_signals(): - self._use_signal(fragment, signal) - for signal in port._en._rhs_signals(): - self._use_signal(fragment, signal) + self._collect_used_signals_value(fragment, port._addr) + self._collect_used_signals_value(fragment, port._data) + self._collect_used_signals_value(fragment, port._en) if port._domain != "comb": domain = fragment.domains[port._domain] self._use_signal(fragment, domain.clk) if domain.rst is not None: self._use_signal(fragment, domain.rst) for port in fragment._write_ports: - for signal in port._addr._rhs_signals(): - self._use_signal(fragment, signal) - for signal in port._data._rhs_signals(): - self._use_signal(fragment, signal) - for signal in port._en._rhs_signals(): - self._use_signal(fragment, signal) + self._collect_used_signals_value(fragment, port._addr) + self._collect_used_signals_value(fragment, port._data) + self._collect_used_signals_value(fragment, port._en) domain = fragment.domains[port._domain] self._use_signal(fragment, domain.clk) if domain.rst is not None: @@ -554,9 +600,7 @@ class Design: if domain.rst is not None: self._use_signal(fragment, domain.rst) for statement in statements: - for signal in statement._lhs_signals() | statement._rhs_signals(): - if not isinstance(signal, (_ast.ClockSignal, _ast.ResetSignal)): - self._use_signal(fragment, signal) + self._collect_used_signals_stmt(fragment, statement) for subfragment, _name, _src_loc in fragment.subfragments: self._collect_used_signals(subfragment) diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 1e9737d..2b6c3cb 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1821,13 +1821,11 @@ class IOValueTestCase(FHDLTestCase): self.assertEqual(len(a), 4) self.assertEqual(a.attrs, {}) self.assertEqual(a.metadata, (None, None, None, None)) - self.assertEqual(a._ioports(), {a}) self.assertRepr(a, "(io-port a)") b = IOPort(3, name="b", attrs={"a": "b"}, metadata=["x", "y", "z"]) self.assertEqual(len(b), 3) self.assertEqual(b.attrs, {"a": "b"}) self.assertEqual(b.metadata, ("x", "y", "z")) - self.assertEqual(b._ioports(), {b}) self.assertRepr(b, "(io-port b)") def test_ioport_wrong(self): @@ -1849,32 +1847,26 @@ class IOValueTestCase(FHDLTestCase): s = a[2:5] self.assertEqual(len(s), 3) self.assertEqual(s.metadata, ("c", "d", "e")) - self.assertEqual(s._ioports(), {a}) self.assertRepr(s, "(io-slice (io-port a) 2:5)") s = a[-5:-2] self.assertEqual(len(s), 3) self.assertEqual(s.metadata, ("d", "e", "f")) - self.assertEqual(s._ioports(), {a}) self.assertRepr(s, "(io-slice (io-port a) 3:6)") s = IOSlice(a, -5, -2) self.assertEqual(len(s), 3) self.assertEqual(s.metadata, ("d", "e", "f")) - self.assertEqual(s._ioports(), {a}) self.assertRepr(s, "(io-slice (io-port a) 3:6)") s = a[5] self.assertEqual(len(s), 1) self.assertEqual(s.metadata, ("f",)) - self.assertEqual(s._ioports(), {a}) self.assertRepr(s, "(io-slice (io-port a) 5:6)") s = a[-1] self.assertEqual(len(s), 1) self.assertEqual(s.metadata, ("h",)) - self.assertEqual(s._ioports(), {a}) self.assertRepr(s, "(io-slice (io-port a) 7:8)") s = a[::2] self.assertEqual(len(s), 4) self.assertEqual(s.metadata, ("a", "c", "e", "g")) - self.assertEqual(s._ioports(), {a}) self.assertRepr(s, "(io-cat (io-slice (io-port a) 0:1) (io-slice (io-port a) 2:3) (io-slice (io-port a) 4:5) (io-slice (io-port a) 6:7))") def test_ioslice_wrong(self): @@ -1902,12 +1894,10 @@ class IOValueTestCase(FHDLTestCase): c = Cat(a, b) self.assertEqual(len(c), 5) self.assertEqual(c.metadata, ("a", "b", "c", "x", "y")) - self.assertEqual(c._ioports(), {a, b}) self.assertRepr(c, "(io-cat (io-port a) (io-port b))") c = Cat(a, Cat()) self.assertEqual(len(c), 3) self.assertEqual(c.metadata, ("a", "b", "c")) - self.assertEqual(c._ioports(), {a}) self.assertRepr(c, "(io-cat (io-port a) (io-cat ))") c = Cat(a, Cat()[:]) self.assertEqual(len(c), 3)