hdl._nir: add combinational cycle detection.

Fixes #704.
Fixes #1143.
This commit is contained in:
Wanda 2024-04-13 13:38:47 +02:00 committed by Catherine
parent 3fbed68365
commit 877a1062a6
4 changed files with 172 additions and 7 deletions

View file

@ -9,6 +9,7 @@ from ._cd import DomainError, ClockDomain
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
from ._ir import Instance, IOBufferInstance
from ._mem import FrozenMemory, MemoryData, MemoryInstance, Memory, ReadPort, WritePort, DummyPort
from ._nir import CombinationalCycle
from ._rec import Record
from ._xfrm import DomainRenamer, ResetInserter, EnableInserter
@ -28,6 +29,8 @@ __all__ = [
# _ir
"UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment",
"Instance", "IOBufferInstance",
# _nir
"CombinationalCycle",
# _mem
"FrozenMemory", "MemoryData", "MemoryInstance", "Memory", "ReadPort", "WritePort", "DummyPort",
# _rec

View file

@ -709,7 +709,7 @@ class NetlistEmitter:
def emit_signal(self, signal) -> _nir.Value:
if signal in self.netlist.signals:
return self.netlist.signals[signal]
value = self.netlist.alloc_late_value(len(signal))
value = self.netlist.alloc_late_value(signal)
self.netlist.signals[signal] = value
for bit, net in enumerate(value):
self.late_net_to_signal[net] = (signal, bit)
@ -1738,6 +1738,7 @@ def build_netlist(fragment, ports=(), *, name="top", all_undef_to_ff=False, **kw
design = fragment.prepare(ports=ports, hierarchy=(name,), **kwargs)
netlist = _nir.Netlist()
_emit_netlist(netlist, design, all_undef_to_ff=all_undef_to_ff)
netlist.check_comb_cycles()
netlist.resolve_all_nets()
_compute_net_flows(netlist)
_compute_ports(netlist)

View file

@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, Any
import enum
from ._ast import SignalDict
@ -7,8 +7,9 @@ from . import _ast
__all__ = [
# Netlist core
"CombinationalCycle",
"Net", "Value", "IONet", "IOValue",
"FormatValue", "Format",
"FormatValue", "Format", "SignalField",
"Netlist", "ModuleNetFlow", "IODirection", "Module", "Cell", "Top",
# Computation cells
"Operator", "Part",
@ -25,6 +26,10 @@ __all__ = [
]
class CombinationalCycle(Exception):
pass
class Net(int):
__slots__ = ()
@ -335,6 +340,7 @@ class Netlist:
modules : list of ``Module``
cells : list of ``Cell``
connections : dict of (negative) int to int
late_to_signal : dict of (late) Net to its Signal and bit number
io_ports : list of ``IOPort``
signals : dict of Signal to ``Value``
signal_fields: dict of Signal to dict of tuple[str | int] to SignalField
@ -344,6 +350,7 @@ class Netlist:
self.modules: list[Module] = []
self.cells: list[Cell] = [Top()]
self.connections: dict[Net, Net] = {}
self.late_to_signal: dict[Net, (_ast.Signal, int)] = {}
self.io_ports: list[_ast.IOPort] = []
self.signals = SignalDict()
self.signal_fields = SignalDict()
@ -405,9 +412,12 @@ class Netlist:
cell_idx = self.add_cell(cell)
return Value(Net.from_cell(cell_idx, bit) for bit in range(width))
def alloc_late_value(self, width: int):
self.last_late_net -= width
return Value(Net.from_late(self.last_late_net + bit) for bit in range(width))
def alloc_late_value(self, signal: _ast.Signal):
self.last_late_net -= len(signal)
value = Value(Net.from_late(self.last_late_net + bit) for bit in range(len(signal)))
for bit, net in enumerate(value):
self.late_to_signal[net] = signal, bit
return value
@property
def top(self):
@ -415,6 +425,62 @@ class Netlist:
assert isinstance(top, Top)
return top
def check_comb_cycles(self):
class Cycle:
def __init__(self, start):
self.start = start
self.path = []
checked = set()
busy = set()
def traverse(net):
if net in checked:
return None
if net in busy:
return Cycle(net)
busy.add(net)
cycle = None
if net.is_const:
pass
elif net.is_late:
cycle = traverse(self.connections[net])
if cycle is not None:
sig, bit = self.late_to_signal[net]
cycle.path.append((sig, bit, sig.src_loc))
else:
for src, src_loc in self.cells[net.cell].comb_edges_to(net.bit):
cycle = traverse(src)
if cycle is not None:
cycle.path.append((self.cells[net.cell], net.bit, src_loc))
break
if cycle is not None and cycle.start == net:
msg = ["Combinational cycle detected, path:\n"]
for obj, bit, src_loc in reversed(cycle.path):
if isinstance(obj, _ast.Signal):
obj = f"signal {obj.name}"
elif isinstance(obj, Operator):
obj = f"operator {obj.operator}"
else:
obj = f"cell {obj.__class__.__name__}"
src_loc = "<unknown>:0" if src_loc is None else f"{src_loc[0]}:{src_loc[1]}"
msg.append(f" {src_loc}: {obj} bit {bit}\n")
raise CombinationalCycle("".join(msg))
busy.remove(net)
checked.add(net)
return cycle
for cell_idx, cell in enumerate(self.cells):
for net in cell.output_nets(cell_idx):
assert traverse(net) is None
for value in self.signals.values():
for net in value:
assert traverse(net) is None
class ModuleNetFlow(enum.Enum):
"""Describes how a given Net flows into or out of a Module.
@ -509,6 +575,9 @@ class Cell:
def resolve_nets(self, netlist: Netlist):
raise NotImplementedError
def comb_edges_to(self, bit: int) -> "Iterable[(Net, Any)]":
raise NotImplementedError
class Top(Cell):
"""A special cell type representing top-level non-IO ports. Must be present in the netlist exactly
@ -558,6 +627,9 @@ class Top(Cell):
ports = "".join(ports)
return f"(top{ports})"
def comb_edges_to(self, bit):
return []
class Operator(Cell):
"""Roughly corresponds to ``hdl.ast.Operator``.
@ -627,6 +699,28 @@ class Operator(Cell):
inputs = " ".join(repr(input) for input in self.inputs)
return f"({self.operator} {inputs})"
def comb_edges_to(self, bit):
if len(self.inputs) == 1:
if self.operator == "~":
yield (self.inputs[0][bit], self.src_loc)
else:
for net in self.inputs[0]:
yield (net, self.src_loc)
elif len(self.inputs) == 2:
if self.operator in ("&", "|", "^"):
yield (self.inputs[0][bit], self.src_loc)
yield (self.inputs[1][bit], self.src_loc)
else:
for net in self.inputs[0]:
yield (net, self.src_loc)
for net in self.inputs[1]:
yield (net, self.src_loc)
else:
assert self.operator == "m"
yield (self.inputs[0][0], self.src_loc)
yield (self.inputs[1][bit], self.src_loc)
yield (self.inputs[2][bit], self.src_loc)
class Part(Cell):
"""Corresponds to ``hdl.ast.Part``.
@ -666,6 +760,12 @@ class Part(Cell):
value_signed = "signed" if self.value_signed else "unsigned"
return f"(part {self.value} {value_signed} {self.offset} {self.width} {self.stride})"
def comb_edges_to(self, bit):
for net in self.value:
yield (net, self.src_loc)
for net in self.offset:
yield (net, self.src_loc)
class Matches(Cell):
"""A combinatorial cell performing a comparison like ``Value.matches``
@ -698,6 +798,10 @@ class Matches(Cell):
patterns = " ".join(self.patterns)
return f"(matches {self.value} {patterns})"
def comb_edges_to(self, bit):
for net in self.value:
yield (net, self.src_loc)
class PriorityMatch(Cell):
"""Used to represent a single switch on the control plane of processes.
@ -733,6 +837,11 @@ class PriorityMatch(Cell):
def __repr__(self):
return f"(priority_match {self.en} {self.inputs})"
def comb_edges_to(self, bit):
yield (self.en, self.src_loc)
for net in self.inputs[:bit + 1]:
yield (net, self.src_loc)
class Assignment:
"""A single assignment in an ``AssignmentList``.
@ -809,6 +918,13 @@ class AssignmentList(Cell):
assignments = " ".join(repr(assign) for assign in self.assignments)
return f"(assignment_list {self.default} {assignments})"
def comb_edges_to(self, bit):
yield (self.default[bit], self.src_loc)
for assign in self.assignments:
yield (assign.cond, assign.src_loc)
if bit >= assign.start and bit < assign.start + len(assign.value):
yield (assign.value[bit - assign.start], assign.src_loc)
class FlipFlop(Cell):
"""A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value.
@ -853,6 +969,10 @@ class FlipFlop(Cell):
attributes = "".join(f" (attr {key} {val!r})" for key, val in self.attributes.items())
return f"(flipflop {self.data} {self.init} {self.clk_edge} {self.clk} {self.arst}{attributes})"
def comb_edges_to(self, bit):
yield (self.clk, self.src_loc)
yield (self.arst, self.src_loc)
class Memory(Cell):
"""Corresponds to ``Memory``. ``init`` must have length equal to ``depth``.
@ -960,6 +1080,10 @@ class AsyncReadPort(Cell):
def __repr__(self):
return f"(read_port {self.memory} {self.width} {self.addr})"
def comb_edges_to(self, bit):
for net in self.addr:
yield (net, self.src_loc)
class SyncReadPort(Cell):
"""A single synchronous read port of a memory. The cell output is the data port.
@ -1004,6 +1128,9 @@ class SyncReadPort(Cell):
transparent_for = " ".join(str(port) for port in self.transparent_for)
return f"(read_port {self.memory} {self.width} {self.addr} {self.en} {self.clk_edge} {self.clk} ({transparent_for}))"
def comb_edges_to(self, bit):
return []
class AsyncPrint(Cell):
"""Corresponds to ``Print`` in the "comb" domain.
@ -1087,6 +1214,9 @@ class Initial(Cell):
def __repr__(self):
return f"(initial)"
def comb_edges_to(self, bit):
return []
class AnyValue(Cell):
"""Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'``
@ -1117,6 +1247,9 @@ class AnyValue(Cell):
def __repr__(self):
return f"({self.kind} {self.width})"
def comb_edges_to(self, bit):
return []
class AsyncProperty(Cell):
"""Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain.
@ -1274,6 +1407,10 @@ class Instance(Cell):
items = " ".join(items)
return f"(instance {self.type!r} {self.name!r} {items})"
def comb_edges_to(self, bit):
# don't ask me, I'm a housecat
return []
class IOBuffer(Cell):
"""An IO buffer cell. This cell does two things:
@ -1328,3 +1465,8 @@ class IOBuffer(Cell):
return f"(iob {self.dir.value} {self.port})"
else:
return f"(iob {self.dir.value} {self.port} {self.o} {self.oe})"
def comb_edges_to(self, bit):
if self.dir is not IODirection.Input:
yield (self.o[bit], self.src_loc)
yield (self.oe, self.src_loc)

View file

@ -7,7 +7,7 @@ from amaranth.hdl._cd import *
from amaranth.hdl._dsl import *
from amaranth.hdl._ir import *
from amaranth.hdl._mem import *
from amaranth.hdl._nir import SignalField
from amaranth.hdl._nir import SignalField, CombinationalCycle
from amaranth.lib import enum, data
@ -3542,3 +3542,22 @@ class FieldsTestCase(FHDLTestCase):
self.assertEqual(nl.signal_fields[s4], {
(): SignalField(nl.signals[s4], signed=False),
})
class CycleTestCase(FHDLTestCase):
def test_cycle(self):
a = Signal()
b = Signal()
m = Module()
m.d.comb += [
a.eq(~b),
b.eq(~a),
]
with self.assertRaisesRegex(CombinationalCycle,
r"^Combinational cycle detected, path:\n"
r".*test_hdl_ir.py:\d+: operator ~ bit 0\n"
r".*test_hdl_ir.py:\d+: signal b bit 0\n"
r".*test_hdl_ir.py:\d+: operator ~ bit 0\n"
r".*test_hdl_ir.py:\d+: signal a bit 0\n"
r"$"):
build_netlist(Fragment.get(m, None), [])