diff --git a/amaranth/hdl/__init__.py b/amaranth/hdl/__init__.py index 7227775..130ccb6 100644 --- a/amaranth/hdl/__init__.py +++ b/amaranth/hdl/__init__.py @@ -9,6 +9,7 @@ from ._cd import DomainError, ClockDomain from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment from ._ir import Instance, IOBufferInstance from ._mem import FrozenMemory, MemoryData, MemoryInstance, Memory, ReadPort, WritePort, DummyPort +from ._nir import CombinationalCycle from ._rec import Record from ._xfrm import DomainRenamer, ResetInserter, EnableInserter @@ -28,6 +29,8 @@ __all__ = [ # _ir "UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance", "IOBufferInstance", + # _nir + "CombinationalCycle", # _mem "FrozenMemory", "MemoryData", "MemoryInstance", "Memory", "ReadPort", "WritePort", "DummyPort", # _rec diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index cc77e4a..075991a 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -709,7 +709,7 @@ class NetlistEmitter: def emit_signal(self, signal) -> _nir.Value: if signal in self.netlist.signals: return self.netlist.signals[signal] - value = self.netlist.alloc_late_value(len(signal)) + value = self.netlist.alloc_late_value(signal) self.netlist.signals[signal] = value for bit, net in enumerate(value): self.late_net_to_signal[net] = (signal, bit) @@ -1738,6 +1738,7 @@ def build_netlist(fragment, ports=(), *, name="top", all_undef_to_ff=False, **kw design = fragment.prepare(ports=ports, hierarchy=(name,), **kwargs) netlist = _nir.Netlist() _emit_netlist(netlist, design, all_undef_to_ff=all_undef_to_ff) + netlist.check_comb_cycles() netlist.resolve_all_nets() _compute_net_flows(netlist) _compute_ports(netlist) diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py index 595e31f..9c9db82 100644 --- a/amaranth/hdl/_nir.py +++ b/amaranth/hdl/_nir.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Any import enum from ._ast import SignalDict @@ -7,8 +7,9 @@ from . import _ast __all__ = [ # Netlist core + "CombinationalCycle", "Net", "Value", "IONet", "IOValue", - "FormatValue", "Format", + "FormatValue", "Format", "SignalField", "Netlist", "ModuleNetFlow", "IODirection", "Module", "Cell", "Top", # Computation cells "Operator", "Part", @@ -25,6 +26,10 @@ __all__ = [ ] +class CombinationalCycle(Exception): + pass + + class Net(int): __slots__ = () @@ -335,6 +340,7 @@ class Netlist: modules : list of ``Module`` cells : list of ``Cell`` connections : dict of (negative) int to int + late_to_signal : dict of (late) Net to its Signal and bit number io_ports : list of ``IOPort`` signals : dict of Signal to ``Value`` signal_fields: dict of Signal to dict of tuple[str | int] to SignalField @@ -344,6 +350,7 @@ class Netlist: self.modules: list[Module] = [] self.cells: list[Cell] = [Top()] self.connections: dict[Net, Net] = {} + self.late_to_signal: dict[Net, (_ast.Signal, int)] = {} self.io_ports: list[_ast.IOPort] = [] self.signals = SignalDict() self.signal_fields = SignalDict() @@ -405,9 +412,12 @@ class Netlist: cell_idx = self.add_cell(cell) return Value(Net.from_cell(cell_idx, bit) for bit in range(width)) - def alloc_late_value(self, width: int): - self.last_late_net -= width - return Value(Net.from_late(self.last_late_net + bit) for bit in range(width)) + def alloc_late_value(self, signal: _ast.Signal): + self.last_late_net -= len(signal) + value = Value(Net.from_late(self.last_late_net + bit) for bit in range(len(signal))) + for bit, net in enumerate(value): + self.late_to_signal[net] = signal, bit + return value @property def top(self): @@ -415,6 +425,62 @@ class Netlist: assert isinstance(top, Top) return top + def check_comb_cycles(self): + class Cycle: + def __init__(self, start): + self.start = start + self.path = [] + + checked = set() + busy = set() + + def traverse(net): + if net in checked: + return None + + if net in busy: + return Cycle(net) + busy.add(net) + + cycle = None + if net.is_const: + pass + elif net.is_late: + cycle = traverse(self.connections[net]) + if cycle is not None: + sig, bit = self.late_to_signal[net] + cycle.path.append((sig, bit, sig.src_loc)) + else: + for src, src_loc in self.cells[net.cell].comb_edges_to(net.bit): + cycle = traverse(src) + if cycle is not None: + cycle.path.append((self.cells[net.cell], net.bit, src_loc)) + break + + if cycle is not None and cycle.start == net: + msg = ["Combinational cycle detected, path:\n"] + for obj, bit, src_loc in reversed(cycle.path): + if isinstance(obj, _ast.Signal): + obj = f"signal {obj.name}" + elif isinstance(obj, Operator): + obj = f"operator {obj.operator}" + else: + obj = f"cell {obj.__class__.__name__}" + src_loc = ":0" if src_loc is None else f"{src_loc[0]}:{src_loc[1]}" + msg.append(f" {src_loc}: {obj} bit {bit}\n") + raise CombinationalCycle("".join(msg)) + + busy.remove(net) + checked.add(net) + return cycle + + for cell_idx, cell in enumerate(self.cells): + for net in cell.output_nets(cell_idx): + assert traverse(net) is None + for value in self.signals.values(): + for net in value: + assert traverse(net) is None + class ModuleNetFlow(enum.Enum): """Describes how a given Net flows into or out of a Module. @@ -509,6 +575,9 @@ class Cell: def resolve_nets(self, netlist: Netlist): raise NotImplementedError + def comb_edges_to(self, bit: int) -> "Iterable[(Net, Any)]": + raise NotImplementedError + class Top(Cell): """A special cell type representing top-level non-IO ports. Must be present in the netlist exactly @@ -558,6 +627,9 @@ class Top(Cell): ports = "".join(ports) return f"(top{ports})" + def comb_edges_to(self, bit): + return [] + class Operator(Cell): """Roughly corresponds to ``hdl.ast.Operator``. @@ -627,6 +699,28 @@ class Operator(Cell): inputs = " ".join(repr(input) for input in self.inputs) return f"({self.operator} {inputs})" + def comb_edges_to(self, bit): + if len(self.inputs) == 1: + if self.operator == "~": + yield (self.inputs[0][bit], self.src_loc) + else: + for net in self.inputs[0]: + yield (net, self.src_loc) + elif len(self.inputs) == 2: + if self.operator in ("&", "|", "^"): + yield (self.inputs[0][bit], self.src_loc) + yield (self.inputs[1][bit], self.src_loc) + else: + for net in self.inputs[0]: + yield (net, self.src_loc) + for net in self.inputs[1]: + yield (net, self.src_loc) + else: + assert self.operator == "m" + yield (self.inputs[0][0], self.src_loc) + yield (self.inputs[1][bit], self.src_loc) + yield (self.inputs[2][bit], self.src_loc) + class Part(Cell): """Corresponds to ``hdl.ast.Part``. @@ -666,6 +760,12 @@ class Part(Cell): value_signed = "signed" if self.value_signed else "unsigned" return f"(part {self.value} {value_signed} {self.offset} {self.width} {self.stride})" + def comb_edges_to(self, bit): + for net in self.value: + yield (net, self.src_loc) + for net in self.offset: + yield (net, self.src_loc) + class Matches(Cell): """A combinatorial cell performing a comparison like ``Value.matches`` @@ -698,6 +798,10 @@ class Matches(Cell): patterns = " ".join(self.patterns) return f"(matches {self.value} {patterns})" + def comb_edges_to(self, bit): + for net in self.value: + yield (net, self.src_loc) + class PriorityMatch(Cell): """Used to represent a single switch on the control plane of processes. @@ -733,6 +837,11 @@ class PriorityMatch(Cell): def __repr__(self): return f"(priority_match {self.en} {self.inputs})" + def comb_edges_to(self, bit): + yield (self.en, self.src_loc) + for net in self.inputs[:bit + 1]: + yield (net, self.src_loc) + class Assignment: """A single assignment in an ``AssignmentList``. @@ -809,6 +918,13 @@ class AssignmentList(Cell): assignments = " ".join(repr(assign) for assign in self.assignments) return f"(assignment_list {self.default} {assignments})" + def comb_edges_to(self, bit): + yield (self.default[bit], self.src_loc) + for assign in self.assignments: + yield (assign.cond, assign.src_loc) + if bit >= assign.start and bit < assign.start + len(assign.value): + yield (assign.value[bit - assign.start], assign.src_loc) + class FlipFlop(Cell): """A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value. @@ -853,6 +969,10 @@ class FlipFlop(Cell): attributes = "".join(f" (attr {key} {val!r})" for key, val in self.attributes.items()) return f"(flipflop {self.data} {self.init} {self.clk_edge} {self.clk} {self.arst}{attributes})" + def comb_edges_to(self, bit): + yield (self.clk, self.src_loc) + yield (self.arst, self.src_loc) + class Memory(Cell): """Corresponds to ``Memory``. ``init`` must have length equal to ``depth``. @@ -960,6 +1080,10 @@ class AsyncReadPort(Cell): def __repr__(self): return f"(read_port {self.memory} {self.width} {self.addr})" + def comb_edges_to(self, bit): + for net in self.addr: + yield (net, self.src_loc) + class SyncReadPort(Cell): """A single synchronous read port of a memory. The cell output is the data port. @@ -1004,6 +1128,9 @@ class SyncReadPort(Cell): transparent_for = " ".join(str(port) for port in self.transparent_for) return f"(read_port {self.memory} {self.width} {self.addr} {self.en} {self.clk_edge} {self.clk} ({transparent_for}))" + def comb_edges_to(self, bit): + return [] + class AsyncPrint(Cell): """Corresponds to ``Print`` in the "comb" domain. @@ -1087,6 +1214,9 @@ class Initial(Cell): def __repr__(self): return f"(initial)" + def comb_edges_to(self, bit): + return [] + class AnyValue(Cell): """Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'`` @@ -1117,6 +1247,9 @@ class AnyValue(Cell): def __repr__(self): return f"({self.kind} {self.width})" + def comb_edges_to(self, bit): + return [] + class AsyncProperty(Cell): """Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain. @@ -1274,6 +1407,10 @@ class Instance(Cell): items = " ".join(items) return f"(instance {self.type!r} {self.name!r} {items})" + def comb_edges_to(self, bit): + # don't ask me, I'm a housecat + return [] + class IOBuffer(Cell): """An IO buffer cell. This cell does two things: @@ -1328,3 +1465,8 @@ class IOBuffer(Cell): return f"(iob {self.dir.value} {self.port})" else: return f"(iob {self.dir.value} {self.port} {self.o} {self.oe})" + + def comb_edges_to(self, bit): + if self.dir is not IODirection.Input: + yield (self.o[bit], self.src_loc) + yield (self.oe, self.src_loc) diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index 4795551..a3326e5 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -7,7 +7,7 @@ from amaranth.hdl._cd import * from amaranth.hdl._dsl import * from amaranth.hdl._ir import * from amaranth.hdl._mem import * -from amaranth.hdl._nir import SignalField +from amaranth.hdl._nir import SignalField, CombinationalCycle from amaranth.lib import enum, data @@ -3542,3 +3542,22 @@ class FieldsTestCase(FHDLTestCase): self.assertEqual(nl.signal_fields[s4], { (): SignalField(nl.signals[s4], signed=False), }) + + +class CycleTestCase(FHDLTestCase): + def test_cycle(self): + a = Signal() + b = Signal() + m = Module() + m.d.comb += [ + a.eq(~b), + b.eq(~a), + ] + with self.assertRaisesRegex(CombinationalCycle, + r"^Combinational cycle detected, path:\n" + r".*test_hdl_ir.py:\d+: operator ~ bit 0\n" + r".*test_hdl_ir.py:\d+: signal b bit 0\n" + r".*test_hdl_ir.py:\d+: operator ~ bit 0\n" + r".*test_hdl_ir.py:\d+: signal a bit 0\n" + r"$"): + build_netlist(Fragment.get(m, None), [])