From 6f44438e585dd54a89c0112732710b389e25a71b Mon Sep 17 00:00:00 2001 From: Catherine Date: Mon, 21 Aug 2023 05:23:15 +0000 Subject: [PATCH] hdl._ir,hdl._nir,back.rtlil: new intermediate representation. The new intermediate representation will enable global analyses on Amaranth code without lowering it to another representation such as RTLIL. This commit also changes the RTLIL builder to use the new IR. Co-authored-by: Wanda --- amaranth/back/rtlil.py | 1495 +++++++++++++++++++--------------------- amaranth/hdl/_ast.py | 12 +- amaranth/hdl/_ir.py | 836 ++++++++++++++++++++-- amaranth/hdl/_nir.py | 1003 +++++++++++++++++++++++++++ amaranth/hdl/_xfrm.py | 107 +-- amaranth/hdl/xfrm.py | 1 - tests/test_hdl_xfrm.py | 130 ---- 7 files changed, 2536 insertions(+), 1048 deletions(-) create mode 100644 amaranth/hdl/_nir.py diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 0879453..a442238 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -1,13 +1,9 @@ +from typing import Iterable import io -from collections import OrderedDict -from contextlib import contextmanager -import warnings -import re -from .._utils import flatten from ..utils import bits_for -from ..hdl import _ast, _ir, _mem, _xfrm, _repr from ..lib import wiring +from ..hdl import _repr, _ast, _ir, _nir __all__ = ["convert", "convert_fragment"] @@ -89,11 +85,6 @@ class _BufferedBuilder: self._buffer.write(fmt.format(*args, **kwargs)) -class _ProxiedBuilder: - def _append(self, *args, **kwargs): - self.rtlil._append(*args, **kwargs) - - class _AttrBuilder: def __init__(self, emit_src, *args, **kwargs): super().__init__(*args, **kwargs) @@ -183,10 +174,7 @@ class _ModuleBuilder(_AttrBuilder, _BufferedBuilder, _Namer): self._append(" parameter \\{} {}\n", param, _const(value)) for port, wire in ports.items(): - # By convention, Yosys ports named $\d+ are positional. Amaranth does not support - # connecting cell ports by position. See amaranth-lang/amaranth#733. - assert not re.match(r"^\$\d+$", port) - self._append(" connect {} {}\n", port, wire) + self._append(" connect \\{} {}\n", port, wire) self._append(" end\n") return name @@ -216,11 +204,14 @@ class _ProcessBuilder(_AttrBuilder, _BufferedBuilder): return _CaseBuilder(self, indent=2) -class _CaseBuilder(_ProxiedBuilder): +class _CaseBuilder: def __init__(self, rtlil, indent): self.rtlil = rtlil self.indent = indent + def _append(self, *args, **kwargs): + self.rtlil._append(*args, **kwargs) + def __enter__(self): return self @@ -234,7 +225,7 @@ class _CaseBuilder(_ProxiedBuilder): return _SwitchBuilder(self.rtlil, cond, attrs, src, self.indent) -class _SwitchBuilder(_AttrBuilder, _ProxiedBuilder): +class _SwitchBuilder(_AttrBuilder): def __init__(self, rtlil, cond, attrs, src, indent): super().__init__(emit_src=rtlil.emit_src) self.rtlil = rtlil @@ -243,6 +234,9 @@ class _SwitchBuilder(_AttrBuilder, _ProxiedBuilder): self.src = src self.indent = indent + def _append(self, *args, **kwargs): + self.rtlil._append(*args, **kwargs) + def __enter__(self): self._attributes(self.attrs, src=self.src, indent=self.indent) self._append("{}switch {}\n", " " * self.indent, self.cond) @@ -268,777 +262,738 @@ def _src(src_loc): return f"{file}:{line}" -class _LegalizeValue(Exception): - def __init__(self, value, branches, src_loc): - self.value = value - self.branches = list(branches) - self.src_loc = src_loc - - -class _ValueCompilerState: - def __init__(self, rtlil): - self.rtlil = rtlil - self.wires = _ast.SignalDict() - self.driven = _ast.SignalDict() - self.ports = _ast.SignalDict() - self.anys = _ast.ValueDict() - - self.expansions = _ast.ValueDict() - - def add_driven(self, signal, sync): - self.driven[signal] = sync - - def add_port(self, signal, kind): - assert kind in ("i", "o", "io") - if kind == "i": - kind = "input" - elif kind == "o": - kind = "output" - elif kind == "io": - kind = "inout" - self.ports[signal] = (len(self.ports), kind) - - def resolve(self, signal, prefix=None): - if len(signal) == 0: - return "{ }", "{ }" - - if signal in self.wires: - return self.wires[signal] - - if signal in self.ports: - port_id, port_kind = self.ports[signal] - else: - port_id = port_kind = None - if prefix is not None: - wire_name = f"{prefix}_{signal.name}" - else: - wire_name = signal.name - - is_sync_driven = signal in self.driven and self.driven[signal] - - attrs = dict(signal.attrs) - for repr in signal._value_repr: - if repr.path == () and isinstance(repr.format, _repr.FormatEnum): - enum = repr.format.enum - attrs["enum_base_type"] = enum.__name__ - for value in enum: - attrs["enum_value_{:0{}b}".format(value.value, signal.width)] = value.name - - # For every signal in the sync domain, assign \sig's initial value (using the \init reg - # attribute) to the reset value. - if is_sync_driven: - attrs["init"] = _ast.Const(signal.reset, signal.width) - - wire_curr = self.rtlil.wire(width=signal.width, name=wire_name, - port_id=port_id, port_kind=port_kind, - attrs=attrs, src=_src(signal.src_loc)) - if is_sync_driven: - wire_next = self.rtlil.wire(width=signal.width, name=wire_curr + "$next", - src=_src(signal.src_loc)) - else: - wire_next = None - self.wires[signal] = (wire_curr, wire_next) - - return wire_curr, wire_next - - def resolve_curr(self, signal, prefix=None): - wire_curr, wire_next = self.resolve(signal, prefix) - return wire_curr - - def expand(self, value): - if not self.expansions: - return value - return self.expansions.get(value, value) - - @contextmanager - def expand_to(self, value, expansion): - try: - assert value not in self.expansions - self.expansions[value] = expansion - yield - finally: - del self.expansions[value] - - -class _ValueCompiler(_xfrm.ValueVisitor): - def __init__(self, state): - self.s = state - - def on_unknown(self, value): - if value is None: - return None - else: - super().on_unknown(value) - - def on_ClockSignal(self, value): - raise NotImplementedError # :nocov: - - def on_ResetSignal(self, value): - raise NotImplementedError # :nocov: - - def on_Cat(self, value): - return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts]))) - - def _prepare_value_for_Slice(self, value): - raise NotImplementedError # :nocov: - - def on_Slice(self, value): - if value.start == 0 and value.stop == len(value.value): - return self(value.value) - - sigspec = self._prepare_value_for_Slice(value.value) - if value.start == value.stop: - return "{}" - elif value.start + 1 == value.stop: - return f"{sigspec} [{value.start}]" - else: - return f"{sigspec} [{value.stop - 1}:{value.start}]" - - def on_ArrayProxy(self, value): - index = self.s.expand(value.index) - if isinstance(index, _ast.Const): - if index.value < len(value.elems): - elem = value.elems[index.value] - else: - elem = value.elems[-1] - return self.match_shape(elem, value.shape()) - else: - max_index = 1 << len(value.index) - max_elem = len(value.elems) - raise _LegalizeValue(value.index, range(min(max_index, max_elem)), value.src_loc) - - -class _RHSValueCompiler(_ValueCompiler): - operator_map = { - (1, "~"): "$not", - (1, "-"): "$neg", - (1, "b"): "$reduce_bool", - (1, "r|"): "$reduce_or", - (1, "r&"): "$reduce_and", - (1, "r^"): "$reduce_xor", - (2, "+"): "$add", - (2, "-"): "$sub", - (2, "*"): "$mul", - (2, "//"): "$divfloor", - (2, "%"): "$modfloor", - (2, "**"): "$pow", - (2, "<<"): "$sshl", - (2, ">>"): "$sshr", - (2, "&"): "$and", - (2, "^"): "$xor", - (2, "|"): "$or", - (2, "=="): "$eq", - (2, "!="): "$ne", - (2, "<"): "$lt", - (2, "<="): "$le", - (2, ">"): "$gt", - (2, ">="): "$ge", - (3, "m"): "$mux", - } - - def on_value(self, value): - return super().on_value(self.s.expand(value)) - - def on_Const(self, value): - return _const(value) - - def on_AnyValue(self, value): - if value in self.s.anys: - return self.s.anys[value] - - res_shape = value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$" + value.kind.value, ports={ - "\\Y": res, - }, params={ - "WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - self.s.anys[value] = res - return res - - def on_Initial(self, value): - res = self.s.rtlil.wire(width=1, src=_src(value.src_loc)) - self.s.rtlil.cell("$initstate", ports={ - "\\Y": res, - }, src=_src(value.src_loc)) - return res - - def on_Signal(self, value): - wire_curr, wire_next = self.s.resolve(value) - return wire_curr - - def on_Operator_unary(self, value): - arg, = value.operands - if value.operator in ("u", "s"): - # These operators don't change the bit pattern, only its interpretation. - return self(arg) - - arg_shape, res_shape = arg.shape(), value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell(self.operator_map[(1, value.operator)], ports={ - "\\A": self(arg), - "\\Y": res, - }, params={ - "A_SIGNED": arg_shape.signed, - "A_WIDTH": arg_shape.width, - "Y_WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - return res - - def match_shape(self, value, new_shape): - if isinstance(value, _ast.Const): - return self(_ast.Const(value.value, new_shape)) - - value_shape = value.shape() - if new_shape.width <= value_shape.width: - return self(_ast.Slice(value, 0, new_shape.width)) - - res = self.s.rtlil.wire(width=new_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$pos", ports={ - "\\A": self(value), - "\\Y": res, - }, params={ - "A_SIGNED": value_shape.signed, - "A_WIDTH": value_shape.width, - "Y_WIDTH": new_shape.width, - }, src=_src(value.src_loc)) - return res - - def on_Operator_binary(self, value): - lhs, rhs = value.operands - lhs_shape, rhs_shape, res_shape = lhs.shape(), rhs.shape(), value.shape() - if lhs_shape.signed == rhs_shape.signed or value.operator in ("<<", ">>", "**"): - lhs_wire = self(lhs) - rhs_wire = self(rhs) - else: - lhs_shape = rhs_shape = _ast.signed(max(lhs_shape.width + rhs_shape.signed, - rhs_shape.width + lhs_shape.signed)) - lhs_wire = self.match_shape(lhs, lhs_shape) - rhs_wire = self.match_shape(rhs, rhs_shape) - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell(self.operator_map[(2, value.operator)], ports={ - "\\A": lhs_wire, - "\\B": rhs_wire, - "\\Y": res, - }, params={ - "A_SIGNED": lhs_shape.signed, - "A_WIDTH": lhs_shape.width, - "B_SIGNED": rhs_shape.signed, - "B_WIDTH": rhs_shape.width, - "Y_WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - if value.operator in ("//", "%"): - # RTLIL leaves division by zero undefined, but we require it to return zero. - divmod_res = res - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$mux", ports={ - "\\A": divmod_res, - "\\B": self(_ast.Const(0, res_shape)), - "\\S": self(rhs == 0), - "\\Y": res, - }, params={ - "WIDTH": res_shape.width - }, src=_src(value.src_loc)) - return res - - def on_Operator_mux(self, value): - sel, val1, val0 = value.operands - if len(sel) != 1: - sel = sel.bool() - res_shape = value.shape() - val1_wire = self.match_shape(val1, res_shape) - val0_wire = self.match_shape(val0, res_shape) - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$mux", ports={ - "\\A": val0_wire, - "\\B": val1_wire, - "\\S": self(sel), - "\\Y": res, - }, params={ - "WIDTH": res_shape.width - }, src=_src(value.src_loc)) - return res - - def on_Operator(self, value): - if len(value.operands) == 1: - return self.on_Operator_unary(value) - elif len(value.operands) == 2: - return self.on_Operator_binary(value) - elif len(value.operands) == 3: - assert value.operator == "m" - return self.on_Operator_mux(value) - else: - raise TypeError # :nocov: - - def _prepare_value_for_Slice(self, value): - if isinstance(value, (_ast.Signal, _ast.Slice, _ast.Cat)): - sigspec = self(value) - else: - sigspec = self.s.rtlil.wire(len(value), src=_src(value.src_loc)) - self.s.rtlil.connect(sigspec, self(value)) - return sigspec - - def on_Part(self, value): - lhs, rhs = value.value, value.offset - if value.stride != 1: - rhs *= value.stride - lhs_shape, rhs_shape, res_shape = lhs.shape(), rhs.shape(), value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - # Note: Verilog's x[o+:w] construct produces a $shiftx cell, not a $shift cell. - # However, Amaranth's semantics defines the out-of-range bits to be zero, so it is correct - # to use a $shift cell here instead, even though it produces less idiomatic Verilog. - self.s.rtlil.cell("$shift", ports={ - "\\A": self(lhs), - "\\B": self(rhs), - "\\Y": res, - }, params={ - "A_SIGNED": lhs_shape.signed, - "A_WIDTH": lhs_shape.width, - "B_SIGNED": rhs_shape.signed, - "B_WIDTH": rhs_shape.width, - "Y_WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - return res - - -class _LHSValueCompiler(_ValueCompiler): - def on_Const(self, value): - raise TypeError # :nocov: - - def on_AnyValue(self, value): - raise TypeError # :nocov: - - def on_Initial(self, value): - raise TypeError # :nocov: - - def on_Operator(self, value): - if value.operator in ("u", "s"): - # These operators are transparent on the LHS. - arg, = value.operands - return self(arg) - - raise TypeError # :nocov: - - def match_shape(self, value, new_shape): - value_shape = value.shape() - if new_shape.width == value_shape.width: - return self(value) - elif new_shape.width < value_shape.width: - return self(_ast.Slice(value, 0, new_shape.width)) - else: # new_shape.width > value_shape.width - dummy_bits = new_shape.width - value_shape.width - dummy_wire = self.s.rtlil.wire(dummy_bits) - return f"{{ {dummy_wire} {self(value)} }}" - - def on_Signal(self, value): - if value not in self.s.driven: - raise ValueError(f"No LHS wire for non-driven signal {value!r}") - wire_curr, wire_next = self.s.resolve(value) - return wire_next or wire_curr - - def _prepare_value_for_Slice(self, value): - assert isinstance(value, (_ast.Signal, _ast.Slice, _ast.Cat, _ast.Part)) - return self(value) - - def on_Part(self, value): - offset = self.s.expand(value.offset) - if isinstance(offset, _ast.Const): - start = offset.value * value.stride - stop = start + value.width - slice = self(_ast.Slice(value.value, start, min(len(value.value), stop))) - if len(value.value) >= stop: - return slice - else: - dummy_wire = self.s.rtlil.wire(stop - len(value.value)) - return f"{{ {dummy_wire} {slice} }}" - else: - # Only so many possible parts. The amount of branches is exponential; if value.offset - # is large (e.g. 32-bit wide), trying to naively legalize it is likely to exhaust - # system resources. - max_branches = len(value.value) // value.stride + 1 - raise _LegalizeValue(value.offset, - range(1 << len(value.offset))[:max_branches], - value.src_loc) - - -class _StatementCompiler(_xfrm.StatementVisitor): - def __init__(self, state, rhs_compiler, lhs_compiler): - self.state = state - self.rhs_compiler = rhs_compiler - self.lhs_compiler = lhs_compiler - - self._domain = None - self._case = None - self._test_cache = {} - self._has_rhs = False - self._wrap_assign = False - - @contextmanager - def case(self, switch, values, attrs={}, src=""): - try: - old_case = self._case - with switch.case(*values, attrs=attrs, src=src) as self._case: - yield - finally: - self._case = old_case - - def _check_rhs(self, value): - if self._has_rhs or next(iter(value._rhs_signals()), None) is not None: - self._has_rhs = True - - def on_Assign(self, stmt): - self._check_rhs(stmt.rhs) - - lhs_shape, rhs_shape = stmt.lhs.shape(), stmt.rhs.shape() - if lhs_shape.width == rhs_shape.width: - rhs_sigspec = self.rhs_compiler(stmt.rhs) - else: - # In RTLIL, LHS and RHS of assignment must have exactly same width. - rhs_sigspec = self.rhs_compiler.match_shape(stmt.rhs, lhs_shape) - if self._wrap_assign: - # In RTLIL, all assigns are logically sequenced before all switches, even if they are - # interleaved in the source. In Amaranth, the source ordering is used. To handle this - # mismatch, we wrap all assigns following a switch in a dummy switch. - with self._case.switch("{ }") as wrap_switch: - with wrap_switch.case() as wrap_case: - wrap_case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) - else: - self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) - - def on_Property(self, stmt): - self(stmt._check.eq(stmt.test)) - self(stmt._en.eq(1)) - - en_wire = self.rhs_compiler(stmt._en) - check_wire = self.rhs_compiler(stmt._check) - self.state.rtlil.cell("$" + stmt.kind.value, ports={ - "\\A": check_wire, - "\\EN": en_wire, - }, src=_src(stmt.src_loc), name=stmt.name) - - def on_Switch(self, stmt): - self._check_rhs(stmt.test) - - if not self.state.expansions: - # We repeatedly translate the same switches over and over (see the LHSGroupAnalyzer - # related code below), and translating the switch test only once helps readability. - if stmt not in self._test_cache: - self._test_cache[stmt] = self.rhs_compiler(stmt.test) - test_sigspec = self._test_cache[stmt] - else: - # However, if the switch test contains an illegal value, then it may not be cached - # (since the illegal value will be repeatedly replaced with different constants), so - # don't cache anything in that case. - test_sigspec = self.rhs_compiler(stmt.test) - - with self._case.switch(test_sigspec, src=_src(stmt.src_loc)) as switch: - for values, stmts in stmt.cases.items(): - case_attrs = {} - case_src = None - if values in stmt.case_src_locs: - case_src = _src(stmt.case_src_locs[values]) - if isinstance(stmt.test, _ast.Signal) and stmt.test.decoder: - decoded_values = [] - for value in values: - if "-" in value: - decoded_values.append("") - else: - decoded_values.append(stmt.test.decoder(int(value, 2))) - case_attrs["amaranth.decoding"] = "|".join(decoded_values) - with self.case(switch, values, attrs=case_attrs, src=case_src): - self._wrap_assign = False - self.on_statements(stmts) - self._wrap_assign = True - - def on_statement(self, stmt): - try: - super().on_statement(stmt) - except _LegalizeValue as legalize: - with self._case.switch(self.rhs_compiler(legalize.value), - src=_src(legalize.src_loc)) as switch: - shape = legalize.value.shape() - tests = ["{:0{}b}".format(v, shape.width) for v in legalize.branches] - if tests: - tests[-1] = "-" * shape.width - for branch, test in zip(legalize.branches, tests): - with self.case(switch, (test,)): - self._wrap_assign = False - branch_value = _ast.Const(branch, shape) - with self.state.expand_to(legalize.value, branch_value): - self.on_statement(stmt) - self._wrap_assign = True - - def on_statements(self, stmts): - for stmt in stmts: - self.on_statement(stmt) - - -def _convert_fragment(builder, fragment, name_map, hierarchy): - if isinstance(fragment, _ir.Instance): - port_map = OrderedDict() - for port_name, (value, dir) in fragment.named_ports.items(): - port_map[f"\\{port_name}"] = value - - params = OrderedDict(fragment.parameters) - - if fragment.type[0] == "$": - return fragment.type, port_map, params - else: - return f"\\{fragment.type}", port_map, params - - if isinstance(fragment, _mem.MemoryInstance): - init = "".join(format(_ast.Const(elem, _ast.unsigned(fragment._width)).value, f"0{fragment._width}b") for elem in reversed(fragment._init)) - init = _ast.Const(int(init or "0", 2), fragment._depth * fragment._width) - rd_clk = [] - rd_clk_enable = 0 - rd_clk_polarity = 0 - rd_transparency_mask = 0 - for index, port in enumerate(fragment._read_ports): - if port._domain != "comb": - cd = fragment.domains[port._domain] - rd_clk.append(cd.clk) - if cd.clk_edge == "pos": - rd_clk_polarity |= 1 << index - rd_clk_enable |= 1 << index - for write_index in port._transparency: - rd_transparency_mask |= 1 << (index * len(fragment._write_ports) + write_index) - else: - rd_clk.append(_ast.Const(0, 1)) - wr_clk = [] - wr_clk_enable = 0 - wr_clk_polarity = 0 - for index, port in enumerate(fragment._write_ports): - cd = fragment.domains[port._domain] - wr_clk.append(cd.clk) - wr_clk_enable |= 1 << index - if cd.clk_edge == "pos": - wr_clk_polarity |= 1 << index - params = { - "MEMID": builder._make_name(hierarchy[-1], local=False), - "SIZE": fragment._depth, - "OFFSET": 0, - "ABITS": _ast.Shape.cast(range(fragment._depth)).width, - "WIDTH": fragment._width, - "INIT": init, - "RD_PORTS": len(fragment._read_ports), - "RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment._read_ports))), - "RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment._read_ports))), - "RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment._read_ports) * len(fragment._write_ports))), - "RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment._read_ports) * len(fragment._write_ports))), - "RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._read_ports))), - "RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment._read_ports))), - "RD_ARST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), - "RD_SRST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), - "RD_INIT_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), - "WR_PORTS": len(fragment._write_ports), - "WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment._write_ports))), - "WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment._write_ports))), - "WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment._write_ports) * len(fragment._write_ports))), - "WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._write_ports))), - } - def make_en(port): - if len(port._data) == 0: - return _ast.Const(0, 0) - granularity = len(port._data) // len(port._en) - return _ast.Cat(en_bit.replicate(granularity) for en_bit in port._en) - port_map = { - "\\RD_CLK": _ast.Cat(rd_clk), - "\\RD_EN": _ast.Cat(port._en for port in fragment._read_ports), - "\\RD_ARST": _ast.Const(0, len(fragment._read_ports)), - "\\RD_SRST": _ast.Const(0, len(fragment._read_ports)), - "\\RD_ADDR": _ast.Cat(port._addr for port in fragment._read_ports), - "\\RD_DATA": _ast.Cat(port._data for port in fragment._read_ports), - "\\WR_CLK": _ast.Cat(wr_clk), - "\\WR_EN": _ast.Cat(make_en(port) for port in fragment._write_ports), - "\\WR_ADDR": _ast.Cat(port._addr for port in fragment._write_ports), - "\\WR_DATA": _ast.Cat(port._data for port in fragment._write_ports), - } - return "$mem_v2", port_map, params - - module_name = ".".join(name or "anonymous" for name in hierarchy) - module_attrs = OrderedDict() - if len(hierarchy) == 1: - module_attrs["top"] = 1 - - with builder.module(module_name, attrs=module_attrs) as module: - compiler_state = _ValueCompilerState(module) - rhs_compiler = _RHSValueCompiler(compiler_state) - lhs_compiler = _LHSValueCompiler(compiler_state) - stmt_compiler = _StatementCompiler(compiler_state, rhs_compiler, lhs_compiler) - - # Register all signals driven in the current fragment. This must be done first, as it - # affects further codegen; e.g. whether \sig$next signals will be generated and used. - for domain, statements in fragment.statements.items(): - for signal in statements._lhs_signals(): - compiler_state.add_driven(signal, sync=domain != "comb") - - # Transform all signals used as ports in the current fragment eagerly and outside of - # any hierarchy, to make sure they get sensible (non-prefixed) names. - for signal in fragment.ports: - compiler_state.add_port(signal, fragment.ports[signal]) - compiler_state.resolve_curr(signal) - - # Transform all clocks clocks and resets eagerly and outside of any hierarchy, to make - # sure they get sensible (non-prefixed) names. This does not affect semantics. - for domain, _ in fragment.iter_sync(): - cd = fragment.domains[domain] - compiler_state.resolve_curr(cd.clk) - if cd.rst is not None: - compiler_state.resolve_curr(cd.rst) - - # Transform all subfragments to their respective cells. Transforming signals connected - # to their ports into wires eagerly makes sure they get sensible (prefixed with submodule - # name) names. - for subfragment, sub_name in fragment.subfragments: - if not (subfragment.ports or subfragment.statements or subfragment.subfragments or - isinstance(subfragment, (_ir.Instance, _mem.MemoryInstance))): - # If the fragment is completely empty, skip translating it, otherwise synthesis - # tools (including Yosys and Vivado) will treat it as a black box when it is - # loaded after conversion to Verilog. - continue - - if sub_name is None: - sub_name = module.anonymous() - - sub_type, sub_port_map, sub_params = \ - _convert_fragment(builder, subfragment, name_map, - hierarchy=hierarchy + (sub_name,)) - - sub_ports = OrderedDict() - for port, value in sub_port_map.items(): - if not isinstance(subfragment, (_ir.Instance, _mem.MemoryInstance)): - for signal in value._rhs_signals(): - compiler_state.resolve_curr(signal, prefix=sub_name) - if len(value) > 0 or sub_type == "$mem_v2": - sub_ports[port] = rhs_compiler(value) - - if isinstance(subfragment, _ir.Instance): - src = _src(subfragment.src_loc) - elif isinstance(subfragment, _mem.MemoryInstance): - src = _src(subfragment._src_loc) - else: - src = "" - - module.cell(sub_type, name=sub_name, ports=sub_ports, params=sub_params, - attrs=subfragment.attrs, src=src) - - # If we emit all of our combinatorial logic into a single RTLIL process, Verilog - # simulators will break horribly, because Yosys write_verilog transforms RTLIL processes - # into always @* blocks with blocking assignment, and that does not create delta cycles. +class MemoryInfo: + def __init__(self, memid): + self.memid = memid + self.num_write_ports = 0 + self.write_port_ids = {} + + +class ModuleEmitter: + def __init__(self, builder, netlist, module, name_map, empty_checker): + self.builder = builder + self.netlist = netlist + self.module = module + self.name_map = name_map + self.empty_checker = empty_checker + + # Internal state of the emitter. This conceptually consists of three parts: + # (1) memory information; + # (2) name and attribute preferences for wires corresponding to signals; + # (3) mapping of Amaranth netlist entities to RTLIL netlist entities. + # Value names are preferences: they are candidate names for values that may or may not get + # used for cell outputs. Attributes are mandatory: they are always emitted, but can be + # squashed if several signals end up aliasing the same driven wire. + self.memories = {} # cell idx -> MemoryInfo + self.value_names = {} # value -> signal or port name + self.value_attrs = {} # value -> dict + self.sigport_wires = {} # signal or port name -> (wire, value) + self.driven_sigports = set() # set of signal or port name + self.nets = {} # net -> (wire name, bit idx) + self.cell_wires = {} # cell idx -> wire name + self.instance_wires = {} # (cell idx, output name) -> wire name + + def emit(self): + self.collect_memory_info() + self.assign_value_names() + self.collect_init_attrs() + self.emit_signal_wires() + self.emit_port_wires() + self.emit_cell_wires() + self.emit_submodule_wires() + self.emit_connects() + self.emit_submodules() + self.emit_cells() + + def collect_memory_info(self): + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.Memory): + self.memories[cell_idx] = MemoryInfo( + self.builder.memory(cell.width, cell.depth, name=cell.name, + attrs=cell.attributes, src=_src(cell.src_loc))) + + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.SyncWritePort): + memory_info = self.memories[cell.memory] + memory_info.write_port_ids[cell_idx] = memory_info.num_write_ports + memory_info.num_write_ports += 1 + + def assign_value_names(self): + for signal, name in self.module.signal_names.items(): + value = self.netlist.signals[signal] + if value not in self.value_names: + self.value_names[value] = name + + def collect_init_attrs(self): + # Flip-flops are special in Yosys; the initial value is stored not as a cell parameter but + # as an attribute of a wire connected to the output of the flip-flop. The claimed benefit + # of this arrangement is that fine cells, which cannot have parameters (so that certain + # backends, like BLIF, which cannot represent parameters--or attributes--can be used to + # emit these cells), then do not need to have 3x more variants (one for initialized to 0, + # one for 1, one for X). # - # Therefore, we translate the fragment as many times as there are independent groups - # of signals (a group is a transitive closure of signals that appear together on LHS), - # splitting them into many RTLIL (and thus Verilog) processes. - for domain, statements in fragment.statements.items(): - lhs_grouper = _xfrm.LHSGroupAnalyzer() - lhs_grouper.on_statements(statements) + # At the time of writing, 2024-02-11, Yosys has 125 (one hundred twenty five) fine FF cells, + # which are generated by a Python script because they have gotten completely out of hand + # long ago and no one could keep track of them manually. This list features such beauties + # as $_DFFSRE_PPPN_ and its other 7 cousins. + # + # These are real cells, used by real Yosys developers! Look at what they have done for us, + # with all the subtly unsynthesizable Verilog we sent them and all of the incompatibilities + # with vendor toolchains we reported! + # + # Nothing is fine about these cells. The decision to have `init` as a wire attribute is + # quite possibly the single worst design decision in Yosys, and not having to dealing with + # that bullshit again is enough of a reason to implement an FPGA toolchain from scratch. + # + # Just have 375 fine cells, bro. Trust me bro. You will certainly not regret having 375 + # fine cells in your toolchain. Or at least you will be able to process netlists without + # having to special-case this one godforsaken attribute every time you look at a wire. + # + # -- @whitequark + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.FlipFlop): + width = len(cell.data) + attrs = {"init": _ast.Const(cell.init, width), **cell.attributes} + value = _nir.Value(_nir.Net.from_cell(cell_idx, bit) for bit in range(width)) + self.value_attrs[value] = attrs - for group, group_signals in lhs_grouper.groups().items(): - lhs_group_filter = _xfrm.LHSGroupFilter(group_signals) - group_stmts = lhs_group_filter(statements) + def emit_signal_wires(self): + for signal, name in self.module.signal_names.items(): + value = self.netlist.signals[signal] - with module.process(name=f"$group_{group}") as process: - with process.case() as case: - # For every signal in comb domain, assign \sig$next to the reset value. - # For every signal in sync domains, assign \sig$next to the current - # value (\sig). - for signal in group_signals: - if domain == "comb": - prev_value = _ast.Const(signal.reset, signal.width) + # One of: (1) empty and created here, (2) `init` filled in by `collect_init_attrs`, + # (3) populated by some other signal aliasing the same nets. In the last case, we will + # glue attributes for these signals together, but synthesizers (including Yosys, when + # the design is flattened) will do that anyway, so it doesn't matter. + attrs = self.value_attrs.setdefault(value, {}) + attrs.update(signal.attrs) + + for repr in signal._value_repr: + if repr.path == () and isinstance(repr.format, _repr.FormatEnum): + enum = repr.format.enum + attrs["enum_base_type"] = enum.__name__ + for enum_value in enum: + attrs["enum_value_{:0{}b}".format(enum_value.value, signal.width)] = enum_value.name + + if name in self.module.ports: + port_value, _flow = self.module.ports[name] + assert value == port_value + self.name_map[signal] = (*self.module.name, f"\\{name}") + else: + wire = self.builder.wire(width=signal.width, name=name, attrs=attrs, + src=_src(signal.src_loc)) + self.sigport_wires[name] = (wire, value) + self.name_map[signal] = (*self.module.name, wire) + + def emit_port_wires(self): + for port_id, (name, (value, flow)) in enumerate(self.module.ports.items()): + wire = self.builder.wire(width=len(value), port_id=port_id, port_kind=flow.value, + name=name, attrs=self.value_attrs.get(value, {})) + self.sigport_wires[name] = (wire, value) + if flow == _nir.ModuleNetFlow.OUTPUT: + continue + # If we just emitted an input or inout port, it is driving the value. + self.driven_sigports.add(name) + for bit, net in enumerate(value): + self.nets[net] = (wire, bit) + + def emit_driven_wire(self, value): + # Emits a wire for a value, in preparation for driving it. + if value in self.value_names: + # If there is a signal or port matching this value, reuse its wire as the canonical + # wire of the nets involved. + name = self.value_names[value] + wire, named_value = self.sigport_wires[name] + assert value == named_value, \ + f"Inconsistent values {value!r}, {named_value!r} for wire {name!r}" + self.driven_sigports.add(name) + else: + # Otherwise, make an anonymous wire. + wire = self.builder.wire(len(value), attrs=self.value_attrs.get(value, {})) + for bit, net in enumerate(value): + self.nets[net] = (wire, bit) + return wire + + def emit_cell_wires(self): + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.Top): + continue + elif isinstance(cell, _nir.Instance): + for name, (start, width) in cell.ports_o.items(): + nets = [_nir.Net.from_cell(cell_idx, start + bit) for bit in range(width)] + wire = self.emit_driven_wire(_nir.Value(nets)) + self.instance_wires[cell_idx, name] = wire + continue # Instances use one wire per output, not per cell. + elif isinstance(cell, (_nir.PriorityMatch, _nir.Matches)): + continue # Inlined into assignment lists. + elif isinstance(cell, (_nir.SyncProperty, _nir.AsyncProperty, _nir.Memory, + _nir.SyncWritePort)): + continue # No outputs. + elif isinstance(cell, _nir.AssignmentList): + width = len(cell.default) + elif isinstance(cell, (_nir.Operator, _nir.Part, _nir.ArrayMux, _nir.AnyValue, + _nir.SyncReadPort, _nir.AsyncReadPort)): + width = cell.width + elif isinstance(cell, _nir.FlipFlop): + width = len(cell.data) + elif isinstance(cell, _nir.Initial): + width = 1 + elif isinstance(cell, _nir.IOBuffer): + width = len(cell.pad) + else: + assert False # :nocov: + # Single output cell connected to a wire. + nets = [_nir.Net.from_cell(cell_idx, bit) for bit in range(width)] + wire = self.emit_driven_wire(_nir.Value(nets)) + self.cell_wires[cell_idx] = wire + + def emit_submodule_wires(self): + for submodule_idx in self.module.submodules: + submodule = self.netlist.modules[submodule_idx] + for _name, (value, flow) in submodule.ports.items(): + if flow == _nir.ModuleNetFlow.OUTPUT: + self.emit_driven_wire(value) + + def sigspec(self, *parts: '_nir.Net | Iterable[_nir.Net]'): + value = _nir.Value() + for part in parts: + value += _nir.Value(part) + + chunks = [] + begin_pos = 0 + while begin_pos < len(value): + end_pos = begin_pos + if value[begin_pos].is_const: + while end_pos < len(value) and value[end_pos].is_const: + end_pos += 1 + width = end_pos - begin_pos + bits = "".join(str(net.const) for net in value[begin_pos:end_pos]) + chunks.append(f"{width}'{bits[::-1]}") + else: + wire, start_bit = self.nets[value[begin_pos]] + bit = start_bit + while (end_pos < len(value) and + not value[end_pos].is_const and + self.nets[value[end_pos]] == (wire, bit)): + end_pos += 1 + bit += 1 + width = end_pos - begin_pos + if width == 1: + chunks.append(f"{wire} [{start_bit}]") + else: + chunks.append(f"{wire} [{start_bit + width - 1}:{start_bit}]") + begin_pos = end_pos + + if len(chunks) == 1: + return chunks[0] + return "{ " + " ".join(reversed(chunks)) + " }" + + def emit_connects(self): + for name, (wire, value) in self.sigport_wires.items(): + if name not in self.driven_sigports: + self.builder.connect(wire, self.sigspec(value)) + + def emit_submodules(self): + for submodule_idx in self.module.submodules: + submodule = self.netlist.modules[submodule_idx] + if not self.empty_checker.is_empty(submodule_idx): + dotted_name = ".".join(submodule.name) + self.builder.cell(f"\\{dotted_name}", submodule.name[-1], ports={ + name: self.sigspec(value) + for name, (value, _flow) in submodule.ports.items() + }) + + def emit_assignment_list(self, cell_idx, cell): + def emit_assignments(case, cond): + # Emits assignments from the assignment list into the given case. + # ``cond`` is the net which is the condition for ``case`` being active. + # Returns once it hits an assignment whose condition is not nested within ``cond``, + # letting parent invocation take care of the remaining assignments. + nonlocal pos + + emitted_switch = False + while pos < len(cell.assignments): + assign = cell.assignments[pos] + if assign.cond == cond and not emitted_switch: + # Not nested, and we didn't emit a switch yet, so emit the assignment. + case.assign(self.sigspec(lhs[assign.start:assign.start + len(assign.value)]), + self.sigspec(assign.value)) + pos += 1 + elif assign.cond == cond: + # Not nested, but we emitted a subswitch. Wrap the assignments in a dummy + # switch. This is necessary because Yosys executes all assignments before all + # subswitches (but allows you to mix asssignments and switches in RTLIL, for + # maximum confusion). + with case.switch("{ }") as switch: + with switch.case("") as subcase: + while pos < len(cell.assignments): + assign = cell.assignments[pos] + if assign.cond == cond: + subcase.assign(self.sigspec(lhs[assign.start:assign.start + + len(assign.value)]), + self.sigspec(assign.value)) + pos += 1 + else: + break + else: + # Condition doesn't match this case's condition — either we encountered + # a nested condition, or we should break out. Try to find out exactly + # how we are nested. + search_cond = assign.cond + while True: + if search_cond == cond: + # We have found the PriorityMatch cell that we should enter. + break + if search_cond == _nir.Net.from_const(1): + # If this isn't nested condition, go back to parent invocation. + return + # Grab the PriorityMatch cell that is on the next level of nesting. + priority_cell_idx = search_cond.cell + priority_cell = self.netlist.cells[priority_cell_idx] + assert isinstance(priority_cell, _nir.PriorityMatch) + search_cond = priority_cell.en + # We assume that: + # 1. PriorityMatch inputs can only be Match cell outputs, or constant 1. + # 2. All Match cells driving a given PriorityMatch cell test the same value. + # Grab the tested value from a random Match cell. + test = _nir.Value() + for net in priority_cell.inputs: + if net != _nir.Net.from_const(1): + matches_cell = self.netlist.cells[net.cell] + assert isinstance(matches_cell, _nir.Matches) + test = matches_cell.value + break + # Now emit cases for all PriorityMatch inputs, in sequence. Consume as many + # assignments as possible along the way. + with case.switch(self.sigspec(test)) as switch: + for bit, net in enumerate(priority_cell.inputs): + subcond = _nir.Net.from_cell(priority_cell_idx, bit) + if net == _nir.Net.from_const(1): + patterns = () else: - prev_value = signal - case.assign(lhs_compiler(signal), rhs_compiler(prev_value)) + # Validate the above assumptions. + matches_cell = self.netlist.cells[net.cell] + assert isinstance(matches_cell, _nir.Matches) + assert test == matches_cell.value + patterns = matches_cell.patterns + with switch.case(*patterns) as subcase: + emit_assignments(subcase, subcond) + emitted_switch = True - # Convert statements into decision trees. - stmt_compiler._domain = domain - stmt_compiler._case = case - stmt_compiler._has_rhs = False - stmt_compiler._wrap_assign = False - stmt_compiler(group_stmts) + lhs = _nir.Value(_nir.Net.from_cell(cell_idx, bit) for bit in range(len(cell.default))) + with self.builder.process(src=_src(cell.src_loc)) as proc: + with proc.case() as root_case: + root_case.assign(self.sigspec(lhs), self.sigspec(cell.default)) - # For every driven signal in the sync domain, create a flop of appropriate type. Which type - # is appropriate depends on the domain: for domains with sync reset, it is a $dff, for - # domains with async reset it is an $adff. The latter is directly provided with the reset - # value as a parameter to the cell, which is directly assigned during reset. - for domain, signal in fragment.iter_sync(): - cd = fragment.domains[domain] + pos = 0 # nonlocally used in `emit_assignments` + emit_assignments(root_case, _nir.Net.from_const(1)) + assert pos == len(cell.assignments) - wire_clk = compiler_state.resolve_curr(cd.clk) - wire_rst = compiler_state.resolve_curr(cd.rst) if cd.rst is not None else None - wire_curr, wire_next = compiler_state.resolve(signal) - - if not cd.async_reset: - # For sync reset flops, the reset value comes from logic inserted by - # `hdl.xfrm.DomainLowerer`. - module.cell("$dff", ports={ - "\\CLK": wire_clk, - "\\D": wire_next, - "\\Q": wire_curr + def emit_operator(self, cell_idx, cell): + UNARY_OPERATORS = { + "-": "$neg", + "~": "$not", + "b": "$reduce_bool", + "r|": "$reduce_or", + "r&": "$reduce_and", + "r^": "$reduce_xor", + } + BINARY_OPERATORS = { + # A_SIGNED, B_SIGNED + "+": ("$add", False, False), + "-": ("$sub", False, False), + "*": ("$mul", False, False), + "u//": ("$divfloor", False, False), + "s//": ("$divfloor", True, True), + "u%": ("$modfloor", False, False), + "s%": ("$modfloor", True, True), + "<<": ("$shl", False, False), + "u>>": ("$shr", False, False), + "s>>": ("$sshr", True, False), + "&": ("$and", False, False), + "|": ("$or", False, False), + "^": ("$xor", False, False), + "==": ("$eq", False, False), + "!=": ("$ne", False, False), + "u<": ("$lt", False, False), + "u>": ("$gt", False, False), + "u<=": ("$le", False, False), + "u>=": ("$ge", False, False), + "s<": ("$lt", True, True), + "s>": ("$gt", True, True), + "s<=": ("$le", True, True), + "s>=": ("$ge", True, True), + } + if len(cell.inputs) == 1: + cell_type = UNARY_OPERATORS[cell.operator] + operand, = cell.inputs + self.builder.cell(cell_type, ports={ + "A": self.sigspec(operand), + "Y": self.cell_wires[cell_idx] + }, params={ + "A_SIGNED": False, + "A_WIDTH": len(operand), + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) + elif len(cell.inputs) == 2: + cell_type, a_signed, b_signed = BINARY_OPERATORS[cell.operator] + operand_a, operand_b = cell.inputs + if cell.operator in ("u//", "s//", "u%", "s%"): + result = self.builder.wire(cell.width) + self.builder.cell(cell_type, ports={ + "A": self.sigspec(operand_a), + "B": self.sigspec(operand_b), + "Y": result, }, params={ - "CLK_POLARITY": int(cd.clk_edge == "pos"), - "WIDTH": signal.width - }) + "A_SIGNED": a_signed, + "B_SIGNED": b_signed, + "A_WIDTH": len(operand_a), + "B_WIDTH": len(operand_b), + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) + nonzero = self.builder.wire(1) + self.builder.cell("$reduce_bool", ports={ + "A": self.sigspec(operand_b), + "Y": nonzero, + }, params={ + "A_SIGNED": False, + "A_WIDTH": len(operand_b), + "Y_WIDTH": 1, + }, src=_src(cell.src_loc)) + self.builder.cell("$mux", ports={ + "S": nonzero, + "A": self.sigspec(_nir.Value.zeros(cell.width)), + "B": result, + "Y": self.cell_wires[cell_idx] + }, params={ + "WIDTH": cell.width, + }, src=_src(cell.src_loc)) else: - # For async reset flops, the reset value is provided directly to the cell. - module.cell("$adff", ports={ - "\\ARST": wire_rst, - "\\CLK": wire_clk, - "\\D": wire_next, - "\\Q": wire_curr + self.builder.cell(cell_type, ports={ + "A": self.sigspec(operand_a), + "B": self.sigspec(operand_b), + "Y": self.cell_wires[cell_idx], }, params={ - "ARST_POLARITY": _ast.Const(1), - "ARST_VALUE": _ast.Const(signal.reset, signal.width), - "CLK_POLARITY": int(cd.clk_edge == "pos"), - "WIDTH": signal.width - }) + "A_SIGNED": a_signed, + "B_SIGNED": b_signed, + "A_WIDTH": len(operand_a), + "B_WIDTH": len(operand_b), + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) + else: + assert cell.operator == "m" + condition, if_true, if_false = cell.inputs + self.builder.cell("$mux", ports={ + "S": self.sigspec(condition), + "A": self.sigspec(if_false), + "B": self.sigspec(if_true), + "Y": self.cell_wires[cell_idx] + }, params={ + "WIDTH": cell.width, + }, src=_src(cell.src_loc)) - # Any signals that are used but neither driven nor connected to an input port always - # assume their reset values. We need to assign the reset value explicitly, since only - # driven sync signals are handled by the logic above. - # - # Because this assignment is done at a late stage, a single Signal object can get assigned - # many times, once in each module it is used. This is a deliberate decision; the possible - # alternatives are to add ports for undriven signals (which requires choosing one module - # to drive it to reset value arbitrarily) or to replace them with their reset value (which - # removes valuable source location information). - driven = _ast.SignalSet() - for domain, statements in fragment.statements.items(): - driven.update(statements._lhs_signals()) - driven.update(fragment.iter_ports(dir="i")) - driven.update(fragment.iter_ports(dir="io")) - for subfragment, sub_name in fragment.subfragments: - driven.update(subfragment.iter_ports(dir="o")) - driven.update(subfragment.iter_ports(dir="io")) + def emit_part(self, cell_idx, cell): + if cell.stride == 1: + offset = self.sigspec(cell.offset) + offset_width = len(cell.offset) + else: + stride = _ast.Const(cell.stride) + offset_width = len(cell.offset) + stride.width + offset = self.builder.wire(offset_width) + self.builder.cell("$mul", ports={ + "A": self.sigspec(cell.offset), + "B": _const(stride), + "Y": offset, + }, params={ + "A_SIGNED": False, + "B_SIGNED": False, + "A_WIDTH": len(cell.offset), + "B_WIDTH": stride.width, + "Y_WIDTH": offset_width, + }, src=_src(cell.src_loc)) + self.builder.cell("$shift", ports={ + "A": self.sigspec(cell.value), + "B": offset, + "Y": self.cell_wires[cell_idx], + }, params={ + "A_SIGNED": cell.value_signed, + "B_SIGNED": False, + "A_WIDTH": len(cell.value), + "B_WIDTH": offset_width, + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) - for wire in compiler_state.wires: - if wire in driven: - continue - wire_curr, _ = compiler_state.wires[wire] - module.connect(wire_curr, rhs_compiler(_ast.Const(wire.reset, wire.width))) + def emit_array_mux(self, cell_idx, cell): + wire = self.cell_wires[cell_idx] + with self.builder.process(src=_src(cell.src_loc)) as proc: + with proc.case() as root_case: + with root_case.switch(self.sigspec(cell.index)) as switch: + for index, elem in enumerate(cell.elems): + if len(cell.index) > 0: + pattern = "{:0{}b}".format(index, len(cell.index)) + else: + pattern = "" + with switch.case(pattern) as case: + case.assign(wire, self.sigspec(elem)) + with switch.case() as case: + case.assign(wire, self.sigspec(cell.elems[0])) - # Collect the names we've given to our ports in RTLIL, and correlate these with the signals - # represented by these ports. If we are a submodule, this will be necessary to create a cell - # for us in the parent module. - port_map = OrderedDict() - for signal in fragment.ports: - port_map[compiler_state.resolve_curr(signal)] = signal + def emit_flip_flop(self, cell_idx, cell): + ports = { + "D": self.sigspec(cell.data), + "CLK": self.sigspec(cell.clk), + "Q": self.cell_wires[cell_idx] + } + params = { + "WIDTH": len(cell.data), + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge] + } + if cell.arst == _nir.Net.from_const(0): + cell_type = "$dff" + else: + cell_type = "$adff" + ports["ARST"] = self.sigspec(cell.arst) + params["ARST_POLARITY"] = True + params["ARST_VALUE"] = _ast.Const(cell.init, len(cell.data)) + self.builder.cell(cell_type, ports=ports, params=params, src=_src(cell.src_loc)) - # Finally, collect the names we've given to each wire in RTLIL, and provide these to - # the caller, to allow manipulating them in the toolchain. - for signal in compiler_state.wires: - wire_name = compiler_state.resolve_curr(signal) - if wire_name.startswith("\\"): - wire_name = wire_name[1:] - name_map[signal] = hierarchy + (wire_name,) + def emit_io_buffer(self, cell_idx, cell): + self.builder.cell("$tribuf", ports={ + "Y": self.sigspec(cell.pad), + "A": self.sigspec(cell.o), + "EN": self.sigspec(cell.oe), + }, params={ + "WIDTH": len(cell.pad), + }, src=_src(cell.src_loc)) + self.builder.connect(self.cell_wires[cell_idx], self.sigspec(cell.pad)) - return module.name, port_map, {} + def emit_memory(self, cell_idx, cell): + memory_info = self.memories[cell_idx] + self.builder.cell("$meminit_v2", ports={ + "ADDR": self.sigspec(), + "DATA": self.sigspec( + _nir.Net.from_const((row >> bit) & 1) + for row in cell.init + for bit in range(cell.width) + ), + "EN": self.sigspec(_nir.Value.ones(cell.width)), + }, params={ + "MEMID": memory_info.memid, + "ABITS": 0, + "WIDTH": cell.width, + "WORDS": cell.depth, + "PRIORITY": 0, + }, src=_src(cell.src_loc)) + + def emit_write_port(self, cell_idx, cell): + memory_info = self.memories[cell.memory] + ports = { + "ADDR": self.sigspec(cell.addr), + "DATA": self.sigspec(cell.data), + "EN": self.sigspec(cell.en), + "CLK": self.sigspec(cell.clk), + } + params = { + "MEMID": memory_info.memid, + "ABITS": len(cell.addr), + "WIDTH": len(cell.data), + "CLK_ENABLE": True, + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge], + "PORTID": memory_info.write_port_ids[cell_idx], + "PRIORITY_MASK": 0, + } + self.builder.cell(f"$memwr_v2", ports=ports, params=params, src=_src(cell.src_loc)) + + def emit_read_port(self, cell_idx, cell): + memory_info = self.memories[cell.memory] + ports = { + "ADDR": self.sigspec(cell.addr), + "DATA": self.cell_wires[cell_idx], + "ARST": self.sigspec(_nir.Net.from_const(0)), + "SRST": self.sigspec(_nir.Net.from_const(0)), + } + if isinstance(cell, _nir.AsyncReadPort): + transparency_mask = 0 + if isinstance(cell, _nir.SyncReadPort): + transparency_mask = sum( + 1 << memory_info.write_port_ids[write_port_cell_index] + for write_port_cell_index in cell.transparent_for + ) + params = { + "MEMID": memory_info.memid, + "ABITS": len(cell.addr), + "WIDTH": cell.width, + "TRANSPARENCY_MASK": _ast.Const(transparency_mask, memory_info.num_write_ports), + "COLLISION_X_MASK": _ast.Const(0, memory_info.num_write_ports), + "ARST_VALUE": _ast.Const(0, cell.width), + "SRST_VALUE": _ast.Const(0, cell.width), + "INIT_VALUE": _ast.Const(0, cell.width), + "CE_OVER_SRST": False, + } + if isinstance(cell, _nir.AsyncReadPort): + ports.update({ + "EN": self.sigspec(_nir.Net.from_const(1)), + "CLK": self.sigspec(_nir.Net.from_const(0)), + }) + params.update({ + "CLK_ENABLE": False, + "CLK_POLARITY": True, + }) + if isinstance(cell, _nir.SyncReadPort): + ports.update({ + "EN": self.sigspec(cell.en), + "CLK": self.sigspec(cell.clk), + }) + params.update({ + "CLK_ENABLE": True, + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge], + }) + self.builder.cell(f"$memrd_v2", ports=ports, params=params, src=_src(cell.src_loc)) + + def emit_property(self, cell_idx, cell): + if isinstance(cell, _nir.AsyncProperty): + ports = { + "A": self.sigspec(cell.test), + "EN": self.sigspec(cell.en), + } + if isinstance(cell, _nir.SyncProperty): + test = self.builder.wire(1, attrs={"init": _ast.Const(0, 1)}) + en = self.builder.wire(1, attrs={"init": _ast.Const(0, 1)}) + for (d, q) in [ + (cell.test, test), + (cell.en, en), + ]: + ports = { + "D": self.sigspec(d), + "Q": q, + "CLK": self.sigspec(cell.clk), + } + params = { + "WIDTH": 1, + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge], + } + self.builder.cell(f"$dff", ports=ports, params=params, src=_src(cell.src_loc)) + ports = { + "A": test, + "EN": en, + } + self.builder.cell(f"${cell.kind}", name=cell.name, ports=ports, src=_src(cell.src_loc)) + + def emit_any_value(self, cell_idx, cell): + self.builder.cell(f"${cell.kind}", ports={ + "Y": self.cell_wires[cell_idx], + }, params={ + "WIDTH": cell.width, + }, src=_src(cell.src_loc)) + + def emit_initial(self, cell_idx, cell): + self.builder.cell("$initstate", ports={ + "Y": self.cell_wires[cell_idx], + }, src=_src(cell.src_loc)) + + def emit_instance(self, cell_idx, cell): + ports = {} + for name, nets in cell.ports_i.items(): + ports[name] = self.sigspec(nets) + for name in cell.ports_o: + ports[name] = self.instance_wires[cell_idx, name] + for name, nets in cell.ports_io.items(): + ports[name] = self.sigspec(nets) + if cell.type.startswith("$"): + type = cell.type + else: + type = "\\" + cell.type + self.builder.cell(type, cell.name, ports=ports, params=cell.parameters, + attrs=cell.attributes, src=_src(cell.src_loc)) + + def emit_cells(self): + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.Top): + pass + elif isinstance(cell, _nir.Matches): + pass # Matches is only referenced from PriorityMatch cells and inlined there + elif isinstance(cell, _nir.PriorityMatch): + pass # PriorityMatch is only referenced from AssignmentList cells and inlined there + elif isinstance(cell, _nir.AssignmentList): + self.emit_assignment_list(cell_idx, cell) + elif isinstance(cell, _nir.Operator): + self.emit_operator(cell_idx, cell) + elif isinstance(cell, _nir.Part): + self.emit_part(cell_idx, cell) + elif isinstance(cell, _nir.ArrayMux): + self.emit_array_mux(cell_idx, cell) + elif isinstance(cell, _nir.FlipFlop): + self.emit_flip_flop(cell_idx, cell) + elif isinstance(cell, _nir.IOBuffer): + self.emit_io_buffer(cell_idx, cell) + elif isinstance(cell, _nir.Memory): + self.emit_memory(cell_idx, cell) + elif isinstance(cell, _nir.SyncWritePort): + self.emit_write_port(cell_idx, cell) + elif isinstance(cell, (_nir.AsyncReadPort, _nir.SyncReadPort)): + self.emit_read_port(cell_idx, cell) + elif isinstance(cell, (_nir.AsyncProperty, _nir.SyncProperty)): + self.emit_property(cell_idx, cell) + elif isinstance(cell, _nir.AnyValue): + self.emit_any_value(cell_idx, cell) + elif isinstance(cell, _nir.Initial): + self.emit_initial(cell_idx, cell) + elif isinstance(cell, _nir.Instance): + self.emit_instance(cell_idx, cell) + else: + assert False # :nocov: + + +# Empty modules are interpreted by some toolchains (Yosys, Xilinx, ...) as black boxes, and +# must not be emitted. +class EmptyModuleChecker: + def __init__(self, netlist): + self.netlist = netlist + self.empty = set() + self.check(0) + + def check(self, module_idx): + is_empty = not self.netlist.modules[module_idx].cells + for submodule in self.netlist.modules[module_idx].submodules: + is_empty &= self.check(submodule) + if is_empty: + self.empty.add(module_idx) + return is_empty + + def is_empty(self, module_idx): + return module_idx in self.empty def convert_fragment(fragment, name="top", *, emit_src=True): assert isinstance(fragment, _ir.Fragment) - builder = _Builder(emit_src=emit_src) name_map = _ast.SignalDict() - _convert_fragment(builder, fragment, name_map, hierarchy=(name,)) + netlist = _ir.build_netlist(fragment, name=name) + empty_checker = EmptyModuleChecker(netlist) + builder = _Builder(emit_src=emit_src) + for module_idx, module in enumerate(netlist.modules): + if empty_checker.is_empty(module_idx): + continue + attrs = {"top": 1} if module_idx == 0 else {} + with builder.module(".".join(module.name), attrs=attrs) as module_builder: + ModuleEmitter(module_builder, netlist, module, name_map, + empty_checker=empty_checker).emit() return str(builder), name_map @@ -1047,7 +1002,7 @@ def convert(elaboratable, name="top", platform=None, *, ports=None, emit_src=Tru hasattr(elaboratable, "signature") and isinstance(elaboratable.signature, wiring.Signature)): ports = [] - for path, member, value in elaboratable.signature.flatten(elaboratable): + for _path, _member, value in elaboratable.signature.flatten(elaboratable): if isinstance(value, _ast.ValueCastable): value = value.as_value() if isinstance(value, _ast.Value): @@ -1055,5 +1010,5 @@ def convert(elaboratable, name="top", platform=None, *, ports=None, emit_src=Tru elif ports is None: raise TypeError("The `convert()` function requires a `ports=` argument") fragment = _ir.Fragment.get(elaboratable, platform).prepare(ports=ports, **kwargs) - il_text, name_map = convert_fragment(fragment, name, emit_src=emit_src) + il_text, _name_map = convert_fragment(fragment, name, emit_src=emit_src) return il_text diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 25c557b..7abe4e3 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1770,25 +1770,17 @@ class Property(Statement, MustUse): Assume = "assume" Cover = "cover" - def __init__(self, kind, test, *, _check=None, _en=None, name=None, src_loc_at=0): + def __init__(self, kind, test, *, name=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) self.kind = self.Kind(kind) self.test = Value.cast(test) - self._check = _check - self._en = _en self.name = name if not isinstance(self.name, str) and self.name is not None: raise TypeError("Property name must be a string or None, not {!r}" .format(self.name)) - if self._check is None: - self._check = Signal(reset_less=True, name=f"${self.kind.value}$check") - self._check.src_loc = self.src_loc - if _en is None: - self._en = Signal(reset_less=True, name=f"${self.kind.value}$en") - self._en.src_loc = self.src_loc def _lhs_signals(self): - return SignalSet((self._en, self._check)) + return set() def _rhs_signals(self): return self.test._rhs_signals() diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 8fdc919..5138de0 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -1,20 +1,17 @@ -from abc import ABCMeta +from typing import Tuple from collections import defaultdict, OrderedDict from functools import reduce import warnings -from .. import tracer -from .._utils import * -from .._unused import * -from ._ast import * -from ._ast import _StatementList -from ._cd import * +from .._utils import flatten, memoize +from .. import tracer, _unused +from . import _ast, _cd, _ir, _nir __all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"] -class UnusedElaboratable(UnusedMustUse): +class UnusedElaboratable(_unused.UnusedMustUse): # The warning is initially silenced. If everything that has been constructed remains unused, # it means the application likely crashed (with an exception, or in another way that does not # call `sys.excepthook`), and it's not necessary to show any warnings. @@ -22,7 +19,7 @@ class UnusedElaboratable(UnusedMustUse): _MustUse__silence = True -class Elaboratable(MustUse): +class Elaboratable(_unused.MustUse): _MustUse__warning = UnusedElaboratable @@ -64,7 +61,7 @@ class Fragment: obj = new_obj def __init__(self): - self.ports = SignalDict() + self.ports = _ast.SignalDict() self.drivers = OrderedDict() self.statements = {} self.domains = OrderedDict() @@ -89,7 +86,7 @@ class Fragment: def add_driver(self, signal, domain="comb"): assert isinstance(domain, str) if domain not in self.drivers: - self.drivers[domain] = SignalSet() + self.drivers[domain] = _ast.SignalSet() self.drivers[domain].add(signal) def iter_drivers(self): @@ -109,7 +106,7 @@ class Fragment: yield domain, signal def iter_signals(self): - signals = SignalSet() + signals = _ast.SignalSet() signals |= self.ports.keys() for domain, domain_signals in self.drivers.items(): if domain != "comb": @@ -122,7 +119,7 @@ class Fragment: def add_domains(self, *domains): for domain in flatten(domains): - assert isinstance(domain, ClockDomain) + assert isinstance(domain, _cd.ClockDomain) assert domain.name not in self.domains self.domains[domain.name] = domain @@ -131,9 +128,9 @@ class Fragment: def add_statements(self, domain, *stmts): assert isinstance(domain, str) - for stmt in Statement.cast(stmts): + for stmt in _ast.Statement.cast(stmts): stmt._MustUse__used = True - self.statements.setdefault(domain, _StatementList()).append(stmt) + self.statements.setdefault(domain, _ast._StatementList()).append(stmt) def add_subfragment(self, subfragment, name=None): assert isinstance(subfragment, Fragment) @@ -186,15 +183,15 @@ class Fragment: assert mode in ("silent", "warn", "error") from ._mem import MemoryInstance - driver_subfrags = SignalDict() + driver_subfrags = _ast.SignalDict() def add_subfrag(registry, entity, entry): # Because of missing domain insertion, at the point when this code runs, we have # a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to # the actual signals (because the signal itself can be driven as well); but leave # the unbound ones as it is, because there's no concrete signal for it yet anyway. - if isinstance(entity, ClockSignal) and entity.domain in self.domains: + if isinstance(entity, _ast.ClockSignal) and entity.domain in self.domains: entity = self.domains[entity.domain].clk - elif isinstance(entity, ResetSignal) and entity.domain in self.domains: + elif isinstance(entity, _ast.ResetSignal) and entity.domain in self.domains: entity = self.domains[entity.domain].rst if entity not in registry: @@ -265,7 +262,7 @@ class Fragment: return self._resolve_hierarchy_conflicts(hierarchy, mode) # Nothing was flattened, we're done! - return SignalSet(driver_subfrags.keys()) + return _ast.SignalSet(driver_subfrags.keys()) def _propagate_domains_up(self, hierarchy=("top",)): from ._xfrm import DomainRenamer @@ -296,18 +293,18 @@ class Fragment: if not all(names): names = sorted(f"" if n is None else f"'{n}'" for f, n, i in subfrags) - raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; " - "it is necessary to either rename subfragment domains " - "explicitly, or give names to subfragments" - .format(domain_name, ", ".join(names), ".".join(hierarchy))) + raise _cd.DomainError( + "Domain '{}' is defined by subfragments {} of fragment '{}'; it is necessary " + "to either rename subfragment domains explicitly, or give names to subfragments" + .format(domain_name, ", ".join(names), ".".join(hierarchy))) if len(names) != len(set(names)): names = sorted(f"#{i}" for f, n, i in subfrags) - raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', " - "some of which have identical names; it is necessary to either " - "rename subfragment domains explicitly, or give distinct names " - "to subfragments" - .format(domain_name, ", ".join(names), ".".join(hierarchy))) + raise _cd.DomainError( + "Domain '{}' is defined by subfragments {} of fragment '{}', some of which " + "have identical names; it is necessary to either rename subfragment domains " + "explicitly, or give distinct names to subfragments" + .format(domain_name, ", ".join(names), ".".join(hierarchy))) for subfrag, name, i in subfrags: domain_name_map = {domain_name: f"{name}_{domain_name}"} @@ -343,8 +340,8 @@ class Fragment: continue value = missing_domain(domain_name) if value is None: - raise DomainError(f"Domain '{domain_name}' is used but not defined") - if type(value) is ClockDomain: + raise _cd.DomainError(f"Domain '{domain_name}' is used but not defined") + if type(value) is _cd.ClockDomain: self.add_domains(value) # And expose ports on the newly added clock domain, since it is added directly # and there was no chance to add any logic driving it. @@ -353,7 +350,7 @@ class Fragment: new_fragment = Fragment.get(value, platform=platform) if domain_name not in new_fragment.domains: defined = new_fragment.domains.keys() - raise DomainError( + raise _cd.DomainError( "Fragment returned by missing domain callback does not define " "requested domain '{}' (defines {})." .format(domain_name, ", ".join(f"'{n}'" for n in defined))) @@ -463,12 +460,12 @@ class Fragment: parent = {self: None} level = {self: 0} - uses = SignalDict() - defs = SignalDict() - ios = SignalDict() + uses = _ast.SignalDict() + defs = _ast.SignalDict() + ios = _ast.SignalDict() self._prepare_use_def_graph(parent, level, uses, defs, ios, self) - ports = SignalSet(ports) + ports = _ast.SignalSet(ports) if all_undef_as_ports: for sig in uses: if sig in defs: @@ -530,7 +527,7 @@ class Fragment: else: self.add_ports(sig, dir="i") - def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)): + def prepare(self, ports=None, missing_domain=lambda name: _cd.ClockDomain(name)): from ._xfrm import DomainLowerer new_domains = self._propagate_domains(missing_domain) @@ -541,14 +538,14 @@ class Fragment: if not isinstance(ports, tuple) and not isinstance(ports, list): msg = "`ports` must be either a list or a tuple, not {!r}"\ .format(ports) - if isinstance(ports, Value): + if isinstance(ports, _ast.Value): msg += " (did you mean `ports=(,)`, rather than `ports=`?)" raise TypeError(msg) mapped_ports = [] # Lower late bound signals like ClockSignal() to ports. port_lowerer = DomainLowerer(fragment.domains) for port in ports: - if not isinstance(port, (Signal, ClockSignal, ResetSignal)): + if not isinstance(port, (_ast.Signal, _ast.ClockSignal, _ast.ResetSignal)): raise TypeError("Only signals may be added as ports, not {!r}" .format(port)) mapped_ports.append(port_lowerer.on_value(port)) @@ -573,7 +570,7 @@ class Fragment: may get a different name. """ - signal_names = SignalDict() + signal_names = _ast.SignalDict() assigned_names = set() def add_signal_name(signal): @@ -599,7 +596,7 @@ class Fragment: for statements in self.statements.values(): for statement in statements: for signal in statement._lhs_signals() | statement._rhs_signals(): - if not isinstance(signal, (ClockSignal, ResetSignal)): + if not isinstance(signal, (_ast.ClockSignal, _ast.ResetSignal)): add_signal_name(signal) return signal_names @@ -657,7 +654,7 @@ class Instance(Fragment): elif kind == "p": self.parameters[name] = value elif kind in ("i", "o", "io"): - self.named_ports[name] = (Value.cast(value), kind) + self.named_ports[name] = (_ast.Value.cast(value), kind) else: raise NameError("Instance argument {!r} should be a tuple (kind, name, value) " "where kind is one of \"a\", \"p\", \"i\", \"o\", or \"io\"" @@ -669,12 +666,763 @@ class Instance(Fragment): elif kw.startswith("p_"): self.parameters[kw[2:]] = arg elif kw.startswith("i_"): - self.named_ports[kw[2:]] = (Value.cast(arg), "i") + self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "i") elif kw.startswith("o_"): - self.named_ports[kw[2:]] = (Value.cast(arg), "o") + self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "o") elif kw.startswith("io_"): - self.named_ports[kw[3:]] = (Value.cast(arg), "io") + self.named_ports[kw[3:]] = (_ast.Value.cast(arg), "io") else: raise NameError("Instance keyword argument {}={!r} does not start with one of " "\"a_\", \"p_\", \"i_\", \"o_\", or \"io_\"" .format(kw, arg)) + + +############################################################################################### >:3 + + +class NetlistDriver: + def __init__(self, module_idx: int, signal: _ast.Signal, + domain: '_cd.ClockDomain | None', *, src_loc): + self.module_idx = module_idx + self.signal = signal + self.domain = domain + self.src_loc = src_loc + self.assignments = [] + + def emit_value(self, builder): + if self.domain is None: + reset = _ast.Const(self.signal.reset, self.signal.width) + default, _signed = builder.emit_rhs(self.module_idx, reset) + else: + default = builder.emit_signal(self.signal) + if len(self.assignments) == 1: + assign, = self.assignments + if assign.cond == 1 and assign.start == 0 and len(assign.value) == len(default): + return assign.value + cell = _nir.AssignmentList(self.module_idx, default=default, assignments=self.assignments, + src_loc=self.signal.src_loc) + return builder.netlist.add_value_cell(len(default), cell) + + +class NetlistEmitter: + def __init__(self, netlist: _nir.Netlist, fragment_names: 'dict[_ir.Fragment, str]'): + self.netlist = netlist + self.fragment_names = fragment_names + self.drivers = _ast.SignalDict() + self.rhs_cache: dict[int, Tuple[_nir.Value, bool, _ast.Value]] = {} + + # Collected for driver conflict diagnostics only. + self.late_net_to_signal = {} + self.connect_src_loc = {} + + def emit_signal(self, signal) -> _nir.Value: + if signal in self.netlist.signals: + return self.netlist.signals[signal] + value = self.netlist.alloc_late_value(len(signal)) + self.netlist.signals[signal] = value + for bit, net in enumerate(value): + self.late_net_to_signal[net] = (signal, bit) + return value + + # Used for instance outputs and read port data, not used for actual assignments. + def emit_lhs(self, value: _ast.Value): + if isinstance(value, _ast.Signal): + return self.emit_signal(value) + elif isinstance(value, _ast.Cat): + result = [] + for part in value.parts: + result += self.emit_lhs(part) + return _nir.Value(result) + elif isinstance(value, _ast.Slice): + return self.emit_lhs(value.value)[value.start:value.stop] + elif isinstance(value, _ast.Operator): + assert value.operator in ('u', 's') + return self.emit_lhs(value.operands[0]) + else: + raise TypeError # :nocov: + + def extend(self, value: _nir.Value, signed: bool, width: int): + nets = list(value) + while len(nets) < width: + if signed: + nets.append(nets[-1]) + else: + nets.append(_nir.Net.from_const(0)) + return _nir.Value(nets) + + def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc): + op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc) + return self.netlist.add_value_cell(op.width, op) + + def unify_shapes_bitwise(self, + operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool): + if signed_a == signed_b: + width = max(len(operand_a), len(operand_b)) + elif signed_a: + width = max(len(operand_a), len(operand_b) + 1) + else: # signed_b + width = max(len(operand_a) + 1, len(operand_b)) + operand_a = self.extend(operand_a, signed_a, width) + operand_b = self.extend(operand_b, signed_b, width) + signed = signed_a or signed_b + return (operand_a, operand_b, signed) + + def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]: + """Emits a RHS value, returns a tuple of (value, is_signed)""" + try: + result, signed, value = self.rhs_cache[id(value)] + return result, signed + except KeyError: + pass + if isinstance(value, _ast.Const): + result = _nir.Value( + _nir.Net.from_const((value.value >> bit) & 1) + for bit in range(value.width) + ) + signed = value.signed + elif isinstance(value, _ast.Signal): + result = self.emit_signal(value) + signed = value.signed + elif isinstance(value, _ast.Operator): + if len(value.operands) == 1: + operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0]) + if value.operator == 's': + result = operand_a + signed = True + elif value.operator == 'u': + result = operand_a + signed = False + elif value.operator == '+': + result = operand_a + signed = signed_a + elif value.operator == '-': + operand_a = self.extend(operand_a, signed_a, len(operand_a) + 1) + result = self.emit_operator(module_idx, '-', operand_a, + src_loc=value.src_loc) + signed = True + elif value.operator == '~': + result = self.emit_operator(module_idx, '~', operand_a, + src_loc=value.src_loc) + signed = signed_a + elif value.operator in ('b', 'r|', 'r&', 'r^'): + result = self.emit_operator(module_idx, value.operator, operand_a, + src_loc=value.src_loc) + signed = False + else: + assert False # :nocov: + elif len(value.operands) == 2: + operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0]) + operand_b, signed_b = self.emit_rhs(module_idx, value.operands[1]) + if value.operator in ('|', '&', '^'): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, value.operator, operand_a, operand_b, + src_loc=value.src_loc) + elif value.operator in ('+', '-'): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + width = len(operand_a) + 1 + operand_a = self.extend(operand_a, signed, width) + operand_b = self.extend(operand_b, signed, width) + result = self.emit_operator(module_idx, value.operator, operand_a, operand_b, + src_loc=value.src_loc) + if value.operator == '-': + signed = True + elif value.operator == '*': + width = len(operand_a) + len(operand_b) + operand_a = self.extend(operand_a, signed_a, width) + operand_b = self.extend(operand_b, signed_b, width) + result = self.emit_operator(module_idx, '*', operand_a, operand_b, + src_loc=value.src_loc) + signed = signed_a or signed_b + elif value.operator == '//': + width = len(operand_a) + signed_b + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + if len(operand_a) < width: + operand_a = self.extend(operand_a, signed, width) + operand_b = self.extend(operand_b, signed, width) + operator = 's//' if signed else 'u//' + result = _nir.Value( + self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc)[:width] + ) + elif value.operator == '%': + width = len(operand_b) + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + operator = 's%' if signed else 'u%' + result = _nir.Value( + self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc)[:width] + ) + signed = signed_b + elif value.operator == '<<': + operand_a = self.extend(operand_a, signed_a, + len(operand_a) + 2 ** len(operand_b) - 1) + result = self.emit_operator(module_idx, '<<', operand_a, operand_b, + src_loc=value.src_loc) + signed = signed_a + elif value.operator == '>>': + operator = 's>>' if signed_a else 'u>>' + result = self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc) + signed = signed_a + elif value.operator in ('==', '!='): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, value.operator, operand_a, operand_b, + src_loc=value.src_loc) + signed = False + elif value.operator in ('<', '>', '<=', '>='): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + operator = ('s' if signed else 'u') + value.operator + result = self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc) + signed = False + else: + assert False # :nocov: + elif len(value.operands) == 3: + assert value.operator == 'm' + operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0]) + operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1]) + operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2]) + if len(operand_s) != 1: + operand_s = self.emit_operator(module_idx, 'b', operand_s, + src_loc=value.src_loc) + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b, + src_loc=value.src_loc) + else: + assert False # :nocov: + elif isinstance(value, _ast.Slice): + inner, _signed = self.emit_rhs(module_idx, value.value) + result = _nir.Value(inner[value.start:value.stop]) + signed = False + elif isinstance(value, _ast.Part): + inner, signed = self.emit_rhs(module_idx, value.value) + offset, _signed = self.emit_rhs(module_idx, value.offset) + cell = _nir.Part(module_idx, value=inner, value_signed=signed, width=value.width, + stride=value.stride, offset=offset, src_loc=value.src_loc) + result = self.netlist.add_value_cell(value.width, cell) + signed = False + elif isinstance(value, _ast.ArrayProxy): + elems = [self.emit_rhs(module_idx, elem) for elem in value.elems] + width = 0 + signed = False + for elem, elem_signed in elems: + if elem_signed: + if not signed: + width += 1 + signed = True + width = max(width, len(elem)) + elif signed: + width = max(width, len(elem) + 1) + else: + width = max(width, len(elem)) + elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems) + index, _signed = self.emit_rhs(module_idx, value.index) + cell = _nir.ArrayMux(module_idx, width=width, elems=elems, index=index, + src_loc=value.src_loc) + result = self.netlist.add_value_cell(width, cell) + elif isinstance(value, _ast.Cat): + nets = [] + for val in value.parts: + inner, _signed = self.emit_rhs(module_idx, val) + for net in inner: + nets.append(net) + result = _nir.Value(nets) + signed = False + elif isinstance(value, _ast.AnyValue): + result = self.netlist.add_value_cell(value.width, + _nir.AnyValue(module_idx, kind=value.kind.value, width=value.width, + src_loc=value.src_loc)) + signed = value.signed + elif isinstance(value, _ast.Initial): + result = self.netlist.add_value_cell(1, _nir.Initial(module_idx, src_loc=value.src_loc)) + signed = False + else: + assert False # :nocov: + assert value.shape().width == len(result), \ + f"Value {value!r} with shape {value.shape()!r} does not match " \ + f"result with width {len(result)}" + # Add the value itself to the cache to make sure `id(value)` remains allocated and pointing + # at `value`. This would be a weakref.WeakKeyDictionary if `value` was hashable. + self.rhs_cache[id(value)] = result, signed, value + return (result, signed) + + def connect(self, lhs: _nir.Value, rhs: _nir.Value, *, src_loc): + assert len(lhs) == len(rhs) + for left, right in zip(lhs, rhs): + if left in self.netlist.connections: + signal, bit = self.late_net_to_signal[left] + other_src_loc = self.connect_src_loc[left] + raise _ir.DriverConflict(f"Bit {bit} of signal {signal!r} has multiple drivers: " + f"{other_src_loc} and {src_loc}") + self.netlist.connections[left] = right + self.connect_src_loc[left] = src_loc + + def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str, + stmt: _ast.Statement, cond: _nir.Net): + if domain == "comb": + cd: _cd.ClockDomain | None = None + else: + cd = fragment.domains[domain] + if isinstance(stmt, _ast.Assign): + if isinstance(stmt.lhs, _ast.Signal): + signal = stmt.lhs + start = 0 + width = signal.width + elif isinstance(stmt.lhs, _ast.Slice): + signal = stmt.lhs.value + start = stmt.lhs.start + width = stmt.lhs.stop - stmt.lhs.start + else: + assert False # :nocov: + assert isinstance(signal, _ast.Signal) + if signal in self.drivers: + driver = self.drivers[signal] + if driver.domain is not cd: + raise _ir.DriverConflict( + f"Signal {signal} driven from domain {cd} at {stmt.src_loc} and domain " + f"{driver.domain} at {driver.src_loc}") + if driver.module_idx != module_idx: + mod_name = ".".join(self.netlist.modules[module_idx].name or ("",)) + other_mod_name = \ + ".".join(self.netlist.modules[driver.module_idx].name or ("",)) + raise _ir.DriverConflict( + f"Signal {signal} driven from module {mod_name} at {stmt.src_loc} and " + f"module {other_mod_name} at {driver.src_loc}") + else: + driver = NetlistDriver(module_idx, signal, domain=cd, src_loc=stmt.src_loc) + self.drivers[signal] = driver + rhs, signed = self.emit_rhs(module_idx, stmt.rhs) + if len(rhs) > width: + rhs = _nir.Value(rhs[:width]) + if len(rhs) < width: + rhs = self.extend(rhs, signed, width) + driver.assignments.append(_nir.Assignment(cond=cond, start=start, value=rhs, + src_loc=stmt.src_loc)) + elif isinstance(stmt, _ast.Property): + test, _signed = self.emit_rhs(module_idx, stmt.test) + if len(test) != 1: + test = self.emit_operator(module_idx, 'b', test, src_loc=stmt.src_loc) + test, = test + en_cell = _nir.AssignmentList(module_idx, + default=_nir.Value.zeros(), + assignments=[ + _nir.Assignment(cond=cond, start=0, value=_nir.Value.ones(), + src_loc=stmt.src_loc) + ], + src_loc=stmt.src_loc) + cond, = self.netlist.add_value_cell(1, en_cell) + if cd is None: + cell = _nir.AsyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond, + name=stmt.name, src_loc=stmt.src_loc) + else: + clk, = self.emit_signal(cd.clk) + cell = _nir.SyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond, + clk=clk, clk_edge=cd.clk_edge, name=stmt.name, + src_loc=stmt.src_loc) + self.netlist.add_cell(cell) + elif isinstance(stmt, _ast.Switch): + test, _signed = self.emit_rhs(module_idx, stmt.test) + conds = [] + for patterns in stmt.cases: + if patterns: + for pattern in patterns: + assert len(pattern) == len(test) + cell = _nir.Matches(module_idx, value=test, patterns=patterns, + src_loc=stmt.case_src_locs.get(patterns)) + net, = self.netlist.add_value_cell(1, cell) + conds.append(net) + else: + conds.append(_nir.Net.from_const(1)) + cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds), + src_loc=stmt.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) + for subcond, substmts in zip(conds, stmt.cases.values()): + for substmt in substmts: + self.emit_stmt(module_idx, fragment, domain, substmt, subcond) + else: + assert False # :nocov: + + def emit_tribuf(self, module_idx: int, instance: _ir.Instance): + pad = self.emit_lhs(instance.named_ports["Y"][0]) + o, _signed = self.emit_rhs(module_idx, instance.named_ports["A"][0]) + (oe,), _signed = self.emit_rhs(module_idx, instance.named_ports["EN"][0]) + assert len(pad) == len(o) + cell = _nir.IOBuffer(module_idx, pad=pad, o=o, oe=oe, src_loc=instance.src_loc) + self.netlist.add_cell(cell) + + def emit_memory(self, module_idx: int, fragment: '_mem.MemoryInstance', name: str): + cell = _nir.Memory(module_idx, + width=fragment._width, + depth=fragment._depth, + init=fragment._init, + name=name, + attributes=fragment._attrs, + src_loc=fragment._src_loc, + ) + return self.netlist.add_cell(cell) + + def emit_write_port(self, module_idx: int, fragment: '_mem.MemoryInstance', + port: '_mem.MemoryInstance._WritePort', memory: int): + data, _signed = self.emit_rhs(module_idx, port._data) + addr, _signed = self.emit_rhs(module_idx, port._addr) + en, _signed = self.emit_rhs(module_idx, port._en) + en = _nir.Value([en[bit // port._granularity] for bit in range(len(port._data))]) + cd = fragment.domains[port._domain] + clk, = self.emit_signal(cd.clk) + cell = _nir.SyncWritePort(module_idx, + memory=memory, + data=data, + addr=addr, + en=en, + clk=clk, + clk_edge=cd.clk_edge, + src_loc=port._data.src_loc, + ) + return self.netlist.add_cell(cell) + + def emit_read_port(self, module_idx: int, fragment: '_mem.MemoryInstance', + port: '_mem.MemoryInstance._ReadPort', memory: int, + write_ports: 'list[int]'): + addr, _signed = self.emit_rhs(module_idx, port._addr) + if port._domain == "comb": + cell = _nir.AsyncReadPort(module_idx, + memory=memory, + width=len(port._data), + addr=addr, + src_loc=port._data.src_loc, + ) + else: + (en,), _signed = self.emit_rhs(module_idx, port._en) + cd = fragment.domains[port._domain] + clk, = self.emit_signal(cd.clk) + cell = _nir.SyncReadPort(module_idx, + memory=memory, + width=len(port._data), + addr=addr, + en=en, + clk=clk, + clk_edge=cd.clk_edge, + transparent_for=tuple(write_ports[idx] for idx in port._transparency), + src_loc=port._data.src_loc, + ) + data = self.netlist.add_value_cell(len(port._data), cell) + self.connect(self.emit_lhs(port._data), data, src_loc=port._data.src_loc) + + def emit_instance(self, module_idx: int, instance: _ir.Instance, name: str): + ports_i = {} + ports_o = {} + ports_io = {} + outputs = [] + next_output_bit = 0 + for port_name, (port_conn, dir) in instance.named_ports.items(): + if dir == 'i': + ports_i[port_name], _signed = self.emit_rhs(module_idx, port_conn) + elif dir == 'o': + port_conn = self.emit_lhs(port_conn) + ports_o[port_name] = (next_output_bit, len(port_conn)) + outputs.append((next_output_bit, port_conn)) + next_output_bit += len(port_conn) + elif dir == 'io': + ports_io[port_name] = self.emit_lhs(port_conn) + else: + assert False # :nocov: + cell = _nir.Instance(module_idx, + type=instance.type, + name=name, + parameters=instance.parameters, + attributes=instance.attrs, + ports_i=ports_i, + ports_o=ports_o, + ports_io=ports_io, + src_loc=instance.src_loc, + ) + output_nets = self.netlist.add_value_cell(width=next_output_bit, cell=cell) + for start_bit, port_conn in outputs: + self.connect(port_conn, _nir.Value(output_nets[start_bit:start_bit + len(port_conn)]), + src_loc=instance.src_loc) + + def emit_top_ports(self, fragment: _ir.Fragment, signal_names: _ast.SignalDict): + next_input_bit = 2 # 0 and 1 are reserved for constants + top = self.netlist.top + for signal, dir in fragment.ports.items(): + assert signal not in self.netlist.signals + name = signal_names[signal] + if dir == 'i': + top.ports_i[name] = (next_input_bit, signal.width) + nets = _nir.Value( + _nir.Net.from_cell(0, bit) + for bit in range(next_input_bit, next_input_bit + signal.width) + ) + next_input_bit += signal.width + self.netlist.signals[signal] = nets + elif dir == 'o': + top.ports_o[name] = self.emit_signal(signal) + elif dir == 'io': + top.ports_io[name] = (next_input_bit, signal.width) + nets = _nir.Value( + _nir.Net.from_cell(0, bit) + for bit in range(next_input_bit, next_input_bit + signal.width) + ) + next_input_bit += signal.width + self.netlist.signals[signal] = nets + + def emit_drivers(self): + for driver in self.drivers.values(): + value = driver.emit_value(self) + if driver.domain is not None: + clk, = self.emit_signal(driver.domain.clk) + if driver.domain.rst is not None and driver.domain.async_reset: + arst, = self.emit_signal(driver.domain.rst) + else: + arst = _nir.Net.from_const(0) + cell = _nir.FlipFlop(driver.module_idx, + data=value, + init=driver.signal.reset, + clk=clk, + clk_edge=driver.domain.clk_edge, + arst=arst, + attributes=driver.signal.attrs, + src_loc=driver.signal.src_loc, + ) + value = self.netlist.add_value_cell(len(value), cell) + if driver.assignments: + src_loc = driver.assignments[0].src_loc + else: + src_loc = driver.signal.src_loc + self.connect(self.emit_signal(driver.signal), value, src_loc=src_loc) + + # Connect all undriven signal bits to their reset values. This can only happen for entirely + # undriven signals, or signals that are partially driven by instances. + for signal, value in self.netlist.signals.items(): + for bit, net in enumerate(value): + if net.is_late and net not in self.netlist.connections: + self.netlist.connections[net] = _nir.Net.from_const((signal.reset >> bit) & 1) + + def emit_fragment(self, fragment: _ir.Fragment, parent_module_idx: 'int | None'): + from . import _mem + + fragment_name = self.fragment_names[fragment] + if isinstance(fragment, _ir.Instance): + assert parent_module_idx is not None + if fragment.type == "$tribuf": + self.emit_tribuf(parent_module_idx, fragment) + else: + self.emit_instance(parent_module_idx, fragment, name=fragment_name[-1]) + elif isinstance(fragment, _mem.MemoryInstance): + assert parent_module_idx is not None + memory = self.emit_memory(parent_module_idx, fragment, name=fragment_name[-1]) + write_ports = [] + for port in fragment._write_ports: + write_ports.append(self.emit_write_port(parent_module_idx, fragment, port, memory)) + for port in fragment._read_ports: + self.emit_read_port(parent_module_idx, fragment, port, memory, write_ports) + elif type(fragment) is _ir.Fragment: + module_idx = self.netlist.add_module(parent_module_idx, fragment_name) + signal_names = fragment._assign_names_to_signals() + self.netlist.modules[module_idx].signal_names = signal_names + if parent_module_idx is None: + self.emit_top_ports(fragment, signal_names) + for signal in signal_names: + self.emit_signal(signal) + for domain, stmts in fragment.statements.items(): + for stmt in stmts: + self.emit_stmt(module_idx, fragment, domain, stmt, _nir.Net.from_const(1)) + for subfragment, _name in fragment.subfragments: + self.emit_fragment(subfragment, module_idx) + if parent_module_idx is None: + self.emit_drivers() + else: + assert False # :nocov: + + +def _emit_netlist(netlist: _nir.Netlist, fragment, hierarchy): + fragment_names = fragment._assign_names_to_fragments(hierarchy) + NetlistEmitter(netlist, fragment_names).emit_fragment(fragment, None) + + +def _compute_net_flows(netlist: _nir.Netlist): + # Computes the net flows for all modules of the netlist. + # + # The rules for net flows are as follows: + # + # - the modules that have a given net in their net_flow form a subtree of the hierarchy + # - INTERNAL is used in the root of the subtree and nowhere else + # - OUTPUT is used for modules that contain the definition of the net, or are on the + # path from the definition to the root + # - remaining modules have a flow of INPUT (unless the net is a top-level inout port, + # in which case it is INOUT) + # + # In other words, the tree looks something like this: + # + # - [no flow] <<< top + # - [no flow] + # - INTERNAL + # - INPUT << use + # - [no flow] + # - INPUT + # - INPUT << use + # - OUTPUT + # - INPUT << use + # - [no flow] + # - OUTPUT << def + # - INPUT + # - INPUT + # - [no flow] + # - [no flow] + # - [no flow] + # + # This function doesn't assign the INOUT flow — that is corrected later, in compute_ports. + lca = {} + + # Initialize by marking the definition point of every net. + for cell_idx, cell in enumerate(netlist.cells): + for net in cell.output_nets(cell_idx): + lca[net] = cell.module_idx + netlist.modules[cell.module_idx].net_flow[net] = _nir.ModuleNetFlow.INTERNAL + + # Marks a use of a net within a given module, and adjusts its netflows in all modules + # as required. + def use_net(net, use_module): + if net.is_const: + return + # If the net is already present in the current module, we're done. + if net in netlist.modules[use_module].net_flow: + return + modules = netlist.modules + # Otherwise, we need to route the net through the hierarchy from def_module + # to use_module. We do that by treating use_module and def_module as pointers + # and moving them up the hierarchy until they meet at the new LCA. + def_module = lca[net] + # While def_module deeper than use_module, go up with def_module. + while len(modules[def_module].name) > len(modules[use_module].name): + modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT + def_module = modules[def_module].parent + # While use_module deeper than def_module, go up with use_module. + # If use_module is below def_module in the hierarchy, we may hit + # another module which already uses this net before hitting def_module, + # so check for this case. + while len(modules[def_module].name) < len(modules[use_module].name): + if net in modules[use_module].net_flow: + return + modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT + use_module = modules[use_module].parent + # Now both pointers should be at the same depth within the hierarchy. + assert len(modules[def_module].name) == len(modules[use_module].name) + # Move both pointers up until they meet. + while def_module != use_module: + modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT + def_module = modules[def_module].parent + modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT + use_module = modules[use_module].parent + assert len(modules[def_module].name) == len(modules[use_module].name) + # And mark the new LCA. + modules[def_module].net_flow[net] = _nir.ModuleNetFlow.INTERNAL + lca[net] = def_module + + # Now mark all uses and flesh out the structure. + for cell in netlist.cells: + for net in cell.input_nets(): + use_net(net, cell.module_idx) + # TODO: ? + for module_idx, module in enumerate(netlist.modules): + for signal in module.signal_names: + for net in netlist.signals[signal]: + use_net(net, module_idx) + + +def _compute_ports(netlist: _nir.Netlist): + # Compute the indexes at which the outputs of a cell should be split to create a distinct port. + # These indexes are stored here as nets. + port_starts = set() + for start, _ in netlist.top.ports_i.values(): + port_starts.add(_nir.Net.from_cell(0, start)) + for start, width in netlist.top.ports_io.values(): + port_starts.add(_nir.Net.from_cell(0, start)) + for cell_idx, cell in enumerate(netlist.cells): + if isinstance(cell, _nir.Instance): + for start, _ in cell.ports_o.values(): + port_starts.add(_nir.Net.from_cell(cell_idx, start)) + + # Compute the set of all inout nets. Currently, a net has inout flow iff it is connected to + # a toplevel inout port. + inouts = set() + for start, width in netlist.top.ports_io.values(): + for idx in range(start, start + width): + inouts.add(_nir.Net.from_cell(0, idx)) + + for module in netlist.modules: + # Collect preferred names for ports. If a port exactly matches a signal, we reuse + # the signal name for the port. Otherwise, we synthesize a private name. + name_table = {} + for signal, name in module.signal_names.items(): + value = netlist.signals[signal] + if value not in name_table and not name.startswith('$'): + name_table[value] = name + + # Gather together "adjacent" nets with the same flow into ports. + visited = set() + for net in sorted(module.net_flow): + flow = module.net_flow[net] + if flow == _nir.ModuleNetFlow.INTERNAL: + continue + if flow == _nir.ModuleNetFlow.INPUT and net in inouts: + flow = module.net_flow[net] = _nir.ModuleNetFlow.INOUT + if net in visited: + continue + # We found a net that needs a port. Keep joining the next nets output by the same + # cell into the same port, if applicable, but stop at instance/top port boundaries. + nets = [net] + while True: + succ = _nir.Net.from_cell(net.cell, net.bit + 1) + if succ in port_starts: + break + if succ not in module.net_flow: + break + if module.net_flow[succ] != module.net_flow[net]: + break + net = succ + nets.append(net) + value = _nir.Value(nets) + # Joined as many nets as we could, now name and add the port. + if value in name_table: + name = name_table[value] + else: + name = f"port${value[0].cell}${value[0].bit}" + module.ports[name] = (value, flow) + visited.update(value) + + # The 0th cell and the 0th module correspond to the toplevel. Transfer the net flows from + # the toplevel cell (used for data flow) to the toplevel module (used to split netlist into + # modules in the backends). + top_module = netlist.modules[0] + for name, (start, width) in netlist.top.ports_i.items(): + top_module.ports[name] = ( + _nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)), + _nir.ModuleNetFlow.INPUT + ) + for name, (start, width) in netlist.top.ports_io.items(): + top_module.ports[name] = ( + _nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)), + _nir.ModuleNetFlow.INOUT + ) + for name, value in netlist.top.ports_o.items(): + top_module.ports[name] = (value, _nir.ModuleNetFlow.OUTPUT) + + +def build_netlist(fragment, *, name="top"): + from ._xfrm import AssignmentLegalizer + + fragment = AssignmentLegalizer()(fragment) + netlist = _nir.Netlist() + _emit_netlist(netlist, fragment, hierarchy=(name,)) + netlist.resolve_all_nets() + _compute_net_flows(netlist) + _compute_ports(netlist) + return netlist diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py new file mode 100644 index 0000000..f3c0c0f --- /dev/null +++ b/amaranth/hdl/_nir.py @@ -0,0 +1,1003 @@ +from typing import Iterable +import enum + +from ._ast import SignalDict + + +__all__ = [ + # Netlist core + "Net", "Value", "Netlist", "ModuleNetFlow", "Module", "Cell", "Top", + # Computation cells + "Operator", "Part", "ArrayMux", + # Decision tree cells + "Matches", "PriorityMatch", "Assignment", "AssignmentList", + # Storage cells + "FlipFlop", "Memory", "SyncWritePort", "AsyncReadPort", "SyncReadPort", + # Formal verification cells + "Initial", "AnyValue", "AsyncProperty", "SyncProperty", + # Foreign interface cells + "Instance", "IOBuffer", +] + + +class Net(int): + @classmethod + def from_cell(cls, cell: int, bit: int): + assert bit in range(1 << 16) + assert cell >= 0 + if cell == 0: + assert bit >= 2 + return cls((cell << 16) | bit) + + @classmethod + def from_const(cls, val: int): + assert val in (0, 1) + return cls(val) + + @classmethod + def from_late(cls, val: int): + assert val < 0 + return cls(val) + + @property + def is_const(self): + return self in (0, 1) + + @property + def const(self): + assert self in (0, 1) + return int(self) + + @property + def is_late(self): + return self < 0 + + @property + def is_cell(self): + return self >= 2 + + @property + def cell(self): + assert self >= 2 + return self >> 16 + + @property + def bit(self): + assert self >= 2 + return self & 0xffff + + @classmethod + def ensure(cls, value: 'Net'): + assert isinstance(value, cls) + return value + + +class Value(tuple): + def __new__(cls, nets: 'Net | Iterable[Net]' = ()): + if isinstance(nets, Net): + return super().__new__(cls, (nets,)) + return super().__new__(cls, (Net.ensure(net) for net in nets)) + + @classmethod + def zeros(cls, digits=1): + return cls(Net.from_const(0) for _ in range(digits)) + + @classmethod + def ones(cls, digits=1): + return cls(Net.from_const(1) for _ in range(digits)) + + +class Netlist: + """A fine netlist. Consists of: + + - a flat array of cells + - a dictionary of connections for late-bound nets + - a map of hierarchical names to nets + - a map of signals to nets + + The nets are virtual: a list of nets is not materialized anywhere in the netlist. + A net is a single bit wide and represented as a single int. The int is encoded as follows: + + - A negative number means a late-bound net. The net should be looked up in the ``connections`` + dictionary to find its driver. + - Non-negative numbers are cell outputs, and are split into bitfields as follows: + + - bits 0-15: output bit index within a cell (exact meaning is cell-dependent) + - bits 16-...: index of cell in ``netlist.cells`` + + Cell 0 is always ``Top``. The first two output bits of ``Top`` are considered to be constants + ``0`` and ``1``, which effectively means that net encoded as ``0`` is always a constant ``0`` and + net encoded as ``1`` is always a constant ``1``. + + Multi-bit values are represented as tuples of int. + + Attributes + ---------- + + cells : list of ``Cell`` + connections : dict of (negative) int to int + signals : dict of Signal to ``Value`` + """ + def __init__(self): + self.modules: list[Module] = [] + self.cells: list[Cell] = [Top()] + self.connections: dict[Net, Net] = {} + self.signals = SignalDict() + self.last_late_net = 0 + + def resolve_net(self, net: Net): + assert isinstance(net, Net) + while net.is_late: + net = self.connections[net] + return net + + def resolve_value(self, value: Value): + return Value(self.resolve_net(net) for net in value) + + def resolve_all_nets(self): + for cell in self.cells: + cell.resolve_nets(self) + for sig in self.signals: + self.signals[sig] = self.resolve_value(self.signals[sig]) + + def __str__(self): + def net_to_str(net): + net = self.resolve_net(net) + if net.is_const: + return f"{net.const}" + return f"{net.cell}.{net.bit}" + def val_to_str(val): + return "{" + " ".join(net_to_str(x) for x in val) + "}" + result = [] + for module_idx, module in enumerate(self.modules): + result.append(f"module {module_idx} [parent {module.parent}]: {' '.join(module.name)}") + if module.submodules: + result.append(f" submodules {' '.join(str(x) for x in module.submodules)} ") + result.append(f" cells {' '.join(str(x) for x in module.cells)} ") + for name, (val, flow) in module.ports.items(): + result.append(f" port {name} {val_to_str(val)}: {flow}") + for cell_idx, cell in enumerate(self.cells): + result.append(f"cell {cell_idx} [module {cell.module_idx}]: ") + if isinstance(cell, Top): + result.append("top") + for name, val in cell.ports_o.items(): + result.append(f" output {name}: {val_to_str(val)}") + for name, (start, num) in cell.ports_i.items(): + result.append(f" input {name}: 0.{start}..0.{start+num-1}") + for name, (start, num) in cell.ports_io.items(): + result.append(f" inout {name}: 0.{start}..0.{start+num-1}") + elif isinstance(cell, Matches): + result.append(f"matches {val_to_str(cell.value)}, {' | '.join(cell.patterns)}") + elif isinstance(cell, PriorityMatch): + result.append(f"priority_match {net_to_str(cell.en)}, {val_to_str(cell.inputs)}") + elif isinstance(cell, AssignmentList): + result.append(f"list {val_to_str(cell.default)}") + for assign in cell.assignments: + result.append(f" if {net_to_str(assign.cond)} start {assign.start} <- {val_to_str(assign.value)}") + elif isinstance(cell, Operator): + inputs = ", ".join(val_to_str(input) for input in cell.inputs) + result.append(f"{cell.operator} {inputs}") + elif isinstance(cell, Part): + result.append(f"part {val_to_str(cell.value)}, {val_to_str(cell.offset)}, {cell.width}, {cell.stride}, {cell.value_signed}") + elif isinstance(cell, ArrayMux): + result.append(f"array {cell.width}, {val_to_str(cell.index)}, {', '.join(val_to_str(elem) for elem in cell.elems)}") + elif isinstance(cell, FlipFlop): + result.append(f"ff {val_to_str(cell.data)} {cell.init} @{cell.clk_edge}edge {net_to_str(cell.clk)} {net_to_str(cell.arst)}") + for attr_name, attr_value in cell.attributes.items(): + result.append(f" attribute {attr_name} {attr_value}") + elif isinstance(cell, Memory): + result.append(f"memory {cell.name} {cell.width} {cell.depth} {cell.init}") + for attr_name, attr_value in cell.attributes.items(): + result.append(f" attribute {attr_name} {attr_value}") + elif isinstance(cell, SyncWritePort): + result.append(f"wrport {cell.memory} {val_to_str(cell.data)} {val_to_str(cell.addr)} {val_to_str(cell.en)} @{cell.clk_edge}edge {net_to_str(cell.clk)}") + elif isinstance(cell, AsyncReadPort): + result.append(f"rdport {cell.memory} {cell.width} {val_to_str(cell.addr)}") + elif isinstance(cell, SyncReadPort): + result.append(f"rdport {cell.memory} {cell.width} {val_to_str(cell.addr)} {net_to_str(cell.en)} @{cell.clk_edge}edge {net_to_str(cell.clk)}") + for port in cell.transparent_for: + result.append(f" transparent {port}") + elif isinstance(cell, Initial): + result.append("initial") + elif isinstance(cell, AnyValue): + result.append("{cell.kind} {cell.width}") + elif isinstance(cell, AsyncProperty): + result.append(f"{cell.kind} {cell.name!r} {net_to_str(cell.test)} {net_to_str(cell.en)}") + elif isinstance(cell, SyncProperty): + result.append(f"{cell.kind} {cell.name!r} {net_to_str(cell.test)} {net_to_str(cell.en)} @{cell.clk_edge}edge {net_to_str(cell.clk)}") + elif isinstance(cell, Instance): + result.append("instance {cell.type} {cell.name}") + for attr_name, attr_value in cell.attributes.items(): + result.append(f" attribute {attr_name} {attr_value}") + for attr_name, attr_value in cell.parameters.items(): + result.append(f" parameter {attr_name} {attr_value}") + for attr_name, (start, num) in cell.ports_o.items(): + result.append(f" output {attr_name} {cell_idx}.{start}..{cell_idx}.{start+num-1}") + for attr_name, attr_value in cell.ports_i.items(): + result.append(f" input {attr_name} {val_to_str(attr_value)}") + for attr_name, attr_value in cell.ports_io.items(): + result.append(f" inout {attr_name} {val_to_str(attr_value)}") + elif isinstance(cell, IOBuffer): + result.append(f"iob {val_to_str(cell.pad)} {val_to_str(cell.o)} {net_to_str(cell.oe)}") + else: + assert False # :nocov: + return "\n".join(result) + + def add_module(self, parent, name: str): + module_idx = len(self.modules) + self.modules.append(Module(parent, name)) + if module_idx == 0: + self.modules[0].cells.append(0) + if parent is not None: + self.modules[parent].submodules.append(module_idx) + return module_idx + + def add_cell(self, cell): + idx = len(self.cells) + self.cells.append(cell) + self.modules[cell.module_idx].cells.append(idx) + return idx + + def add_value_cell(self, width: int, cell): + cell_idx = self.add_cell(cell) + return Value(Net.from_cell(cell_idx, bit) for bit in range(width)) + + def alloc_late_value(self, width: int): + self.last_late_net -= width + return Value(Net.from_late(self.last_late_net + bit) for bit in range(width)) + + @property + def top(self): + top = self.cells[0] + assert isinstance(top, Top) + return top + + +class ModuleNetFlow(enum.Enum): + """Describes how a given Net flows into or out of a Module. + + The net can also be none of these (not present in the dictionary at all), + when it is not present in the module at all. + """ + + #: The net is present in the module (used in the module or needs + #: to be routed through it between its submodules), but is not + #: present outside its subtree and thus is not a port of this module. + INTERNAL = "internal" + + #: The net is present in the module, and is not driven from + #: the module or any of its submodules. It is thus an input + #: port of this module. + INPUT = "input" + + #: The net is present in the module, is driven from the module or + #: one of its submodules, and is also used outside of its subtree. + #: It is thus an output port of this module. + OUTPUT = "output" + + #: The net is a special top-level inout net that is used within + #: this module or its submodules. It is an inout port of this module. + INOUT = "inout" + + +class Module: + """A module within the netlist. + + Attributes + ---------- + + parent: index of parent module, or ``None`` for top module + name: a tuple of str, hierarchical name of this module (top has empty tuple) + submodules: a list of nested module indices + signal_names: a SignalDict from Signal to str, signal names visible in this module + net_flow: a dict from Net to NetFlow, describes how a net is used within this module + ports: a dict from port name to (Value, NetFlow) pair + cells: a list of cell indices that belong to this module + """ + def __init__(self, parent, name): + self.parent = parent + self.name = name + self.submodules = [] + self.signal_names = SignalDict() + self.net_flow = {} + self.ports = {} + self.cells = [] + + +class Cell: + """A base class for all cell types. + + Attributes + ---------- + + src_loc: str + module: int, index of the module this cell belongs to (within Netlist.modules) + """ + + def __init__(self, module_idx: int, *, src_loc): + self.module_idx = module_idx + self.src_loc = src_loc + + def input_nets(self): + raise NotImplementedError + + def output_nets(self, self_idx: int): + raise NotImplementedError + + def resolve_nets(self, netlist: Netlist): + raise NotImplementedError + + +class Top(Cell): + """A special cell type representing top-level ports. Must be present in the netlist exactly + once, at index 0. + + Top-level outputs are stored as a dict of names to their assigned values. + + Top-level inputs and inouts are effectively the output of this cell. They are both stored + as a dict of names to a (start bit index, width) tuple. Output bit indices 0 and 1 are reserved + for constant nets, so the lowest bit index that can be assigned to a port is 2. + + Top-level inouts are special and can only be used by inout ports of instances, or in the pad + value of an ``IoBuf`` cell. + + Attributes + ---------- + + ports_o: dict of str to Value + ports_i: dict of str to (int, int) + ports_io: dict of str to (int, int) + """ + def __init__(self): + super().__init__(module_idx=0, src_loc=None) + + self.ports_o = {} + self.ports_i = {} + self.ports_io = {} + + def input_nets(self): + nets = set() + for value in self.ports_o.values(): + nets |= set(value) + return nets + + def output_nets(self, self_idx: int): + nets = set() + for start, width in self.ports_i.values(): + for bit in range(start, start + width): + nets.add(Net.from_cell(self_idx, bit)) + for start, width in self.ports_io.values(): + for bit in range(start, start + width): + nets.add(Net.from_cell(self_idx, bit)) + return nets + + def resolve_nets(self, netlist: Netlist): + for port in self.ports_o: + self.ports_o[port] = netlist.resolve_value(self.ports_o[port]) + + +class Operator(Cell): + """Roughly corresponds to ``hdl.ast.Operator``. + + The available operators are roughly the same as in AST, with some changes: + + - '<', '>', '<=', '>=', '//', '%', '>>' have signed and unsigned variants that are selected + by prepending 'u' or 's' to operator name + - 's', 'u', and unary '+' are redundant and do not exist + - many operators restrict input widths to be the same as output width, + and/or to be the same as each other + + The unary operators are: + + - '-', '~': like AST, input same width as output + - 'b', 'r|', 'r&', 'r^': like AST, 1-bit output + + The binary operators are: + + - '+', '-', '*', '&', '^', '|', 'u//', 's//', 'u%', 's%': like AST, both inputs same width as output + - '<<', 'u>>', 's>>': like AST, first input same width as output + - '==', '!=', 'u<', 's<', 'u>', 's>', 'u<=', 's<=', 'u>=', 's>=': like AST, both inputs need to have + the same width, 1-bit output + + The ternary operators are: + + - 'm': like AST, first input needs to have width of 1, second and third operand need to have the same + width as output + + Attributes + ---------- + + operator: str, symbol of the operator (from the above list) + inputs: tuple of Value + """ + + def __init__(self, module_idx, *, operator: str, inputs, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.operator = operator + self.inputs = tuple(Value(input) for input in inputs) + + @property + def width(self): + if self.operator in ('~', '-', '+', '*', '&', '^', '|', 'u//', 's//', 'u%', 's%', '<<', 'u>>', 's>>'): + return len(self.inputs[0]) + elif self.operator in ('b', 'r&', 'r^', 'r|', '==', '!=', 'u<', 's<', 'u>', 's>', 'u<=', 's<=', 'u>=', 's>='): + return 1 + elif self.operator == 'm': + return len(self.inputs[1]) + else: + assert False # :nocov: + + def input_nets(self): + nets = set() + for value in self.inputs: + nets |= set(value) + return nets + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.inputs = tuple(netlist.resolve_value(val) for val in self.inputs) + + +class Part(Cell): + """Corresponds to ``hdl.ast.Part``. + + Attributes + ---------- + + value: Value, the data input + value_signed: bool + offset: Value, the offset input + width: int + stride: int + """ + def __init__(self, module_idx, *, value, value_signed, offset, width, stride, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert type(width) is int + assert type(stride) is int + + self.value = Value(value) + self.value_signed = value_signed + self.offset = Value(offset) + self.width = width + self.stride = stride + + def input_nets(self): + return set(self.value) | set(self.offset) + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.value = netlist.resolve_value(self.value) + self.offset = netlist.resolve_value(self.offset) + + +class ArrayMux(Cell): + """Corresponds to ``hdl.ast.ArrayProxy``. All values in the ``elems`` array need to have + the same width as the output. + + Attributes + ---------- + + width: int (width of output and all inputs) + elems: tuple of Value + index: Value + """ + def __init__(self, module_idx, *, width, elems, index, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.width = width + self.elems = tuple(Value(val) for val in elems) + self.index = Value(index) + + def input_nets(self): + nets = set(self.index) + for value in self.elems: + nets |= set(value) + return nets + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.elems = tuple(netlist.resolve_value(val) for val in self.elems) + self.index = netlist.resolve_value(self.index) + + +class Matches(Cell): + """A combinatorial cell performing a comparison like ``Value.matches`` + (or, equivalently, a case condition). + + Attributes + ---------- + + value: Value + patterns: tuple of str, each str contains '0', '1', '-' + """ + def __init__(self, module_idx, *, value, patterns, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.value = Value(value) + self.patterns = tuple(patterns) + + def input_nets(self): + return set(self.value) + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, 0)} + + def resolve_nets(self, netlist: Netlist): + self.value = netlist.resolve_value(self.value) + + +class PriorityMatch(Cell): + """Used to represent a single switch on the control plane of processes. + + The output is the same length as ``inputs``. If ``en`` is ``0``, the output + is all-0. Otherwise, output keeps the lowest-numbered ``1`` bit in the input + (if any) and masks all other bits to ``0``. + + Note: the RTLIL backend requires all bits of ``inputs`` to be driven + by a ``Match`` cell within the same module. + + Attributes + ---------- + en: Net + inputs: Value + """ + def __init__(self, module_idx, *, en, inputs, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.en = Net.ensure(en) + self.inputs = Value(inputs) + + def input_nets(self): + return set(self.inputs) | {self.en} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.inputs))} + + def resolve_nets(self, netlist: Netlist): + self.en = netlist.resolve_net(self.en) + self.inputs = netlist.resolve_value(self.inputs) + + +class Assignment: + """A single assignment in an ``AssignmentList``. + + The assignment is executed iff ``cond`` is true. When the assignment + is executed, ``len(value)`` bits starting at position `offset` are set + to the value ``value``, and the remaining bits are unchanged. + Assignments to out-of-bounds bit indices are ignored. + + Attributes + ---------- + + cond: Net + start: int + value: Value + src_loc: str + """ + def __init__(self, *, cond, start, value, src_loc): + assert isinstance(start, int) + self.cond = Net.ensure(cond) + self.start = start + self.value = Value(value) + self.src_loc = src_loc + + def resolve_nets(self, netlist: Netlist): + self.cond = netlist.resolve_net(self.cond) + self.value = netlist.resolve_value(self.value) + + +class AssignmentList(Cell): + """Used to represent a single assigned signal on the data plane of processes. + + The output of this cell is determined by starting with the ``default`` value, + then executing each assignment in sequence. + + Note: the RTLIL backend requires all ``cond`` inputs of assignments to be driven + by a ``PriorityMatch`` cell within the same module. + + Attributes + ---------- + default: Value + assignments: tuple of ``Assignment`` + """ + def __init__(self, module_idx, *, default, assignments, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assignments = tuple(assignments) + for assign in assignments: + assert isinstance(assign, Assignment) + + self.default = Value(default) + self.assignments: tuple[Assignment, ...] = assignments + + def input_nets(self): + nets = set(self.default) + for assign in self.assignments: + nets.add(assign.cond) + nets |= set(assign.value) + return nets + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.default))} + + def resolve_nets(self, netlist: Netlist): + for assign in self.assignments: + assign.resolve_nets(netlist) + self.default = netlist.resolve_value(self.default) + + +class FlipFlop(Cell): + """A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value. + ``clk`` and ``clk_edge`` work as in a ``ClockDomain``. ``arst`` is the async reset signal, + or ``0`` if async reset is not used. + + Attributes + ---------- + + data: Value + init: int + clk: Net + clk_edge: str, either 'pos' or 'neg' + arst: Net + attributes: dict from str to int, Const, or str + """ + def __init__(self, module_idx, *, data, init, clk, clk_edge, arst, attributes, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert clk_edge in ('pos', 'neg') + assert type(init) is int + + self.data = Value(data) + self.init = init + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + self.arst = Net.ensure(arst) + self.attributes = attributes + + def input_nets(self): + return set(self.data) | {self.clk, self.arst} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.data))} + + def resolve_nets(self, netlist: Netlist): + self.data = netlist.resolve_value(self.data) + self.clk = netlist.resolve_net(self.clk) + self.arst = netlist.resolve_net(self.arst) + + +class Memory(Cell): + """Corresponds to ``Memory``. ``init`` must have length equal to ``depth``. + Read and write ports are separate cells. + + Attributes + ---------- + + width: int + depth: int + init: tuple of int + name: str + attributes: dict from str to int, Const, or str + """ + def __init__(self, module_idx, *, width, depth, init, name, attributes, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.width = width + self.depth = depth + self.init = tuple(init) + self.name = name + self.attributes = attributes + + def input_nets(self): + return set() + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + pass + + +class SyncWritePort(Cell): + """A single write port of a memory. This cell has no output. + + Attributes + ---------- + + memory: cell index of ``Memory`` + data: Value + addr: Value + en: Value + clk: Net + clk_edge: str, either 'pos' or 'neg' + """ + def __init__(self, module_idx, memory, *, data, addr, en, clk, clk_edge, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert clk_edge in ('pos', 'neg') + self.memory = memory + self.data = Value(data) + self.addr = Value(addr) + self.en = Value(en) + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + + def input_nets(self): + return set(self.data) | set(self.addr) | set(self.en) | {self.clk} + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + self.data = netlist.resolve_value(self.data) + self.addr = netlist.resolve_value(self.addr) + self.en = netlist.resolve_value(self.en) + self.clk = netlist.resolve_net(self.clk) + + +class AsyncReadPort(Cell): + """A single asynchronous read port of a memory. + + Attributes + ---------- + + memory: cell index of ``Memory`` + width: int + addr: Value + """ + def __init__(self, module_idx, memory, *, width, addr, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.memory = memory + self.width = width + self.addr = Value(addr) + + def input_nets(self): + return set(self.addr) + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.addr = netlist.resolve_value(self.addr) + +class SyncReadPort(Cell): + """A single synchronous read port of a memory. The cell output is the data port. + ``transparent_for`` is the set of write ports (identified by cell index) that this + read port is transparent with. + + Attributes + ---------- + + memory: cell index of ``Memory`` + width: int + addr: Value + en: Net + clk: Net + clk_edge: str, either 'pos' or 'neg' + transparent_for: tuple of int + """ + def __init__(self, module_idx, memory, *, width, addr, en, clk, clk_edge, transparent_for, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert clk_edge in ('pos', 'neg') + self.memory = memory + self.width = width + self.addr = Value(addr) + self.en = Net.ensure(en) + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + self.transparent_for = tuple(transparent_for) + + def input_nets(self): + return set(self.addr) | {self.en, self.clk} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.addr = netlist.resolve_value(self.addr) + self.en = netlist.resolve_net(self.en) + self.clk = netlist.resolve_net(self.clk) + +class Initial(Cell): + """Corresponds to ``Initial`` value.""" + + def input_nets(self): + return set() + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, 0)} + + def resolve_nets(self, netlist: Netlist): + pass + + +class AnyValue(Cell): + """Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'`` + or ``'anyseq'``. + + Attributes + ---------- + + kind: str, 'anyconst' or 'anyseq' + width: int + """ + def __init__(self, module_idx, *, kind, width, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert kind in ('anyconst', 'anyseq') + self.kind = kind + self.width = width + + def input_nets(self): + return set() + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + pass + + +class AsyncProperty(Cell): + """Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain._ + + Attributes + ---------- + + kind: str, either 'assert', 'assume', or 'cover' + test: Net + en: Net + name: str + """ + def __init__(self, module_idx, *, kind, test, en, name, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert kind in ('assert', 'assume', 'cover') + self.kind = kind + self.test = Net.ensure(test) + self.en = Net.ensure(en) + self.name = name + + def input_nets(self): + return {self.test, self.en} + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + self.test = netlist.resolve_net(self.test) + self.en = netlist.resolve_net(self.en) + + +class SyncProperty(Cell): + """Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain. + + Attributes + ---------- + + kind: str, either 'assert', 'assume', or 'cover' + test: Net + en: Net + clk: Net + clk_edge: str, either 'pos' or 'neg' + name: str + """ + + def __init__(self, module_idx, *, kind, test, en, clk, clk_edge, name, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert kind in ('assert', 'assume', 'cover') + assert clk_edge in ('pos', 'neg') + self.kind = kind + self.test = Net.ensure(test) + self.en = Net.ensure(en) + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + self.name = name + + def input_nets(self): + return {self.test, self.en, self.clk} + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + self.test = netlist.resolve_net(self.test) + self.en = netlist.resolve_net(self.en) + self.clk = netlist.resolve_net(self.clk) + + +class Instance(Cell): + """Corresponds to ``Instance``. ``type``, ``parameters`` and ``attributes`` work the same as in + ``Instance``. Input and inout ports are represented as a dict of port names to values. + Inout ports must be connected to nets corresponding to an IO port of the ``Top`` cell. + + Output ports are represented as a dict of port names to (start bit index, width) describing + their position in the virtual "output" of this cell. + + Attributes + ---------- + + type: str + name: str + parameters: dict of str to Const, int, or str + attributes: dict of str to Const, int, or str + ports_i: dict of str to Value + ports_o: dict of str to pair of int (index start, width) + ports_io: dict of str to Value + """ + + def __init__(self, module_idx, *, type, name, parameters, attributes, ports_i, ports_o, ports_io, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.type = type + self.name = name + self.parameters = parameters + self.attributes = attributes + self.ports_i = {name: Value(val) for name, val in ports_i.items()} + self.ports_o = ports_o + self.ports_io = {name: Value(val) for name, val in ports_io.items()} + + def input_nets(self): + nets = set() + for val in self.ports_i.values(): + nets |= set(val) + for val in self.ports_io.values(): + nets |= set(val) + return nets + + def output_nets(self, self_idx: int): + nets = set() + for start, width in self.ports_o.values(): + for bit in range(start, start + width): + nets.add(Net.from_cell(self_idx, bit)) + return nets + + def resolve_nets(self, netlist: Netlist): + for port in self.ports_i: + self.ports_i[port] = netlist.resolve_value(self.ports_i[port]) + for port in self.ports_io: + self.ports_io[port] = netlist.resolve_value(self.ports_io[port]) + + +class IOBuffer(Cell): + """An IO buffer cell. ``pad`` must be connected to nets corresponding to an IO port + of the ``Top`` cell. This cell does two things: + + - a tristate buffer is inserted driving ``pad`` based on ``o`` and ``oe`` nets (output buffer) + - the value of ``pad`` is sampled and made available as output of this cell (input buffer) + + Attributes + ---------- + + pad: Value + o: Value + oe: Net + """ + def __init__(self, module_idx, *, pad, o, oe, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.pad = Value(pad) + self.o = Value(o) + self.oe = Net.ensure(oe) + + def input_nets(self): + return set(self.pad) | set(self.o) | {self.oe} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.pad))} + + def resolve_nets(self, netlist: Netlist): + self.pad = netlist.resolve_value(self.pad) + self.o = netlist.resolve_value(self.o) + self.oe = netlist.resolve_net(self.oe) diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index a8cf574..ec3c057 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -16,7 +16,6 @@ __all__ = ["ValueVisitor", "ValueTransformer", "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", - "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", "ResetInserter", "EnableInserter", "AssignmentLegalizer"] @@ -195,7 +194,7 @@ class StatementTransformer(StatementVisitor): return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs)) def on_Property(self, stmt): - return Property(stmt.kind, self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) + return Property(stmt.kind, self.on_value(stmt.test), name=stmt.name) def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) @@ -533,97 +532,6 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer) return new_fragment -class SwitchCleaner(StatementVisitor): - def on_ignore(self, stmt): - return stmt - - on_Assign = on_ignore - on_Property = on_ignore - - def on_Switch(self, stmt): - cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) - if any(len(s) for s in cases.values()): - return Switch(stmt.test, cases) - - def on_statements(self, stmts): - stmts = flatten(self.on_statement(stmt) for stmt in stmts) - return _StatementList(stmt for stmt in stmts if stmt is not None) - - -class LHSGroupAnalyzer(StatementVisitor): - def __init__(self): - self.signals = SignalDict() - self.unions = OrderedDict() - - def find(self, signal): - if signal not in self.signals: - self.signals[signal] = len(self.signals) - group = self.signals[signal] - while group in self.unions: - group = self.unions[group] - self.signals[signal] = group - return group - - def unify(self, root, *leaves): - root_group = self.find(root) - for leaf in leaves: - leaf_group = self.find(leaf) - if root_group == leaf_group: - continue - self.unions[leaf_group] = root_group - - def groups(self): - groups = OrderedDict() - for signal in self.signals: - group = self.find(signal) - if group not in groups: - groups[group] = SignalSet() - groups[group].add(signal) - return groups - - def on_Assign(self, stmt): - lhs_signals = stmt._lhs_signals() - if lhs_signals: - self.unify(*stmt._lhs_signals()) - - def on_Property(self, stmt): - lhs_signals = stmt._lhs_signals() - if lhs_signals: - self.unify(*stmt._lhs_signals()) - - def on_Switch(self, stmt): - for case_stmts in stmt.cases.values(): - self.on_statements(case_stmts) - - def on_statements(self, stmts): - assert not isinstance(stmts, str) - for stmt in stmts: - self.on_statement(stmt) - - def __call__(self, stmts): - self.on_statements(stmts) - return self.groups() - - -class LHSGroupFilter(SwitchCleaner): - def __init__(self, signals): - self.signals = signals - - def on_Assign(self, stmt): - # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together - # on LHS are a part of the same group, so it is sufficient to check any of them. - lhs_signals = stmt.lhs._lhs_signals() - if lhs_signals: - any_lhs_signal = next(iter(lhs_signals)) - if any_lhs_signal in self.signals: - return stmt - - def on_Property(self, stmt): - any_lhs_signal = next(iter(stmt._lhs_signals())) - if any_lhs_signal in self.signals: - return stmt - - class _ControlInserter(FragmentTransformer): def __init__(self, controls): self.src_loc = None @@ -655,10 +563,23 @@ class ResetInserter(_ControlInserter): fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) +class _PropertyEnableInserter(StatementTransformer): + def __init__(self, en): + self.en = en + + def on_Property(self, stmt): + return Switch( + self.en, + {1: [stmt]}, + src_loc=stmt.src_loc, + ) + + class EnableInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(s) for s in signals] fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) + fragment.statements[domain] = _PropertyEnableInserter(self.controls[domain])(fragment.statements[domain]) def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) diff --git a/amaranth/hdl/xfrm.py b/amaranth/hdl/xfrm.py index 2b96db6..b490f00 100644 --- a/amaranth/hdl/xfrm.py +++ b/amaranth/hdl/xfrm.py @@ -9,7 +9,6 @@ __all__ = ["ValueVisitor", "ValueTransformer", "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", - "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", "ResetInserter", "EnableInserter"] diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 94c50ec..56cc2ec 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -244,136 +244,6 @@ class DomainLowererTestCase(FHDLTestCase): DomainLowerer()(f) -class SwitchCleanerTestCase(FHDLTestCase): - def test_clean(self): - a = Signal() - b = Signal() - c = Signal() - stmts = [ - Switch(a, { - 1: a.eq(0), - 0: [ - b.eq(1), - Switch(b, {1: [ - Switch(a|b, {}) - ]}) - ] - }) - ] - - self.assertRepr(SwitchCleaner()(stmts), """ - ( - (switch (sig a) - (case 1 - (eq (sig a) (const 1'd0))) - (case 0 - (eq (sig b) (const 1'd1))) - ) - ) - """) - - -class LHSGroupAnalyzerTestCase(FHDLTestCase): - def test_no_group_unrelated(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - b.eq(0), - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a,)), - SignalSet((b,)), - ]) - - def test_group_related(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - Cat(a, b).eq(0), - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a, b)), - ]) - - def test_no_loops(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - Cat(a, b).eq(0), - Cat(a, b).eq(0), - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a, b)), - ]) - - def test_switch(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - Switch(a, { - 1: b.eq(0), - }) - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a,)), - SignalSet((b,)), - ]) - - def test_lhs_empty(self): - stmts = [ - Cat().eq(0) - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - ]) - - -class LHSGroupFilterTestCase(FHDLTestCase): - def test_filter(self): - a = Signal() - b = Signal() - c = Signal() - stmts = [ - Switch(a, { - 1: a.eq(0), - 0: [ - b.eq(1), - Switch(b, {1: []}) - ] - }) - ] - - self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """ - ( - (switch (sig a) - (case 1 - (eq (sig a) (const 1'd0))) - (case 0 ) - ) - ) - """) - - def test_lhs_empty(self): - stmts = [ - Cat().eq(0) - ] - - self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()") - - class ResetInserterTestCase(FHDLTestCase): def setUp(self): self.s1 = Signal()