diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 0879453..a442238 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -1,13 +1,9 @@ +from typing import Iterable import io -from collections import OrderedDict -from contextlib import contextmanager -import warnings -import re -from .._utils import flatten from ..utils import bits_for -from ..hdl import _ast, _ir, _mem, _xfrm, _repr from ..lib import wiring +from ..hdl import _repr, _ast, _ir, _nir __all__ = ["convert", "convert_fragment"] @@ -89,11 +85,6 @@ class _BufferedBuilder: self._buffer.write(fmt.format(*args, **kwargs)) -class _ProxiedBuilder: - def _append(self, *args, **kwargs): - self.rtlil._append(*args, **kwargs) - - class _AttrBuilder: def __init__(self, emit_src, *args, **kwargs): super().__init__(*args, **kwargs) @@ -183,10 +174,7 @@ class _ModuleBuilder(_AttrBuilder, _BufferedBuilder, _Namer): self._append(" parameter \\{} {}\n", param, _const(value)) for port, wire in ports.items(): - # By convention, Yosys ports named $\d+ are positional. Amaranth does not support - # connecting cell ports by position. See amaranth-lang/amaranth#733. - assert not re.match(r"^\$\d+$", port) - self._append(" connect {} {}\n", port, wire) + self._append(" connect \\{} {}\n", port, wire) self._append(" end\n") return name @@ -216,11 +204,14 @@ class _ProcessBuilder(_AttrBuilder, _BufferedBuilder): return _CaseBuilder(self, indent=2) -class _CaseBuilder(_ProxiedBuilder): +class _CaseBuilder: def __init__(self, rtlil, indent): self.rtlil = rtlil self.indent = indent + def _append(self, *args, **kwargs): + self.rtlil._append(*args, **kwargs) + def __enter__(self): return self @@ -234,7 +225,7 @@ class _CaseBuilder(_ProxiedBuilder): return _SwitchBuilder(self.rtlil, cond, attrs, src, self.indent) -class _SwitchBuilder(_AttrBuilder, _ProxiedBuilder): +class _SwitchBuilder(_AttrBuilder): def __init__(self, rtlil, cond, attrs, src, indent): super().__init__(emit_src=rtlil.emit_src) self.rtlil = rtlil @@ -243,6 +234,9 @@ class _SwitchBuilder(_AttrBuilder, _ProxiedBuilder): self.src = src self.indent = indent + def _append(self, *args, **kwargs): + self.rtlil._append(*args, **kwargs) + def __enter__(self): self._attributes(self.attrs, src=self.src, indent=self.indent) self._append("{}switch {}\n", " " * self.indent, self.cond) @@ -268,777 +262,738 @@ def _src(src_loc): return f"{file}:{line}" -class _LegalizeValue(Exception): - def __init__(self, value, branches, src_loc): - self.value = value - self.branches = list(branches) - self.src_loc = src_loc - - -class _ValueCompilerState: - def __init__(self, rtlil): - self.rtlil = rtlil - self.wires = _ast.SignalDict() - self.driven = _ast.SignalDict() - self.ports = _ast.SignalDict() - self.anys = _ast.ValueDict() - - self.expansions = _ast.ValueDict() - - def add_driven(self, signal, sync): - self.driven[signal] = sync - - def add_port(self, signal, kind): - assert kind in ("i", "o", "io") - if kind == "i": - kind = "input" - elif kind == "o": - kind = "output" - elif kind == "io": - kind = "inout" - self.ports[signal] = (len(self.ports), kind) - - def resolve(self, signal, prefix=None): - if len(signal) == 0: - return "{ }", "{ }" - - if signal in self.wires: - return self.wires[signal] - - if signal in self.ports: - port_id, port_kind = self.ports[signal] - else: - port_id = port_kind = None - if prefix is not None: - wire_name = f"{prefix}_{signal.name}" - else: - wire_name = signal.name - - is_sync_driven = signal in self.driven and self.driven[signal] - - attrs = dict(signal.attrs) - for repr in signal._value_repr: - if repr.path == () and isinstance(repr.format, _repr.FormatEnum): - enum = repr.format.enum - attrs["enum_base_type"] = enum.__name__ - for value in enum: - attrs["enum_value_{:0{}b}".format(value.value, signal.width)] = value.name - - # For every signal in the sync domain, assign \sig's initial value (using the \init reg - # attribute) to the reset value. - if is_sync_driven: - attrs["init"] = _ast.Const(signal.reset, signal.width) - - wire_curr = self.rtlil.wire(width=signal.width, name=wire_name, - port_id=port_id, port_kind=port_kind, - attrs=attrs, src=_src(signal.src_loc)) - if is_sync_driven: - wire_next = self.rtlil.wire(width=signal.width, name=wire_curr + "$next", - src=_src(signal.src_loc)) - else: - wire_next = None - self.wires[signal] = (wire_curr, wire_next) - - return wire_curr, wire_next - - def resolve_curr(self, signal, prefix=None): - wire_curr, wire_next = self.resolve(signal, prefix) - return wire_curr - - def expand(self, value): - if not self.expansions: - return value - return self.expansions.get(value, value) - - @contextmanager - def expand_to(self, value, expansion): - try: - assert value not in self.expansions - self.expansions[value] = expansion - yield - finally: - del self.expansions[value] - - -class _ValueCompiler(_xfrm.ValueVisitor): - def __init__(self, state): - self.s = state - - def on_unknown(self, value): - if value is None: - return None - else: - super().on_unknown(value) - - def on_ClockSignal(self, value): - raise NotImplementedError # :nocov: - - def on_ResetSignal(self, value): - raise NotImplementedError # :nocov: - - def on_Cat(self, value): - return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts]))) - - def _prepare_value_for_Slice(self, value): - raise NotImplementedError # :nocov: - - def on_Slice(self, value): - if value.start == 0 and value.stop == len(value.value): - return self(value.value) - - sigspec = self._prepare_value_for_Slice(value.value) - if value.start == value.stop: - return "{}" - elif value.start + 1 == value.stop: - return f"{sigspec} [{value.start}]" - else: - return f"{sigspec} [{value.stop - 1}:{value.start}]" - - def on_ArrayProxy(self, value): - index = self.s.expand(value.index) - if isinstance(index, _ast.Const): - if index.value < len(value.elems): - elem = value.elems[index.value] - else: - elem = value.elems[-1] - return self.match_shape(elem, value.shape()) - else: - max_index = 1 << len(value.index) - max_elem = len(value.elems) - raise _LegalizeValue(value.index, range(min(max_index, max_elem)), value.src_loc) - - -class _RHSValueCompiler(_ValueCompiler): - operator_map = { - (1, "~"): "$not", - (1, "-"): "$neg", - (1, "b"): "$reduce_bool", - (1, "r|"): "$reduce_or", - (1, "r&"): "$reduce_and", - (1, "r^"): "$reduce_xor", - (2, "+"): "$add", - (2, "-"): "$sub", - (2, "*"): "$mul", - (2, "//"): "$divfloor", - (2, "%"): "$modfloor", - (2, "**"): "$pow", - (2, "<<"): "$sshl", - (2, ">>"): "$sshr", - (2, "&"): "$and", - (2, "^"): "$xor", - (2, "|"): "$or", - (2, "=="): "$eq", - (2, "!="): "$ne", - (2, "<"): "$lt", - (2, "<="): "$le", - (2, ">"): "$gt", - (2, ">="): "$ge", - (3, "m"): "$mux", - } - - def on_value(self, value): - return super().on_value(self.s.expand(value)) - - def on_Const(self, value): - return _const(value) - - def on_AnyValue(self, value): - if value in self.s.anys: - return self.s.anys[value] - - res_shape = value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$" + value.kind.value, ports={ - "\\Y": res, - }, params={ - "WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - self.s.anys[value] = res - return res - - def on_Initial(self, value): - res = self.s.rtlil.wire(width=1, src=_src(value.src_loc)) - self.s.rtlil.cell("$initstate", ports={ - "\\Y": res, - }, src=_src(value.src_loc)) - return res - - def on_Signal(self, value): - wire_curr, wire_next = self.s.resolve(value) - return wire_curr - - def on_Operator_unary(self, value): - arg, = value.operands - if value.operator in ("u", "s"): - # These operators don't change the bit pattern, only its interpretation. - return self(arg) - - arg_shape, res_shape = arg.shape(), value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell(self.operator_map[(1, value.operator)], ports={ - "\\A": self(arg), - "\\Y": res, - }, params={ - "A_SIGNED": arg_shape.signed, - "A_WIDTH": arg_shape.width, - "Y_WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - return res - - def match_shape(self, value, new_shape): - if isinstance(value, _ast.Const): - return self(_ast.Const(value.value, new_shape)) - - value_shape = value.shape() - if new_shape.width <= value_shape.width: - return self(_ast.Slice(value, 0, new_shape.width)) - - res = self.s.rtlil.wire(width=new_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$pos", ports={ - "\\A": self(value), - "\\Y": res, - }, params={ - "A_SIGNED": value_shape.signed, - "A_WIDTH": value_shape.width, - "Y_WIDTH": new_shape.width, - }, src=_src(value.src_loc)) - return res - - def on_Operator_binary(self, value): - lhs, rhs = value.operands - lhs_shape, rhs_shape, res_shape = lhs.shape(), rhs.shape(), value.shape() - if lhs_shape.signed == rhs_shape.signed or value.operator in ("<<", ">>", "**"): - lhs_wire = self(lhs) - rhs_wire = self(rhs) - else: - lhs_shape = rhs_shape = _ast.signed(max(lhs_shape.width + rhs_shape.signed, - rhs_shape.width + lhs_shape.signed)) - lhs_wire = self.match_shape(lhs, lhs_shape) - rhs_wire = self.match_shape(rhs, rhs_shape) - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell(self.operator_map[(2, value.operator)], ports={ - "\\A": lhs_wire, - "\\B": rhs_wire, - "\\Y": res, - }, params={ - "A_SIGNED": lhs_shape.signed, - "A_WIDTH": lhs_shape.width, - "B_SIGNED": rhs_shape.signed, - "B_WIDTH": rhs_shape.width, - "Y_WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - if value.operator in ("//", "%"): - # RTLIL leaves division by zero undefined, but we require it to return zero. - divmod_res = res - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$mux", ports={ - "\\A": divmod_res, - "\\B": self(_ast.Const(0, res_shape)), - "\\S": self(rhs == 0), - "\\Y": res, - }, params={ - "WIDTH": res_shape.width - }, src=_src(value.src_loc)) - return res - - def on_Operator_mux(self, value): - sel, val1, val0 = value.operands - if len(sel) != 1: - sel = sel.bool() - res_shape = value.shape() - val1_wire = self.match_shape(val1, res_shape) - val0_wire = self.match_shape(val0, res_shape) - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$mux", ports={ - "\\A": val0_wire, - "\\B": val1_wire, - "\\S": self(sel), - "\\Y": res, - }, params={ - "WIDTH": res_shape.width - }, src=_src(value.src_loc)) - return res - - def on_Operator(self, value): - if len(value.operands) == 1: - return self.on_Operator_unary(value) - elif len(value.operands) == 2: - return self.on_Operator_binary(value) - elif len(value.operands) == 3: - assert value.operator == "m" - return self.on_Operator_mux(value) - else: - raise TypeError # :nocov: - - def _prepare_value_for_Slice(self, value): - if isinstance(value, (_ast.Signal, _ast.Slice, _ast.Cat)): - sigspec = self(value) - else: - sigspec = self.s.rtlil.wire(len(value), src=_src(value.src_loc)) - self.s.rtlil.connect(sigspec, self(value)) - return sigspec - - def on_Part(self, value): - lhs, rhs = value.value, value.offset - if value.stride != 1: - rhs *= value.stride - lhs_shape, rhs_shape, res_shape = lhs.shape(), rhs.shape(), value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - # Note: Verilog's x[o+:w] construct produces a $shiftx cell, not a $shift cell. - # However, Amaranth's semantics defines the out-of-range bits to be zero, so it is correct - # to use a $shift cell here instead, even though it produces less idiomatic Verilog. - self.s.rtlil.cell("$shift", ports={ - "\\A": self(lhs), - "\\B": self(rhs), - "\\Y": res, - }, params={ - "A_SIGNED": lhs_shape.signed, - "A_WIDTH": lhs_shape.width, - "B_SIGNED": rhs_shape.signed, - "B_WIDTH": rhs_shape.width, - "Y_WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - return res - - -class _LHSValueCompiler(_ValueCompiler): - def on_Const(self, value): - raise TypeError # :nocov: - - def on_AnyValue(self, value): - raise TypeError # :nocov: - - def on_Initial(self, value): - raise TypeError # :nocov: - - def on_Operator(self, value): - if value.operator in ("u", "s"): - # These operators are transparent on the LHS. - arg, = value.operands - return self(arg) - - raise TypeError # :nocov: - - def match_shape(self, value, new_shape): - value_shape = value.shape() - if new_shape.width == value_shape.width: - return self(value) - elif new_shape.width < value_shape.width: - return self(_ast.Slice(value, 0, new_shape.width)) - else: # new_shape.width > value_shape.width - dummy_bits = new_shape.width - value_shape.width - dummy_wire = self.s.rtlil.wire(dummy_bits) - return f"{{ {dummy_wire} {self(value)} }}" - - def on_Signal(self, value): - if value not in self.s.driven: - raise ValueError(f"No LHS wire for non-driven signal {value!r}") - wire_curr, wire_next = self.s.resolve(value) - return wire_next or wire_curr - - def _prepare_value_for_Slice(self, value): - assert isinstance(value, (_ast.Signal, _ast.Slice, _ast.Cat, _ast.Part)) - return self(value) - - def on_Part(self, value): - offset = self.s.expand(value.offset) - if isinstance(offset, _ast.Const): - start = offset.value * value.stride - stop = start + value.width - slice = self(_ast.Slice(value.value, start, min(len(value.value), stop))) - if len(value.value) >= stop: - return slice - else: - dummy_wire = self.s.rtlil.wire(stop - len(value.value)) - return f"{{ {dummy_wire} {slice} }}" - else: - # Only so many possible parts. The amount of branches is exponential; if value.offset - # is large (e.g. 32-bit wide), trying to naively legalize it is likely to exhaust - # system resources. - max_branches = len(value.value) // value.stride + 1 - raise _LegalizeValue(value.offset, - range(1 << len(value.offset))[:max_branches], - value.src_loc) - - -class _StatementCompiler(_xfrm.StatementVisitor): - def __init__(self, state, rhs_compiler, lhs_compiler): - self.state = state - self.rhs_compiler = rhs_compiler - self.lhs_compiler = lhs_compiler - - self._domain = None - self._case = None - self._test_cache = {} - self._has_rhs = False - self._wrap_assign = False - - @contextmanager - def case(self, switch, values, attrs={}, src=""): - try: - old_case = self._case - with switch.case(*values, attrs=attrs, src=src) as self._case: - yield - finally: - self._case = old_case - - def _check_rhs(self, value): - if self._has_rhs or next(iter(value._rhs_signals()), None) is not None: - self._has_rhs = True - - def on_Assign(self, stmt): - self._check_rhs(stmt.rhs) - - lhs_shape, rhs_shape = stmt.lhs.shape(), stmt.rhs.shape() - if lhs_shape.width == rhs_shape.width: - rhs_sigspec = self.rhs_compiler(stmt.rhs) - else: - # In RTLIL, LHS and RHS of assignment must have exactly same width. - rhs_sigspec = self.rhs_compiler.match_shape(stmt.rhs, lhs_shape) - if self._wrap_assign: - # In RTLIL, all assigns are logically sequenced before all switches, even if they are - # interleaved in the source. In Amaranth, the source ordering is used. To handle this - # mismatch, we wrap all assigns following a switch in a dummy switch. - with self._case.switch("{ }") as wrap_switch: - with wrap_switch.case() as wrap_case: - wrap_case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) - else: - self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) - - def on_Property(self, stmt): - self(stmt._check.eq(stmt.test)) - self(stmt._en.eq(1)) - - en_wire = self.rhs_compiler(stmt._en) - check_wire = self.rhs_compiler(stmt._check) - self.state.rtlil.cell("$" + stmt.kind.value, ports={ - "\\A": check_wire, - "\\EN": en_wire, - }, src=_src(stmt.src_loc), name=stmt.name) - - def on_Switch(self, stmt): - self._check_rhs(stmt.test) - - if not self.state.expansions: - # We repeatedly translate the same switches over and over (see the LHSGroupAnalyzer - # related code below), and translating the switch test only once helps readability. - if stmt not in self._test_cache: - self._test_cache[stmt] = self.rhs_compiler(stmt.test) - test_sigspec = self._test_cache[stmt] - else: - # However, if the switch test contains an illegal value, then it may not be cached - # (since the illegal value will be repeatedly replaced with different constants), so - # don't cache anything in that case. - test_sigspec = self.rhs_compiler(stmt.test) - - with self._case.switch(test_sigspec, src=_src(stmt.src_loc)) as switch: - for values, stmts in stmt.cases.items(): - case_attrs = {} - case_src = None - if values in stmt.case_src_locs: - case_src = _src(stmt.case_src_locs[values]) - if isinstance(stmt.test, _ast.Signal) and stmt.test.decoder: - decoded_values = [] - for value in values: - if "-" in value: - decoded_values.append("") - else: - decoded_values.append(stmt.test.decoder(int(value, 2))) - case_attrs["amaranth.decoding"] = "|".join(decoded_values) - with self.case(switch, values, attrs=case_attrs, src=case_src): - self._wrap_assign = False - self.on_statements(stmts) - self._wrap_assign = True - - def on_statement(self, stmt): - try: - super().on_statement(stmt) - except _LegalizeValue as legalize: - with self._case.switch(self.rhs_compiler(legalize.value), - src=_src(legalize.src_loc)) as switch: - shape = legalize.value.shape() - tests = ["{:0{}b}".format(v, shape.width) for v in legalize.branches] - if tests: - tests[-1] = "-" * shape.width - for branch, test in zip(legalize.branches, tests): - with self.case(switch, (test,)): - self._wrap_assign = False - branch_value = _ast.Const(branch, shape) - with self.state.expand_to(legalize.value, branch_value): - self.on_statement(stmt) - self._wrap_assign = True - - def on_statements(self, stmts): - for stmt in stmts: - self.on_statement(stmt) - - -def _convert_fragment(builder, fragment, name_map, hierarchy): - if isinstance(fragment, _ir.Instance): - port_map = OrderedDict() - for port_name, (value, dir) in fragment.named_ports.items(): - port_map[f"\\{port_name}"] = value - - params = OrderedDict(fragment.parameters) - - if fragment.type[0] == "$": - return fragment.type, port_map, params - else: - return f"\\{fragment.type}", port_map, params - - if isinstance(fragment, _mem.MemoryInstance): - init = "".join(format(_ast.Const(elem, _ast.unsigned(fragment._width)).value, f"0{fragment._width}b") for elem in reversed(fragment._init)) - init = _ast.Const(int(init or "0", 2), fragment._depth * fragment._width) - rd_clk = [] - rd_clk_enable = 0 - rd_clk_polarity = 0 - rd_transparency_mask = 0 - for index, port in enumerate(fragment._read_ports): - if port._domain != "comb": - cd = fragment.domains[port._domain] - rd_clk.append(cd.clk) - if cd.clk_edge == "pos": - rd_clk_polarity |= 1 << index - rd_clk_enable |= 1 << index - for write_index in port._transparency: - rd_transparency_mask |= 1 << (index * len(fragment._write_ports) + write_index) - else: - rd_clk.append(_ast.Const(0, 1)) - wr_clk = [] - wr_clk_enable = 0 - wr_clk_polarity = 0 - for index, port in enumerate(fragment._write_ports): - cd = fragment.domains[port._domain] - wr_clk.append(cd.clk) - wr_clk_enable |= 1 << index - if cd.clk_edge == "pos": - wr_clk_polarity |= 1 << index - params = { - "MEMID": builder._make_name(hierarchy[-1], local=False), - "SIZE": fragment._depth, - "OFFSET": 0, - "ABITS": _ast.Shape.cast(range(fragment._depth)).width, - "WIDTH": fragment._width, - "INIT": init, - "RD_PORTS": len(fragment._read_ports), - "RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment._read_ports))), - "RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment._read_ports))), - "RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment._read_ports) * len(fragment._write_ports))), - "RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment._read_ports) * len(fragment._write_ports))), - "RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._read_ports))), - "RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment._read_ports))), - "RD_ARST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), - "RD_SRST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), - "RD_INIT_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), - "WR_PORTS": len(fragment._write_ports), - "WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment._write_ports))), - "WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment._write_ports))), - "WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment._write_ports) * len(fragment._write_ports))), - "WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._write_ports))), - } - def make_en(port): - if len(port._data) == 0: - return _ast.Const(0, 0) - granularity = len(port._data) // len(port._en) - return _ast.Cat(en_bit.replicate(granularity) for en_bit in port._en) - port_map = { - "\\RD_CLK": _ast.Cat(rd_clk), - "\\RD_EN": _ast.Cat(port._en for port in fragment._read_ports), - "\\RD_ARST": _ast.Const(0, len(fragment._read_ports)), - "\\RD_SRST": _ast.Const(0, len(fragment._read_ports)), - "\\RD_ADDR": _ast.Cat(port._addr for port in fragment._read_ports), - "\\RD_DATA": _ast.Cat(port._data for port in fragment._read_ports), - "\\WR_CLK": _ast.Cat(wr_clk), - "\\WR_EN": _ast.Cat(make_en(port) for port in fragment._write_ports), - "\\WR_ADDR": _ast.Cat(port._addr for port in fragment._write_ports), - "\\WR_DATA": _ast.Cat(port._data for port in fragment._write_ports), - } - return "$mem_v2", port_map, params - - module_name = ".".join(name or "anonymous" for name in hierarchy) - module_attrs = OrderedDict() - if len(hierarchy) == 1: - module_attrs["top"] = 1 - - with builder.module(module_name, attrs=module_attrs) as module: - compiler_state = _ValueCompilerState(module) - rhs_compiler = _RHSValueCompiler(compiler_state) - lhs_compiler = _LHSValueCompiler(compiler_state) - stmt_compiler = _StatementCompiler(compiler_state, rhs_compiler, lhs_compiler) - - # Register all signals driven in the current fragment. This must be done first, as it - # affects further codegen; e.g. whether \sig$next signals will be generated and used. - for domain, statements in fragment.statements.items(): - for signal in statements._lhs_signals(): - compiler_state.add_driven(signal, sync=domain != "comb") - - # Transform all signals used as ports in the current fragment eagerly and outside of - # any hierarchy, to make sure they get sensible (non-prefixed) names. - for signal in fragment.ports: - compiler_state.add_port(signal, fragment.ports[signal]) - compiler_state.resolve_curr(signal) - - # Transform all clocks clocks and resets eagerly and outside of any hierarchy, to make - # sure they get sensible (non-prefixed) names. This does not affect semantics. - for domain, _ in fragment.iter_sync(): - cd = fragment.domains[domain] - compiler_state.resolve_curr(cd.clk) - if cd.rst is not None: - compiler_state.resolve_curr(cd.rst) - - # Transform all subfragments to their respective cells. Transforming signals connected - # to their ports into wires eagerly makes sure they get sensible (prefixed with submodule - # name) names. - for subfragment, sub_name in fragment.subfragments: - if not (subfragment.ports or subfragment.statements or subfragment.subfragments or - isinstance(subfragment, (_ir.Instance, _mem.MemoryInstance))): - # If the fragment is completely empty, skip translating it, otherwise synthesis - # tools (including Yosys and Vivado) will treat it as a black box when it is - # loaded after conversion to Verilog. - continue - - if sub_name is None: - sub_name = module.anonymous() - - sub_type, sub_port_map, sub_params = \ - _convert_fragment(builder, subfragment, name_map, - hierarchy=hierarchy + (sub_name,)) - - sub_ports = OrderedDict() - for port, value in sub_port_map.items(): - if not isinstance(subfragment, (_ir.Instance, _mem.MemoryInstance)): - for signal in value._rhs_signals(): - compiler_state.resolve_curr(signal, prefix=sub_name) - if len(value) > 0 or sub_type == "$mem_v2": - sub_ports[port] = rhs_compiler(value) - - if isinstance(subfragment, _ir.Instance): - src = _src(subfragment.src_loc) - elif isinstance(subfragment, _mem.MemoryInstance): - src = _src(subfragment._src_loc) - else: - src = "" - - module.cell(sub_type, name=sub_name, ports=sub_ports, params=sub_params, - attrs=subfragment.attrs, src=src) - - # If we emit all of our combinatorial logic into a single RTLIL process, Verilog - # simulators will break horribly, because Yosys write_verilog transforms RTLIL processes - # into always @* blocks with blocking assignment, and that does not create delta cycles. +class MemoryInfo: + def __init__(self, memid): + self.memid = memid + self.num_write_ports = 0 + self.write_port_ids = {} + + +class ModuleEmitter: + def __init__(self, builder, netlist, module, name_map, empty_checker): + self.builder = builder + self.netlist = netlist + self.module = module + self.name_map = name_map + self.empty_checker = empty_checker + + # Internal state of the emitter. This conceptually consists of three parts: + # (1) memory information; + # (2) name and attribute preferences for wires corresponding to signals; + # (3) mapping of Amaranth netlist entities to RTLIL netlist entities. + # Value names are preferences: they are candidate names for values that may or may not get + # used for cell outputs. Attributes are mandatory: they are always emitted, but can be + # squashed if several signals end up aliasing the same driven wire. + self.memories = {} # cell idx -> MemoryInfo + self.value_names = {} # value -> signal or port name + self.value_attrs = {} # value -> dict + self.sigport_wires = {} # signal or port name -> (wire, value) + self.driven_sigports = set() # set of signal or port name + self.nets = {} # net -> (wire name, bit idx) + self.cell_wires = {} # cell idx -> wire name + self.instance_wires = {} # (cell idx, output name) -> wire name + + def emit(self): + self.collect_memory_info() + self.assign_value_names() + self.collect_init_attrs() + self.emit_signal_wires() + self.emit_port_wires() + self.emit_cell_wires() + self.emit_submodule_wires() + self.emit_connects() + self.emit_submodules() + self.emit_cells() + + def collect_memory_info(self): + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.Memory): + self.memories[cell_idx] = MemoryInfo( + self.builder.memory(cell.width, cell.depth, name=cell.name, + attrs=cell.attributes, src=_src(cell.src_loc))) + + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.SyncWritePort): + memory_info = self.memories[cell.memory] + memory_info.write_port_ids[cell_idx] = memory_info.num_write_ports + memory_info.num_write_ports += 1 + + def assign_value_names(self): + for signal, name in self.module.signal_names.items(): + value = self.netlist.signals[signal] + if value not in self.value_names: + self.value_names[value] = name + + def collect_init_attrs(self): + # Flip-flops are special in Yosys; the initial value is stored not as a cell parameter but + # as an attribute of a wire connected to the output of the flip-flop. The claimed benefit + # of this arrangement is that fine cells, which cannot have parameters (so that certain + # backends, like BLIF, which cannot represent parameters--or attributes--can be used to + # emit these cells), then do not need to have 3x more variants (one for initialized to 0, + # one for 1, one for X). # - # Therefore, we translate the fragment as many times as there are independent groups - # of signals (a group is a transitive closure of signals that appear together on LHS), - # splitting them into many RTLIL (and thus Verilog) processes. - for domain, statements in fragment.statements.items(): - lhs_grouper = _xfrm.LHSGroupAnalyzer() - lhs_grouper.on_statements(statements) + # At the time of writing, 2024-02-11, Yosys has 125 (one hundred twenty five) fine FF cells, + # which are generated by a Python script because they have gotten completely out of hand + # long ago and no one could keep track of them manually. This list features such beauties + # as $_DFFSRE_PPPN_ and its other 7 cousins. + # + # These are real cells, used by real Yosys developers! Look at what they have done for us, + # with all the subtly unsynthesizable Verilog we sent them and all of the incompatibilities + # with vendor toolchains we reported! + # + # Nothing is fine about these cells. The decision to have `init` as a wire attribute is + # quite possibly the single worst design decision in Yosys, and not having to dealing with + # that bullshit again is enough of a reason to implement an FPGA toolchain from scratch. + # + # Just have 375 fine cells, bro. Trust me bro. You will certainly not regret having 375 + # fine cells in your toolchain. Or at least you will be able to process netlists without + # having to special-case this one godforsaken attribute every time you look at a wire. + # + # -- @whitequark + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.FlipFlop): + width = len(cell.data) + attrs = {"init": _ast.Const(cell.init, width), **cell.attributes} + value = _nir.Value(_nir.Net.from_cell(cell_idx, bit) for bit in range(width)) + self.value_attrs[value] = attrs - for group, group_signals in lhs_grouper.groups().items(): - lhs_group_filter = _xfrm.LHSGroupFilter(group_signals) - group_stmts = lhs_group_filter(statements) + def emit_signal_wires(self): + for signal, name in self.module.signal_names.items(): + value = self.netlist.signals[signal] - with module.process(name=f"$group_{group}") as process: - with process.case() as case: - # For every signal in comb domain, assign \sig$next to the reset value. - # For every signal in sync domains, assign \sig$next to the current - # value (\sig). - for signal in group_signals: - if domain == "comb": - prev_value = _ast.Const(signal.reset, signal.width) + # One of: (1) empty and created here, (2) `init` filled in by `collect_init_attrs`, + # (3) populated by some other signal aliasing the same nets. In the last case, we will + # glue attributes for these signals together, but synthesizers (including Yosys, when + # the design is flattened) will do that anyway, so it doesn't matter. + attrs = self.value_attrs.setdefault(value, {}) + attrs.update(signal.attrs) + + for repr in signal._value_repr: + if repr.path == () and isinstance(repr.format, _repr.FormatEnum): + enum = repr.format.enum + attrs["enum_base_type"] = enum.__name__ + for enum_value in enum: + attrs["enum_value_{:0{}b}".format(enum_value.value, signal.width)] = enum_value.name + + if name in self.module.ports: + port_value, _flow = self.module.ports[name] + assert value == port_value + self.name_map[signal] = (*self.module.name, f"\\{name}") + else: + wire = self.builder.wire(width=signal.width, name=name, attrs=attrs, + src=_src(signal.src_loc)) + self.sigport_wires[name] = (wire, value) + self.name_map[signal] = (*self.module.name, wire) + + def emit_port_wires(self): + for port_id, (name, (value, flow)) in enumerate(self.module.ports.items()): + wire = self.builder.wire(width=len(value), port_id=port_id, port_kind=flow.value, + name=name, attrs=self.value_attrs.get(value, {})) + self.sigport_wires[name] = (wire, value) + if flow == _nir.ModuleNetFlow.OUTPUT: + continue + # If we just emitted an input or inout port, it is driving the value. + self.driven_sigports.add(name) + for bit, net in enumerate(value): + self.nets[net] = (wire, bit) + + def emit_driven_wire(self, value): + # Emits a wire for a value, in preparation for driving it. + if value in self.value_names: + # If there is a signal or port matching this value, reuse its wire as the canonical + # wire of the nets involved. + name = self.value_names[value] + wire, named_value = self.sigport_wires[name] + assert value == named_value, \ + f"Inconsistent values {value!r}, {named_value!r} for wire {name!r}" + self.driven_sigports.add(name) + else: + # Otherwise, make an anonymous wire. + wire = self.builder.wire(len(value), attrs=self.value_attrs.get(value, {})) + for bit, net in enumerate(value): + self.nets[net] = (wire, bit) + return wire + + def emit_cell_wires(self): + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.Top): + continue + elif isinstance(cell, _nir.Instance): + for name, (start, width) in cell.ports_o.items(): + nets = [_nir.Net.from_cell(cell_idx, start + bit) for bit in range(width)] + wire = self.emit_driven_wire(_nir.Value(nets)) + self.instance_wires[cell_idx, name] = wire + continue # Instances use one wire per output, not per cell. + elif isinstance(cell, (_nir.PriorityMatch, _nir.Matches)): + continue # Inlined into assignment lists. + elif isinstance(cell, (_nir.SyncProperty, _nir.AsyncProperty, _nir.Memory, + _nir.SyncWritePort)): + continue # No outputs. + elif isinstance(cell, _nir.AssignmentList): + width = len(cell.default) + elif isinstance(cell, (_nir.Operator, _nir.Part, _nir.ArrayMux, _nir.AnyValue, + _nir.SyncReadPort, _nir.AsyncReadPort)): + width = cell.width + elif isinstance(cell, _nir.FlipFlop): + width = len(cell.data) + elif isinstance(cell, _nir.Initial): + width = 1 + elif isinstance(cell, _nir.IOBuffer): + width = len(cell.pad) + else: + assert False # :nocov: + # Single output cell connected to a wire. + nets = [_nir.Net.from_cell(cell_idx, bit) for bit in range(width)] + wire = self.emit_driven_wire(_nir.Value(nets)) + self.cell_wires[cell_idx] = wire + + def emit_submodule_wires(self): + for submodule_idx in self.module.submodules: + submodule = self.netlist.modules[submodule_idx] + for _name, (value, flow) in submodule.ports.items(): + if flow == _nir.ModuleNetFlow.OUTPUT: + self.emit_driven_wire(value) + + def sigspec(self, *parts: '_nir.Net | Iterable[_nir.Net]'): + value = _nir.Value() + for part in parts: + value += _nir.Value(part) + + chunks = [] + begin_pos = 0 + while begin_pos < len(value): + end_pos = begin_pos + if value[begin_pos].is_const: + while end_pos < len(value) and value[end_pos].is_const: + end_pos += 1 + width = end_pos - begin_pos + bits = "".join(str(net.const) for net in value[begin_pos:end_pos]) + chunks.append(f"{width}'{bits[::-1]}") + else: + wire, start_bit = self.nets[value[begin_pos]] + bit = start_bit + while (end_pos < len(value) and + not value[end_pos].is_const and + self.nets[value[end_pos]] == (wire, bit)): + end_pos += 1 + bit += 1 + width = end_pos - begin_pos + if width == 1: + chunks.append(f"{wire} [{start_bit}]") + else: + chunks.append(f"{wire} [{start_bit + width - 1}:{start_bit}]") + begin_pos = end_pos + + if len(chunks) == 1: + return chunks[0] + return "{ " + " ".join(reversed(chunks)) + " }" + + def emit_connects(self): + for name, (wire, value) in self.sigport_wires.items(): + if name not in self.driven_sigports: + self.builder.connect(wire, self.sigspec(value)) + + def emit_submodules(self): + for submodule_idx in self.module.submodules: + submodule = self.netlist.modules[submodule_idx] + if not self.empty_checker.is_empty(submodule_idx): + dotted_name = ".".join(submodule.name) + self.builder.cell(f"\\{dotted_name}", submodule.name[-1], ports={ + name: self.sigspec(value) + for name, (value, _flow) in submodule.ports.items() + }) + + def emit_assignment_list(self, cell_idx, cell): + def emit_assignments(case, cond): + # Emits assignments from the assignment list into the given case. + # ``cond`` is the net which is the condition for ``case`` being active. + # Returns once it hits an assignment whose condition is not nested within ``cond``, + # letting parent invocation take care of the remaining assignments. + nonlocal pos + + emitted_switch = False + while pos < len(cell.assignments): + assign = cell.assignments[pos] + if assign.cond == cond and not emitted_switch: + # Not nested, and we didn't emit a switch yet, so emit the assignment. + case.assign(self.sigspec(lhs[assign.start:assign.start + len(assign.value)]), + self.sigspec(assign.value)) + pos += 1 + elif assign.cond == cond: + # Not nested, but we emitted a subswitch. Wrap the assignments in a dummy + # switch. This is necessary because Yosys executes all assignments before all + # subswitches (but allows you to mix asssignments and switches in RTLIL, for + # maximum confusion). + with case.switch("{ }") as switch: + with switch.case("") as subcase: + while pos < len(cell.assignments): + assign = cell.assignments[pos] + if assign.cond == cond: + subcase.assign(self.sigspec(lhs[assign.start:assign.start + + len(assign.value)]), + self.sigspec(assign.value)) + pos += 1 + else: + break + else: + # Condition doesn't match this case's condition — either we encountered + # a nested condition, or we should break out. Try to find out exactly + # how we are nested. + search_cond = assign.cond + while True: + if search_cond == cond: + # We have found the PriorityMatch cell that we should enter. + break + if search_cond == _nir.Net.from_const(1): + # If this isn't nested condition, go back to parent invocation. + return + # Grab the PriorityMatch cell that is on the next level of nesting. + priority_cell_idx = search_cond.cell + priority_cell = self.netlist.cells[priority_cell_idx] + assert isinstance(priority_cell, _nir.PriorityMatch) + search_cond = priority_cell.en + # We assume that: + # 1. PriorityMatch inputs can only be Match cell outputs, or constant 1. + # 2. All Match cells driving a given PriorityMatch cell test the same value. + # Grab the tested value from a random Match cell. + test = _nir.Value() + for net in priority_cell.inputs: + if net != _nir.Net.from_const(1): + matches_cell = self.netlist.cells[net.cell] + assert isinstance(matches_cell, _nir.Matches) + test = matches_cell.value + break + # Now emit cases for all PriorityMatch inputs, in sequence. Consume as many + # assignments as possible along the way. + with case.switch(self.sigspec(test)) as switch: + for bit, net in enumerate(priority_cell.inputs): + subcond = _nir.Net.from_cell(priority_cell_idx, bit) + if net == _nir.Net.from_const(1): + patterns = () else: - prev_value = signal - case.assign(lhs_compiler(signal), rhs_compiler(prev_value)) + # Validate the above assumptions. + matches_cell = self.netlist.cells[net.cell] + assert isinstance(matches_cell, _nir.Matches) + assert test == matches_cell.value + patterns = matches_cell.patterns + with switch.case(*patterns) as subcase: + emit_assignments(subcase, subcond) + emitted_switch = True - # Convert statements into decision trees. - stmt_compiler._domain = domain - stmt_compiler._case = case - stmt_compiler._has_rhs = False - stmt_compiler._wrap_assign = False - stmt_compiler(group_stmts) + lhs = _nir.Value(_nir.Net.from_cell(cell_idx, bit) for bit in range(len(cell.default))) + with self.builder.process(src=_src(cell.src_loc)) as proc: + with proc.case() as root_case: + root_case.assign(self.sigspec(lhs), self.sigspec(cell.default)) - # For every driven signal in the sync domain, create a flop of appropriate type. Which type - # is appropriate depends on the domain: for domains with sync reset, it is a $dff, for - # domains with async reset it is an $adff. The latter is directly provided with the reset - # value as a parameter to the cell, which is directly assigned during reset. - for domain, signal in fragment.iter_sync(): - cd = fragment.domains[domain] + pos = 0 # nonlocally used in `emit_assignments` + emit_assignments(root_case, _nir.Net.from_const(1)) + assert pos == len(cell.assignments) - wire_clk = compiler_state.resolve_curr(cd.clk) - wire_rst = compiler_state.resolve_curr(cd.rst) if cd.rst is not None else None - wire_curr, wire_next = compiler_state.resolve(signal) - - if not cd.async_reset: - # For sync reset flops, the reset value comes from logic inserted by - # `hdl.xfrm.DomainLowerer`. - module.cell("$dff", ports={ - "\\CLK": wire_clk, - "\\D": wire_next, - "\\Q": wire_curr + def emit_operator(self, cell_idx, cell): + UNARY_OPERATORS = { + "-": "$neg", + "~": "$not", + "b": "$reduce_bool", + "r|": "$reduce_or", + "r&": "$reduce_and", + "r^": "$reduce_xor", + } + BINARY_OPERATORS = { + # A_SIGNED, B_SIGNED + "+": ("$add", False, False), + "-": ("$sub", False, False), + "*": ("$mul", False, False), + "u//": ("$divfloor", False, False), + "s//": ("$divfloor", True, True), + "u%": ("$modfloor", False, False), + "s%": ("$modfloor", True, True), + "<<": ("$shl", False, False), + "u>>": ("$shr", False, False), + "s>>": ("$sshr", True, False), + "&": ("$and", False, False), + "|": ("$or", False, False), + "^": ("$xor", False, False), + "==": ("$eq", False, False), + "!=": ("$ne", False, False), + "u<": ("$lt", False, False), + "u>": ("$gt", False, False), + "u<=": ("$le", False, False), + "u>=": ("$ge", False, False), + "s<": ("$lt", True, True), + "s>": ("$gt", True, True), + "s<=": ("$le", True, True), + "s>=": ("$ge", True, True), + } + if len(cell.inputs) == 1: + cell_type = UNARY_OPERATORS[cell.operator] + operand, = cell.inputs + self.builder.cell(cell_type, ports={ + "A": self.sigspec(operand), + "Y": self.cell_wires[cell_idx] + }, params={ + "A_SIGNED": False, + "A_WIDTH": len(operand), + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) + elif len(cell.inputs) == 2: + cell_type, a_signed, b_signed = BINARY_OPERATORS[cell.operator] + operand_a, operand_b = cell.inputs + if cell.operator in ("u//", "s//", "u%", "s%"): + result = self.builder.wire(cell.width) + self.builder.cell(cell_type, ports={ + "A": self.sigspec(operand_a), + "B": self.sigspec(operand_b), + "Y": result, }, params={ - "CLK_POLARITY": int(cd.clk_edge == "pos"), - "WIDTH": signal.width - }) + "A_SIGNED": a_signed, + "B_SIGNED": b_signed, + "A_WIDTH": len(operand_a), + "B_WIDTH": len(operand_b), + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) + nonzero = self.builder.wire(1) + self.builder.cell("$reduce_bool", ports={ + "A": self.sigspec(operand_b), + "Y": nonzero, + }, params={ + "A_SIGNED": False, + "A_WIDTH": len(operand_b), + "Y_WIDTH": 1, + }, src=_src(cell.src_loc)) + self.builder.cell("$mux", ports={ + "S": nonzero, + "A": self.sigspec(_nir.Value.zeros(cell.width)), + "B": result, + "Y": self.cell_wires[cell_idx] + }, params={ + "WIDTH": cell.width, + }, src=_src(cell.src_loc)) else: - # For async reset flops, the reset value is provided directly to the cell. - module.cell("$adff", ports={ - "\\ARST": wire_rst, - "\\CLK": wire_clk, - "\\D": wire_next, - "\\Q": wire_curr + self.builder.cell(cell_type, ports={ + "A": self.sigspec(operand_a), + "B": self.sigspec(operand_b), + "Y": self.cell_wires[cell_idx], }, params={ - "ARST_POLARITY": _ast.Const(1), - "ARST_VALUE": _ast.Const(signal.reset, signal.width), - "CLK_POLARITY": int(cd.clk_edge == "pos"), - "WIDTH": signal.width - }) + "A_SIGNED": a_signed, + "B_SIGNED": b_signed, + "A_WIDTH": len(operand_a), + "B_WIDTH": len(operand_b), + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) + else: + assert cell.operator == "m" + condition, if_true, if_false = cell.inputs + self.builder.cell("$mux", ports={ + "S": self.sigspec(condition), + "A": self.sigspec(if_false), + "B": self.sigspec(if_true), + "Y": self.cell_wires[cell_idx] + }, params={ + "WIDTH": cell.width, + }, src=_src(cell.src_loc)) - # Any signals that are used but neither driven nor connected to an input port always - # assume their reset values. We need to assign the reset value explicitly, since only - # driven sync signals are handled by the logic above. - # - # Because this assignment is done at a late stage, a single Signal object can get assigned - # many times, once in each module it is used. This is a deliberate decision; the possible - # alternatives are to add ports for undriven signals (which requires choosing one module - # to drive it to reset value arbitrarily) or to replace them with their reset value (which - # removes valuable source location information). - driven = _ast.SignalSet() - for domain, statements in fragment.statements.items(): - driven.update(statements._lhs_signals()) - driven.update(fragment.iter_ports(dir="i")) - driven.update(fragment.iter_ports(dir="io")) - for subfragment, sub_name in fragment.subfragments: - driven.update(subfragment.iter_ports(dir="o")) - driven.update(subfragment.iter_ports(dir="io")) + def emit_part(self, cell_idx, cell): + if cell.stride == 1: + offset = self.sigspec(cell.offset) + offset_width = len(cell.offset) + else: + stride = _ast.Const(cell.stride) + offset_width = len(cell.offset) + stride.width + offset = self.builder.wire(offset_width) + self.builder.cell("$mul", ports={ + "A": self.sigspec(cell.offset), + "B": _const(stride), + "Y": offset, + }, params={ + "A_SIGNED": False, + "B_SIGNED": False, + "A_WIDTH": len(cell.offset), + "B_WIDTH": stride.width, + "Y_WIDTH": offset_width, + }, src=_src(cell.src_loc)) + self.builder.cell("$shift", ports={ + "A": self.sigspec(cell.value), + "B": offset, + "Y": self.cell_wires[cell_idx], + }, params={ + "A_SIGNED": cell.value_signed, + "B_SIGNED": False, + "A_WIDTH": len(cell.value), + "B_WIDTH": offset_width, + "Y_WIDTH": cell.width, + }, src=_src(cell.src_loc)) - for wire in compiler_state.wires: - if wire in driven: - continue - wire_curr, _ = compiler_state.wires[wire] - module.connect(wire_curr, rhs_compiler(_ast.Const(wire.reset, wire.width))) + def emit_array_mux(self, cell_idx, cell): + wire = self.cell_wires[cell_idx] + with self.builder.process(src=_src(cell.src_loc)) as proc: + with proc.case() as root_case: + with root_case.switch(self.sigspec(cell.index)) as switch: + for index, elem in enumerate(cell.elems): + if len(cell.index) > 0: + pattern = "{:0{}b}".format(index, len(cell.index)) + else: + pattern = "" + with switch.case(pattern) as case: + case.assign(wire, self.sigspec(elem)) + with switch.case() as case: + case.assign(wire, self.sigspec(cell.elems[0])) - # Collect the names we've given to our ports in RTLIL, and correlate these with the signals - # represented by these ports. If we are a submodule, this will be necessary to create a cell - # for us in the parent module. - port_map = OrderedDict() - for signal in fragment.ports: - port_map[compiler_state.resolve_curr(signal)] = signal + def emit_flip_flop(self, cell_idx, cell): + ports = { + "D": self.sigspec(cell.data), + "CLK": self.sigspec(cell.clk), + "Q": self.cell_wires[cell_idx] + } + params = { + "WIDTH": len(cell.data), + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge] + } + if cell.arst == _nir.Net.from_const(0): + cell_type = "$dff" + else: + cell_type = "$adff" + ports["ARST"] = self.sigspec(cell.arst) + params["ARST_POLARITY"] = True + params["ARST_VALUE"] = _ast.Const(cell.init, len(cell.data)) + self.builder.cell(cell_type, ports=ports, params=params, src=_src(cell.src_loc)) - # Finally, collect the names we've given to each wire in RTLIL, and provide these to - # the caller, to allow manipulating them in the toolchain. - for signal in compiler_state.wires: - wire_name = compiler_state.resolve_curr(signal) - if wire_name.startswith("\\"): - wire_name = wire_name[1:] - name_map[signal] = hierarchy + (wire_name,) + def emit_io_buffer(self, cell_idx, cell): + self.builder.cell("$tribuf", ports={ + "Y": self.sigspec(cell.pad), + "A": self.sigspec(cell.o), + "EN": self.sigspec(cell.oe), + }, params={ + "WIDTH": len(cell.pad), + }, src=_src(cell.src_loc)) + self.builder.connect(self.cell_wires[cell_idx], self.sigspec(cell.pad)) - return module.name, port_map, {} + def emit_memory(self, cell_idx, cell): + memory_info = self.memories[cell_idx] + self.builder.cell("$meminit_v2", ports={ + "ADDR": self.sigspec(), + "DATA": self.sigspec( + _nir.Net.from_const((row >> bit) & 1) + for row in cell.init + for bit in range(cell.width) + ), + "EN": self.sigspec(_nir.Value.ones(cell.width)), + }, params={ + "MEMID": memory_info.memid, + "ABITS": 0, + "WIDTH": cell.width, + "WORDS": cell.depth, + "PRIORITY": 0, + }, src=_src(cell.src_loc)) + + def emit_write_port(self, cell_idx, cell): + memory_info = self.memories[cell.memory] + ports = { + "ADDR": self.sigspec(cell.addr), + "DATA": self.sigspec(cell.data), + "EN": self.sigspec(cell.en), + "CLK": self.sigspec(cell.clk), + } + params = { + "MEMID": memory_info.memid, + "ABITS": len(cell.addr), + "WIDTH": len(cell.data), + "CLK_ENABLE": True, + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge], + "PORTID": memory_info.write_port_ids[cell_idx], + "PRIORITY_MASK": 0, + } + self.builder.cell(f"$memwr_v2", ports=ports, params=params, src=_src(cell.src_loc)) + + def emit_read_port(self, cell_idx, cell): + memory_info = self.memories[cell.memory] + ports = { + "ADDR": self.sigspec(cell.addr), + "DATA": self.cell_wires[cell_idx], + "ARST": self.sigspec(_nir.Net.from_const(0)), + "SRST": self.sigspec(_nir.Net.from_const(0)), + } + if isinstance(cell, _nir.AsyncReadPort): + transparency_mask = 0 + if isinstance(cell, _nir.SyncReadPort): + transparency_mask = sum( + 1 << memory_info.write_port_ids[write_port_cell_index] + for write_port_cell_index in cell.transparent_for + ) + params = { + "MEMID": memory_info.memid, + "ABITS": len(cell.addr), + "WIDTH": cell.width, + "TRANSPARENCY_MASK": _ast.Const(transparency_mask, memory_info.num_write_ports), + "COLLISION_X_MASK": _ast.Const(0, memory_info.num_write_ports), + "ARST_VALUE": _ast.Const(0, cell.width), + "SRST_VALUE": _ast.Const(0, cell.width), + "INIT_VALUE": _ast.Const(0, cell.width), + "CE_OVER_SRST": False, + } + if isinstance(cell, _nir.AsyncReadPort): + ports.update({ + "EN": self.sigspec(_nir.Net.from_const(1)), + "CLK": self.sigspec(_nir.Net.from_const(0)), + }) + params.update({ + "CLK_ENABLE": False, + "CLK_POLARITY": True, + }) + if isinstance(cell, _nir.SyncReadPort): + ports.update({ + "EN": self.sigspec(cell.en), + "CLK": self.sigspec(cell.clk), + }) + params.update({ + "CLK_ENABLE": True, + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge], + }) + self.builder.cell(f"$memrd_v2", ports=ports, params=params, src=_src(cell.src_loc)) + + def emit_property(self, cell_idx, cell): + if isinstance(cell, _nir.AsyncProperty): + ports = { + "A": self.sigspec(cell.test), + "EN": self.sigspec(cell.en), + } + if isinstance(cell, _nir.SyncProperty): + test = self.builder.wire(1, attrs={"init": _ast.Const(0, 1)}) + en = self.builder.wire(1, attrs={"init": _ast.Const(0, 1)}) + for (d, q) in [ + (cell.test, test), + (cell.en, en), + ]: + ports = { + "D": self.sigspec(d), + "Q": q, + "CLK": self.sigspec(cell.clk), + } + params = { + "WIDTH": 1, + "CLK_POLARITY": { + "pos": True, + "neg": False, + }[cell.clk_edge], + } + self.builder.cell(f"$dff", ports=ports, params=params, src=_src(cell.src_loc)) + ports = { + "A": test, + "EN": en, + } + self.builder.cell(f"${cell.kind}", name=cell.name, ports=ports, src=_src(cell.src_loc)) + + def emit_any_value(self, cell_idx, cell): + self.builder.cell(f"${cell.kind}", ports={ + "Y": self.cell_wires[cell_idx], + }, params={ + "WIDTH": cell.width, + }, src=_src(cell.src_loc)) + + def emit_initial(self, cell_idx, cell): + self.builder.cell("$initstate", ports={ + "Y": self.cell_wires[cell_idx], + }, src=_src(cell.src_loc)) + + def emit_instance(self, cell_idx, cell): + ports = {} + for name, nets in cell.ports_i.items(): + ports[name] = self.sigspec(nets) + for name in cell.ports_o: + ports[name] = self.instance_wires[cell_idx, name] + for name, nets in cell.ports_io.items(): + ports[name] = self.sigspec(nets) + if cell.type.startswith("$"): + type = cell.type + else: + type = "\\" + cell.type + self.builder.cell(type, cell.name, ports=ports, params=cell.parameters, + attrs=cell.attributes, src=_src(cell.src_loc)) + + def emit_cells(self): + for cell_idx in self.module.cells: + cell = self.netlist.cells[cell_idx] + if isinstance(cell, _nir.Top): + pass + elif isinstance(cell, _nir.Matches): + pass # Matches is only referenced from PriorityMatch cells and inlined there + elif isinstance(cell, _nir.PriorityMatch): + pass # PriorityMatch is only referenced from AssignmentList cells and inlined there + elif isinstance(cell, _nir.AssignmentList): + self.emit_assignment_list(cell_idx, cell) + elif isinstance(cell, _nir.Operator): + self.emit_operator(cell_idx, cell) + elif isinstance(cell, _nir.Part): + self.emit_part(cell_idx, cell) + elif isinstance(cell, _nir.ArrayMux): + self.emit_array_mux(cell_idx, cell) + elif isinstance(cell, _nir.FlipFlop): + self.emit_flip_flop(cell_idx, cell) + elif isinstance(cell, _nir.IOBuffer): + self.emit_io_buffer(cell_idx, cell) + elif isinstance(cell, _nir.Memory): + self.emit_memory(cell_idx, cell) + elif isinstance(cell, _nir.SyncWritePort): + self.emit_write_port(cell_idx, cell) + elif isinstance(cell, (_nir.AsyncReadPort, _nir.SyncReadPort)): + self.emit_read_port(cell_idx, cell) + elif isinstance(cell, (_nir.AsyncProperty, _nir.SyncProperty)): + self.emit_property(cell_idx, cell) + elif isinstance(cell, _nir.AnyValue): + self.emit_any_value(cell_idx, cell) + elif isinstance(cell, _nir.Initial): + self.emit_initial(cell_idx, cell) + elif isinstance(cell, _nir.Instance): + self.emit_instance(cell_idx, cell) + else: + assert False # :nocov: + + +# Empty modules are interpreted by some toolchains (Yosys, Xilinx, ...) as black boxes, and +# must not be emitted. +class EmptyModuleChecker: + def __init__(self, netlist): + self.netlist = netlist + self.empty = set() + self.check(0) + + def check(self, module_idx): + is_empty = not self.netlist.modules[module_idx].cells + for submodule in self.netlist.modules[module_idx].submodules: + is_empty &= self.check(submodule) + if is_empty: + self.empty.add(module_idx) + return is_empty + + def is_empty(self, module_idx): + return module_idx in self.empty def convert_fragment(fragment, name="top", *, emit_src=True): assert isinstance(fragment, _ir.Fragment) - builder = _Builder(emit_src=emit_src) name_map = _ast.SignalDict() - _convert_fragment(builder, fragment, name_map, hierarchy=(name,)) + netlist = _ir.build_netlist(fragment, name=name) + empty_checker = EmptyModuleChecker(netlist) + builder = _Builder(emit_src=emit_src) + for module_idx, module in enumerate(netlist.modules): + if empty_checker.is_empty(module_idx): + continue + attrs = {"top": 1} if module_idx == 0 else {} + with builder.module(".".join(module.name), attrs=attrs) as module_builder: + ModuleEmitter(module_builder, netlist, module, name_map, + empty_checker=empty_checker).emit() return str(builder), name_map @@ -1047,7 +1002,7 @@ def convert(elaboratable, name="top", platform=None, *, ports=None, emit_src=Tru hasattr(elaboratable, "signature") and isinstance(elaboratable.signature, wiring.Signature)): ports = [] - for path, member, value in elaboratable.signature.flatten(elaboratable): + for _path, _member, value in elaboratable.signature.flatten(elaboratable): if isinstance(value, _ast.ValueCastable): value = value.as_value() if isinstance(value, _ast.Value): @@ -1055,5 +1010,5 @@ def convert(elaboratable, name="top", platform=None, *, ports=None, emit_src=Tru elif ports is None: raise TypeError("The `convert()` function requires a `ports=` argument") fragment = _ir.Fragment.get(elaboratable, platform).prepare(ports=ports, **kwargs) - il_text, name_map = convert_fragment(fragment, name, emit_src=emit_src) + il_text, _name_map = convert_fragment(fragment, name, emit_src=emit_src) return il_text diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 25c557b..7abe4e3 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1770,25 +1770,17 @@ class Property(Statement, MustUse): Assume = "assume" Cover = "cover" - def __init__(self, kind, test, *, _check=None, _en=None, name=None, src_loc_at=0): + def __init__(self, kind, test, *, name=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) self.kind = self.Kind(kind) self.test = Value.cast(test) - self._check = _check - self._en = _en self.name = name if not isinstance(self.name, str) and self.name is not None: raise TypeError("Property name must be a string or None, not {!r}" .format(self.name)) - if self._check is None: - self._check = Signal(reset_less=True, name=f"${self.kind.value}$check") - self._check.src_loc = self.src_loc - if _en is None: - self._en = Signal(reset_less=True, name=f"${self.kind.value}$en") - self._en.src_loc = self.src_loc def _lhs_signals(self): - return SignalSet((self._en, self._check)) + return set() def _rhs_signals(self): return self.test._rhs_signals() diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 8fdc919..5138de0 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -1,20 +1,17 @@ -from abc import ABCMeta +from typing import Tuple from collections import defaultdict, OrderedDict from functools import reduce import warnings -from .. import tracer -from .._utils import * -from .._unused import * -from ._ast import * -from ._ast import _StatementList -from ._cd import * +from .._utils import flatten, memoize +from .. import tracer, _unused +from . import _ast, _cd, _ir, _nir __all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"] -class UnusedElaboratable(UnusedMustUse): +class UnusedElaboratable(_unused.UnusedMustUse): # The warning is initially silenced. If everything that has been constructed remains unused, # it means the application likely crashed (with an exception, or in another way that does not # call `sys.excepthook`), and it's not necessary to show any warnings. @@ -22,7 +19,7 @@ class UnusedElaboratable(UnusedMustUse): _MustUse__silence = True -class Elaboratable(MustUse): +class Elaboratable(_unused.MustUse): _MustUse__warning = UnusedElaboratable @@ -64,7 +61,7 @@ class Fragment: obj = new_obj def __init__(self): - self.ports = SignalDict() + self.ports = _ast.SignalDict() self.drivers = OrderedDict() self.statements = {} self.domains = OrderedDict() @@ -89,7 +86,7 @@ class Fragment: def add_driver(self, signal, domain="comb"): assert isinstance(domain, str) if domain not in self.drivers: - self.drivers[domain] = SignalSet() + self.drivers[domain] = _ast.SignalSet() self.drivers[domain].add(signal) def iter_drivers(self): @@ -109,7 +106,7 @@ class Fragment: yield domain, signal def iter_signals(self): - signals = SignalSet() + signals = _ast.SignalSet() signals |= self.ports.keys() for domain, domain_signals in self.drivers.items(): if domain != "comb": @@ -122,7 +119,7 @@ class Fragment: def add_domains(self, *domains): for domain in flatten(domains): - assert isinstance(domain, ClockDomain) + assert isinstance(domain, _cd.ClockDomain) assert domain.name not in self.domains self.domains[domain.name] = domain @@ -131,9 +128,9 @@ class Fragment: def add_statements(self, domain, *stmts): assert isinstance(domain, str) - for stmt in Statement.cast(stmts): + for stmt in _ast.Statement.cast(stmts): stmt._MustUse__used = True - self.statements.setdefault(domain, _StatementList()).append(stmt) + self.statements.setdefault(domain, _ast._StatementList()).append(stmt) def add_subfragment(self, subfragment, name=None): assert isinstance(subfragment, Fragment) @@ -186,15 +183,15 @@ class Fragment: assert mode in ("silent", "warn", "error") from ._mem import MemoryInstance - driver_subfrags = SignalDict() + driver_subfrags = _ast.SignalDict() def add_subfrag(registry, entity, entry): # Because of missing domain insertion, at the point when this code runs, we have # a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to # the actual signals (because the signal itself can be driven as well); but leave # the unbound ones as it is, because there's no concrete signal for it yet anyway. - if isinstance(entity, ClockSignal) and entity.domain in self.domains: + if isinstance(entity, _ast.ClockSignal) and entity.domain in self.domains: entity = self.domains[entity.domain].clk - elif isinstance(entity, ResetSignal) and entity.domain in self.domains: + elif isinstance(entity, _ast.ResetSignal) and entity.domain in self.domains: entity = self.domains[entity.domain].rst if entity not in registry: @@ -265,7 +262,7 @@ class Fragment: return self._resolve_hierarchy_conflicts(hierarchy, mode) # Nothing was flattened, we're done! - return SignalSet(driver_subfrags.keys()) + return _ast.SignalSet(driver_subfrags.keys()) def _propagate_domains_up(self, hierarchy=("top",)): from ._xfrm import DomainRenamer @@ -296,18 +293,18 @@ class Fragment: if not all(names): names = sorted(f"" if n is None else f"'{n}'" for f, n, i in subfrags) - raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; " - "it is necessary to either rename subfragment domains " - "explicitly, or give names to subfragments" - .format(domain_name, ", ".join(names), ".".join(hierarchy))) + raise _cd.DomainError( + "Domain '{}' is defined by subfragments {} of fragment '{}'; it is necessary " + "to either rename subfragment domains explicitly, or give names to subfragments" + .format(domain_name, ", ".join(names), ".".join(hierarchy))) if len(names) != len(set(names)): names = sorted(f"#{i}" for f, n, i in subfrags) - raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', " - "some of which have identical names; it is necessary to either " - "rename subfragment domains explicitly, or give distinct names " - "to subfragments" - .format(domain_name, ", ".join(names), ".".join(hierarchy))) + raise _cd.DomainError( + "Domain '{}' is defined by subfragments {} of fragment '{}', some of which " + "have identical names; it is necessary to either rename subfragment domains " + "explicitly, or give distinct names to subfragments" + .format(domain_name, ", ".join(names), ".".join(hierarchy))) for subfrag, name, i in subfrags: domain_name_map = {domain_name: f"{name}_{domain_name}"} @@ -343,8 +340,8 @@ class Fragment: continue value = missing_domain(domain_name) if value is None: - raise DomainError(f"Domain '{domain_name}' is used but not defined") - if type(value) is ClockDomain: + raise _cd.DomainError(f"Domain '{domain_name}' is used but not defined") + if type(value) is _cd.ClockDomain: self.add_domains(value) # And expose ports on the newly added clock domain, since it is added directly # and there was no chance to add any logic driving it. @@ -353,7 +350,7 @@ class Fragment: new_fragment = Fragment.get(value, platform=platform) if domain_name not in new_fragment.domains: defined = new_fragment.domains.keys() - raise DomainError( + raise _cd.DomainError( "Fragment returned by missing domain callback does not define " "requested domain '{}' (defines {})." .format(domain_name, ", ".join(f"'{n}'" for n in defined))) @@ -463,12 +460,12 @@ class Fragment: parent = {self: None} level = {self: 0} - uses = SignalDict() - defs = SignalDict() - ios = SignalDict() + uses = _ast.SignalDict() + defs = _ast.SignalDict() + ios = _ast.SignalDict() self._prepare_use_def_graph(parent, level, uses, defs, ios, self) - ports = SignalSet(ports) + ports = _ast.SignalSet(ports) if all_undef_as_ports: for sig in uses: if sig in defs: @@ -530,7 +527,7 @@ class Fragment: else: self.add_ports(sig, dir="i") - def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)): + def prepare(self, ports=None, missing_domain=lambda name: _cd.ClockDomain(name)): from ._xfrm import DomainLowerer new_domains = self._propagate_domains(missing_domain) @@ -541,14 +538,14 @@ class Fragment: if not isinstance(ports, tuple) and not isinstance(ports, list): msg = "`ports` must be either a list or a tuple, not {!r}"\ .format(ports) - if isinstance(ports, Value): + if isinstance(ports, _ast.Value): msg += " (did you mean `ports=(,)`, rather than `ports=`?)" raise TypeError(msg) mapped_ports = [] # Lower late bound signals like ClockSignal() to ports. port_lowerer = DomainLowerer(fragment.domains) for port in ports: - if not isinstance(port, (Signal, ClockSignal, ResetSignal)): + if not isinstance(port, (_ast.Signal, _ast.ClockSignal, _ast.ResetSignal)): raise TypeError("Only signals may be added as ports, not {!r}" .format(port)) mapped_ports.append(port_lowerer.on_value(port)) @@ -573,7 +570,7 @@ class Fragment: may get a different name. """ - signal_names = SignalDict() + signal_names = _ast.SignalDict() assigned_names = set() def add_signal_name(signal): @@ -599,7 +596,7 @@ class Fragment: for statements in self.statements.values(): for statement in statements: for signal in statement._lhs_signals() | statement._rhs_signals(): - if not isinstance(signal, (ClockSignal, ResetSignal)): + if not isinstance(signal, (_ast.ClockSignal, _ast.ResetSignal)): add_signal_name(signal) return signal_names @@ -657,7 +654,7 @@ class Instance(Fragment): elif kind == "p": self.parameters[name] = value elif kind in ("i", "o", "io"): - self.named_ports[name] = (Value.cast(value), kind) + self.named_ports[name] = (_ast.Value.cast(value), kind) else: raise NameError("Instance argument {!r} should be a tuple (kind, name, value) " "where kind is one of \"a\", \"p\", \"i\", \"o\", or \"io\"" @@ -669,12 +666,763 @@ class Instance(Fragment): elif kw.startswith("p_"): self.parameters[kw[2:]] = arg elif kw.startswith("i_"): - self.named_ports[kw[2:]] = (Value.cast(arg), "i") + self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "i") elif kw.startswith("o_"): - self.named_ports[kw[2:]] = (Value.cast(arg), "o") + self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "o") elif kw.startswith("io_"): - self.named_ports[kw[3:]] = (Value.cast(arg), "io") + self.named_ports[kw[3:]] = (_ast.Value.cast(arg), "io") else: raise NameError("Instance keyword argument {}={!r} does not start with one of " "\"a_\", \"p_\", \"i_\", \"o_\", or \"io_\"" .format(kw, arg)) + + +############################################################################################### >:3 + + +class NetlistDriver: + def __init__(self, module_idx: int, signal: _ast.Signal, + domain: '_cd.ClockDomain | None', *, src_loc): + self.module_idx = module_idx + self.signal = signal + self.domain = domain + self.src_loc = src_loc + self.assignments = [] + + def emit_value(self, builder): + if self.domain is None: + reset = _ast.Const(self.signal.reset, self.signal.width) + default, _signed = builder.emit_rhs(self.module_idx, reset) + else: + default = builder.emit_signal(self.signal) + if len(self.assignments) == 1: + assign, = self.assignments + if assign.cond == 1 and assign.start == 0 and len(assign.value) == len(default): + return assign.value + cell = _nir.AssignmentList(self.module_idx, default=default, assignments=self.assignments, + src_loc=self.signal.src_loc) + return builder.netlist.add_value_cell(len(default), cell) + + +class NetlistEmitter: + def __init__(self, netlist: _nir.Netlist, fragment_names: 'dict[_ir.Fragment, str]'): + self.netlist = netlist + self.fragment_names = fragment_names + self.drivers = _ast.SignalDict() + self.rhs_cache: dict[int, Tuple[_nir.Value, bool, _ast.Value]] = {} + + # Collected for driver conflict diagnostics only. + self.late_net_to_signal = {} + self.connect_src_loc = {} + + def emit_signal(self, signal) -> _nir.Value: + if signal in self.netlist.signals: + return self.netlist.signals[signal] + value = self.netlist.alloc_late_value(len(signal)) + self.netlist.signals[signal] = value + for bit, net in enumerate(value): + self.late_net_to_signal[net] = (signal, bit) + return value + + # Used for instance outputs and read port data, not used for actual assignments. + def emit_lhs(self, value: _ast.Value): + if isinstance(value, _ast.Signal): + return self.emit_signal(value) + elif isinstance(value, _ast.Cat): + result = [] + for part in value.parts: + result += self.emit_lhs(part) + return _nir.Value(result) + elif isinstance(value, _ast.Slice): + return self.emit_lhs(value.value)[value.start:value.stop] + elif isinstance(value, _ast.Operator): + assert value.operator in ('u', 's') + return self.emit_lhs(value.operands[0]) + else: + raise TypeError # :nocov: + + def extend(self, value: _nir.Value, signed: bool, width: int): + nets = list(value) + while len(nets) < width: + if signed: + nets.append(nets[-1]) + else: + nets.append(_nir.Net.from_const(0)) + return _nir.Value(nets) + + def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc): + op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc) + return self.netlist.add_value_cell(op.width, op) + + def unify_shapes_bitwise(self, + operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool): + if signed_a == signed_b: + width = max(len(operand_a), len(operand_b)) + elif signed_a: + width = max(len(operand_a), len(operand_b) + 1) + else: # signed_b + width = max(len(operand_a) + 1, len(operand_b)) + operand_a = self.extend(operand_a, signed_a, width) + operand_b = self.extend(operand_b, signed_b, width) + signed = signed_a or signed_b + return (operand_a, operand_b, signed) + + def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]: + """Emits a RHS value, returns a tuple of (value, is_signed)""" + try: + result, signed, value = self.rhs_cache[id(value)] + return result, signed + except KeyError: + pass + if isinstance(value, _ast.Const): + result = _nir.Value( + _nir.Net.from_const((value.value >> bit) & 1) + for bit in range(value.width) + ) + signed = value.signed + elif isinstance(value, _ast.Signal): + result = self.emit_signal(value) + signed = value.signed + elif isinstance(value, _ast.Operator): + if len(value.operands) == 1: + operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0]) + if value.operator == 's': + result = operand_a + signed = True + elif value.operator == 'u': + result = operand_a + signed = False + elif value.operator == '+': + result = operand_a + signed = signed_a + elif value.operator == '-': + operand_a = self.extend(operand_a, signed_a, len(operand_a) + 1) + result = self.emit_operator(module_idx, '-', operand_a, + src_loc=value.src_loc) + signed = True + elif value.operator == '~': + result = self.emit_operator(module_idx, '~', operand_a, + src_loc=value.src_loc) + signed = signed_a + elif value.operator in ('b', 'r|', 'r&', 'r^'): + result = self.emit_operator(module_idx, value.operator, operand_a, + src_loc=value.src_loc) + signed = False + else: + assert False # :nocov: + elif len(value.operands) == 2: + operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0]) + operand_b, signed_b = self.emit_rhs(module_idx, value.operands[1]) + if value.operator in ('|', '&', '^'): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, value.operator, operand_a, operand_b, + src_loc=value.src_loc) + elif value.operator in ('+', '-'): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + width = len(operand_a) + 1 + operand_a = self.extend(operand_a, signed, width) + operand_b = self.extend(operand_b, signed, width) + result = self.emit_operator(module_idx, value.operator, operand_a, operand_b, + src_loc=value.src_loc) + if value.operator == '-': + signed = True + elif value.operator == '*': + width = len(operand_a) + len(operand_b) + operand_a = self.extend(operand_a, signed_a, width) + operand_b = self.extend(operand_b, signed_b, width) + result = self.emit_operator(module_idx, '*', operand_a, operand_b, + src_loc=value.src_loc) + signed = signed_a or signed_b + elif value.operator == '//': + width = len(operand_a) + signed_b + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + if len(operand_a) < width: + operand_a = self.extend(operand_a, signed, width) + operand_b = self.extend(operand_b, signed, width) + operator = 's//' if signed else 'u//' + result = _nir.Value( + self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc)[:width] + ) + elif value.operator == '%': + width = len(operand_b) + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + operator = 's%' if signed else 'u%' + result = _nir.Value( + self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc)[:width] + ) + signed = signed_b + elif value.operator == '<<': + operand_a = self.extend(operand_a, signed_a, + len(operand_a) + 2 ** len(operand_b) - 1) + result = self.emit_operator(module_idx, '<<', operand_a, operand_b, + src_loc=value.src_loc) + signed = signed_a + elif value.operator == '>>': + operator = 's>>' if signed_a else 'u>>' + result = self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc) + signed = signed_a + elif value.operator in ('==', '!='): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, value.operator, operand_a, operand_b, + src_loc=value.src_loc) + signed = False + elif value.operator in ('<', '>', '<=', '>='): + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + operator = ('s' if signed else 'u') + value.operator + result = self.emit_operator(module_idx, operator, operand_a, operand_b, + src_loc=value.src_loc) + signed = False + else: + assert False # :nocov: + elif len(value.operands) == 3: + assert value.operator == 'm' + operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0]) + operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1]) + operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2]) + if len(operand_s) != 1: + operand_s = self.emit_operator(module_idx, 'b', operand_s, + src_loc=value.src_loc) + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b, + src_loc=value.src_loc) + else: + assert False # :nocov: + elif isinstance(value, _ast.Slice): + inner, _signed = self.emit_rhs(module_idx, value.value) + result = _nir.Value(inner[value.start:value.stop]) + signed = False + elif isinstance(value, _ast.Part): + inner, signed = self.emit_rhs(module_idx, value.value) + offset, _signed = self.emit_rhs(module_idx, value.offset) + cell = _nir.Part(module_idx, value=inner, value_signed=signed, width=value.width, + stride=value.stride, offset=offset, src_loc=value.src_loc) + result = self.netlist.add_value_cell(value.width, cell) + signed = False + elif isinstance(value, _ast.ArrayProxy): + elems = [self.emit_rhs(module_idx, elem) for elem in value.elems] + width = 0 + signed = False + for elem, elem_signed in elems: + if elem_signed: + if not signed: + width += 1 + signed = True + width = max(width, len(elem)) + elif signed: + width = max(width, len(elem) + 1) + else: + width = max(width, len(elem)) + elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems) + index, _signed = self.emit_rhs(module_idx, value.index) + cell = _nir.ArrayMux(module_idx, width=width, elems=elems, index=index, + src_loc=value.src_loc) + result = self.netlist.add_value_cell(width, cell) + elif isinstance(value, _ast.Cat): + nets = [] + for val in value.parts: + inner, _signed = self.emit_rhs(module_idx, val) + for net in inner: + nets.append(net) + result = _nir.Value(nets) + signed = False + elif isinstance(value, _ast.AnyValue): + result = self.netlist.add_value_cell(value.width, + _nir.AnyValue(module_idx, kind=value.kind.value, width=value.width, + src_loc=value.src_loc)) + signed = value.signed + elif isinstance(value, _ast.Initial): + result = self.netlist.add_value_cell(1, _nir.Initial(module_idx, src_loc=value.src_loc)) + signed = False + else: + assert False # :nocov: + assert value.shape().width == len(result), \ + f"Value {value!r} with shape {value.shape()!r} does not match " \ + f"result with width {len(result)}" + # Add the value itself to the cache to make sure `id(value)` remains allocated and pointing + # at `value`. This would be a weakref.WeakKeyDictionary if `value` was hashable. + self.rhs_cache[id(value)] = result, signed, value + return (result, signed) + + def connect(self, lhs: _nir.Value, rhs: _nir.Value, *, src_loc): + assert len(lhs) == len(rhs) + for left, right in zip(lhs, rhs): + if left in self.netlist.connections: + signal, bit = self.late_net_to_signal[left] + other_src_loc = self.connect_src_loc[left] + raise _ir.DriverConflict(f"Bit {bit} of signal {signal!r} has multiple drivers: " + f"{other_src_loc} and {src_loc}") + self.netlist.connections[left] = right + self.connect_src_loc[left] = src_loc + + def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str, + stmt: _ast.Statement, cond: _nir.Net): + if domain == "comb": + cd: _cd.ClockDomain | None = None + else: + cd = fragment.domains[domain] + if isinstance(stmt, _ast.Assign): + if isinstance(stmt.lhs, _ast.Signal): + signal = stmt.lhs + start = 0 + width = signal.width + elif isinstance(stmt.lhs, _ast.Slice): + signal = stmt.lhs.value + start = stmt.lhs.start + width = stmt.lhs.stop - stmt.lhs.start + else: + assert False # :nocov: + assert isinstance(signal, _ast.Signal) + if signal in self.drivers: + driver = self.drivers[signal] + if driver.domain is not cd: + raise _ir.DriverConflict( + f"Signal {signal} driven from domain {cd} at {stmt.src_loc} and domain " + f"{driver.domain} at {driver.src_loc}") + if driver.module_idx != module_idx: + mod_name = ".".join(self.netlist.modules[module_idx].name or ("",)) + other_mod_name = \ + ".".join(self.netlist.modules[driver.module_idx].name or ("",)) + raise _ir.DriverConflict( + f"Signal {signal} driven from module {mod_name} at {stmt.src_loc} and " + f"module {other_mod_name} at {driver.src_loc}") + else: + driver = NetlistDriver(module_idx, signal, domain=cd, src_loc=stmt.src_loc) + self.drivers[signal] = driver + rhs, signed = self.emit_rhs(module_idx, stmt.rhs) + if len(rhs) > width: + rhs = _nir.Value(rhs[:width]) + if len(rhs) < width: + rhs = self.extend(rhs, signed, width) + driver.assignments.append(_nir.Assignment(cond=cond, start=start, value=rhs, + src_loc=stmt.src_loc)) + elif isinstance(stmt, _ast.Property): + test, _signed = self.emit_rhs(module_idx, stmt.test) + if len(test) != 1: + test = self.emit_operator(module_idx, 'b', test, src_loc=stmt.src_loc) + test, = test + en_cell = _nir.AssignmentList(module_idx, + default=_nir.Value.zeros(), + assignments=[ + _nir.Assignment(cond=cond, start=0, value=_nir.Value.ones(), + src_loc=stmt.src_loc) + ], + src_loc=stmt.src_loc) + cond, = self.netlist.add_value_cell(1, en_cell) + if cd is None: + cell = _nir.AsyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond, + name=stmt.name, src_loc=stmt.src_loc) + else: + clk, = self.emit_signal(cd.clk) + cell = _nir.SyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond, + clk=clk, clk_edge=cd.clk_edge, name=stmt.name, + src_loc=stmt.src_loc) + self.netlist.add_cell(cell) + elif isinstance(stmt, _ast.Switch): + test, _signed = self.emit_rhs(module_idx, stmt.test) + conds = [] + for patterns in stmt.cases: + if patterns: + for pattern in patterns: + assert len(pattern) == len(test) + cell = _nir.Matches(module_idx, value=test, patterns=patterns, + src_loc=stmt.case_src_locs.get(patterns)) + net, = self.netlist.add_value_cell(1, cell) + conds.append(net) + else: + conds.append(_nir.Net.from_const(1)) + cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds), + src_loc=stmt.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) + for subcond, substmts in zip(conds, stmt.cases.values()): + for substmt in substmts: + self.emit_stmt(module_idx, fragment, domain, substmt, subcond) + else: + assert False # :nocov: + + def emit_tribuf(self, module_idx: int, instance: _ir.Instance): + pad = self.emit_lhs(instance.named_ports["Y"][0]) + o, _signed = self.emit_rhs(module_idx, instance.named_ports["A"][0]) + (oe,), _signed = self.emit_rhs(module_idx, instance.named_ports["EN"][0]) + assert len(pad) == len(o) + cell = _nir.IOBuffer(module_idx, pad=pad, o=o, oe=oe, src_loc=instance.src_loc) + self.netlist.add_cell(cell) + + def emit_memory(self, module_idx: int, fragment: '_mem.MemoryInstance', name: str): + cell = _nir.Memory(module_idx, + width=fragment._width, + depth=fragment._depth, + init=fragment._init, + name=name, + attributes=fragment._attrs, + src_loc=fragment._src_loc, + ) + return self.netlist.add_cell(cell) + + def emit_write_port(self, module_idx: int, fragment: '_mem.MemoryInstance', + port: '_mem.MemoryInstance._WritePort', memory: int): + data, _signed = self.emit_rhs(module_idx, port._data) + addr, _signed = self.emit_rhs(module_idx, port._addr) + en, _signed = self.emit_rhs(module_idx, port._en) + en = _nir.Value([en[bit // port._granularity] for bit in range(len(port._data))]) + cd = fragment.domains[port._domain] + clk, = self.emit_signal(cd.clk) + cell = _nir.SyncWritePort(module_idx, + memory=memory, + data=data, + addr=addr, + en=en, + clk=clk, + clk_edge=cd.clk_edge, + src_loc=port._data.src_loc, + ) + return self.netlist.add_cell(cell) + + def emit_read_port(self, module_idx: int, fragment: '_mem.MemoryInstance', + port: '_mem.MemoryInstance._ReadPort', memory: int, + write_ports: 'list[int]'): + addr, _signed = self.emit_rhs(module_idx, port._addr) + if port._domain == "comb": + cell = _nir.AsyncReadPort(module_idx, + memory=memory, + width=len(port._data), + addr=addr, + src_loc=port._data.src_loc, + ) + else: + (en,), _signed = self.emit_rhs(module_idx, port._en) + cd = fragment.domains[port._domain] + clk, = self.emit_signal(cd.clk) + cell = _nir.SyncReadPort(module_idx, + memory=memory, + width=len(port._data), + addr=addr, + en=en, + clk=clk, + clk_edge=cd.clk_edge, + transparent_for=tuple(write_ports[idx] for idx in port._transparency), + src_loc=port._data.src_loc, + ) + data = self.netlist.add_value_cell(len(port._data), cell) + self.connect(self.emit_lhs(port._data), data, src_loc=port._data.src_loc) + + def emit_instance(self, module_idx: int, instance: _ir.Instance, name: str): + ports_i = {} + ports_o = {} + ports_io = {} + outputs = [] + next_output_bit = 0 + for port_name, (port_conn, dir) in instance.named_ports.items(): + if dir == 'i': + ports_i[port_name], _signed = self.emit_rhs(module_idx, port_conn) + elif dir == 'o': + port_conn = self.emit_lhs(port_conn) + ports_o[port_name] = (next_output_bit, len(port_conn)) + outputs.append((next_output_bit, port_conn)) + next_output_bit += len(port_conn) + elif dir == 'io': + ports_io[port_name] = self.emit_lhs(port_conn) + else: + assert False # :nocov: + cell = _nir.Instance(module_idx, + type=instance.type, + name=name, + parameters=instance.parameters, + attributes=instance.attrs, + ports_i=ports_i, + ports_o=ports_o, + ports_io=ports_io, + src_loc=instance.src_loc, + ) + output_nets = self.netlist.add_value_cell(width=next_output_bit, cell=cell) + for start_bit, port_conn in outputs: + self.connect(port_conn, _nir.Value(output_nets[start_bit:start_bit + len(port_conn)]), + src_loc=instance.src_loc) + + def emit_top_ports(self, fragment: _ir.Fragment, signal_names: _ast.SignalDict): + next_input_bit = 2 # 0 and 1 are reserved for constants + top = self.netlist.top + for signal, dir in fragment.ports.items(): + assert signal not in self.netlist.signals + name = signal_names[signal] + if dir == 'i': + top.ports_i[name] = (next_input_bit, signal.width) + nets = _nir.Value( + _nir.Net.from_cell(0, bit) + for bit in range(next_input_bit, next_input_bit + signal.width) + ) + next_input_bit += signal.width + self.netlist.signals[signal] = nets + elif dir == 'o': + top.ports_o[name] = self.emit_signal(signal) + elif dir == 'io': + top.ports_io[name] = (next_input_bit, signal.width) + nets = _nir.Value( + _nir.Net.from_cell(0, bit) + for bit in range(next_input_bit, next_input_bit + signal.width) + ) + next_input_bit += signal.width + self.netlist.signals[signal] = nets + + def emit_drivers(self): + for driver in self.drivers.values(): + value = driver.emit_value(self) + if driver.domain is not None: + clk, = self.emit_signal(driver.domain.clk) + if driver.domain.rst is not None and driver.domain.async_reset: + arst, = self.emit_signal(driver.domain.rst) + else: + arst = _nir.Net.from_const(0) + cell = _nir.FlipFlop(driver.module_idx, + data=value, + init=driver.signal.reset, + clk=clk, + clk_edge=driver.domain.clk_edge, + arst=arst, + attributes=driver.signal.attrs, + src_loc=driver.signal.src_loc, + ) + value = self.netlist.add_value_cell(len(value), cell) + if driver.assignments: + src_loc = driver.assignments[0].src_loc + else: + src_loc = driver.signal.src_loc + self.connect(self.emit_signal(driver.signal), value, src_loc=src_loc) + + # Connect all undriven signal bits to their reset values. This can only happen for entirely + # undriven signals, or signals that are partially driven by instances. + for signal, value in self.netlist.signals.items(): + for bit, net in enumerate(value): + if net.is_late and net not in self.netlist.connections: + self.netlist.connections[net] = _nir.Net.from_const((signal.reset >> bit) & 1) + + def emit_fragment(self, fragment: _ir.Fragment, parent_module_idx: 'int | None'): + from . import _mem + + fragment_name = self.fragment_names[fragment] + if isinstance(fragment, _ir.Instance): + assert parent_module_idx is not None + if fragment.type == "$tribuf": + self.emit_tribuf(parent_module_idx, fragment) + else: + self.emit_instance(parent_module_idx, fragment, name=fragment_name[-1]) + elif isinstance(fragment, _mem.MemoryInstance): + assert parent_module_idx is not None + memory = self.emit_memory(parent_module_idx, fragment, name=fragment_name[-1]) + write_ports = [] + for port in fragment._write_ports: + write_ports.append(self.emit_write_port(parent_module_idx, fragment, port, memory)) + for port in fragment._read_ports: + self.emit_read_port(parent_module_idx, fragment, port, memory, write_ports) + elif type(fragment) is _ir.Fragment: + module_idx = self.netlist.add_module(parent_module_idx, fragment_name) + signal_names = fragment._assign_names_to_signals() + self.netlist.modules[module_idx].signal_names = signal_names + if parent_module_idx is None: + self.emit_top_ports(fragment, signal_names) + for signal in signal_names: + self.emit_signal(signal) + for domain, stmts in fragment.statements.items(): + for stmt in stmts: + self.emit_stmt(module_idx, fragment, domain, stmt, _nir.Net.from_const(1)) + for subfragment, _name in fragment.subfragments: + self.emit_fragment(subfragment, module_idx) + if parent_module_idx is None: + self.emit_drivers() + else: + assert False # :nocov: + + +def _emit_netlist(netlist: _nir.Netlist, fragment, hierarchy): + fragment_names = fragment._assign_names_to_fragments(hierarchy) + NetlistEmitter(netlist, fragment_names).emit_fragment(fragment, None) + + +def _compute_net_flows(netlist: _nir.Netlist): + # Computes the net flows for all modules of the netlist. + # + # The rules for net flows are as follows: + # + # - the modules that have a given net in their net_flow form a subtree of the hierarchy + # - INTERNAL is used in the root of the subtree and nowhere else + # - OUTPUT is used for modules that contain the definition of the net, or are on the + # path from the definition to the root + # - remaining modules have a flow of INPUT (unless the net is a top-level inout port, + # in which case it is INOUT) + # + # In other words, the tree looks something like this: + # + # - [no flow] <<< top + # - [no flow] + # - INTERNAL + # - INPUT << use + # - [no flow] + # - INPUT + # - INPUT << use + # - OUTPUT + # - INPUT << use + # - [no flow] + # - OUTPUT << def + # - INPUT + # - INPUT + # - [no flow] + # - [no flow] + # - [no flow] + # + # This function doesn't assign the INOUT flow — that is corrected later, in compute_ports. + lca = {} + + # Initialize by marking the definition point of every net. + for cell_idx, cell in enumerate(netlist.cells): + for net in cell.output_nets(cell_idx): + lca[net] = cell.module_idx + netlist.modules[cell.module_idx].net_flow[net] = _nir.ModuleNetFlow.INTERNAL + + # Marks a use of a net within a given module, and adjusts its netflows in all modules + # as required. + def use_net(net, use_module): + if net.is_const: + return + # If the net is already present in the current module, we're done. + if net in netlist.modules[use_module].net_flow: + return + modules = netlist.modules + # Otherwise, we need to route the net through the hierarchy from def_module + # to use_module. We do that by treating use_module and def_module as pointers + # and moving them up the hierarchy until they meet at the new LCA. + def_module = lca[net] + # While def_module deeper than use_module, go up with def_module. + while len(modules[def_module].name) > len(modules[use_module].name): + modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT + def_module = modules[def_module].parent + # While use_module deeper than def_module, go up with use_module. + # If use_module is below def_module in the hierarchy, we may hit + # another module which already uses this net before hitting def_module, + # so check for this case. + while len(modules[def_module].name) < len(modules[use_module].name): + if net in modules[use_module].net_flow: + return + modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT + use_module = modules[use_module].parent + # Now both pointers should be at the same depth within the hierarchy. + assert len(modules[def_module].name) == len(modules[use_module].name) + # Move both pointers up until they meet. + while def_module != use_module: + modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT + def_module = modules[def_module].parent + modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT + use_module = modules[use_module].parent + assert len(modules[def_module].name) == len(modules[use_module].name) + # And mark the new LCA. + modules[def_module].net_flow[net] = _nir.ModuleNetFlow.INTERNAL + lca[net] = def_module + + # Now mark all uses and flesh out the structure. + for cell in netlist.cells: + for net in cell.input_nets(): + use_net(net, cell.module_idx) + # TODO: ? + for module_idx, module in enumerate(netlist.modules): + for signal in module.signal_names: + for net in netlist.signals[signal]: + use_net(net, module_idx) + + +def _compute_ports(netlist: _nir.Netlist): + # Compute the indexes at which the outputs of a cell should be split to create a distinct port. + # These indexes are stored here as nets. + port_starts = set() + for start, _ in netlist.top.ports_i.values(): + port_starts.add(_nir.Net.from_cell(0, start)) + for start, width in netlist.top.ports_io.values(): + port_starts.add(_nir.Net.from_cell(0, start)) + for cell_idx, cell in enumerate(netlist.cells): + if isinstance(cell, _nir.Instance): + for start, _ in cell.ports_o.values(): + port_starts.add(_nir.Net.from_cell(cell_idx, start)) + + # Compute the set of all inout nets. Currently, a net has inout flow iff it is connected to + # a toplevel inout port. + inouts = set() + for start, width in netlist.top.ports_io.values(): + for idx in range(start, start + width): + inouts.add(_nir.Net.from_cell(0, idx)) + + for module in netlist.modules: + # Collect preferred names for ports. If a port exactly matches a signal, we reuse + # the signal name for the port. Otherwise, we synthesize a private name. + name_table = {} + for signal, name in module.signal_names.items(): + value = netlist.signals[signal] + if value not in name_table and not name.startswith('$'): + name_table[value] = name + + # Gather together "adjacent" nets with the same flow into ports. + visited = set() + for net in sorted(module.net_flow): + flow = module.net_flow[net] + if flow == _nir.ModuleNetFlow.INTERNAL: + continue + if flow == _nir.ModuleNetFlow.INPUT and net in inouts: + flow = module.net_flow[net] = _nir.ModuleNetFlow.INOUT + if net in visited: + continue + # We found a net that needs a port. Keep joining the next nets output by the same + # cell into the same port, if applicable, but stop at instance/top port boundaries. + nets = [net] + while True: + succ = _nir.Net.from_cell(net.cell, net.bit + 1) + if succ in port_starts: + break + if succ not in module.net_flow: + break + if module.net_flow[succ] != module.net_flow[net]: + break + net = succ + nets.append(net) + value = _nir.Value(nets) + # Joined as many nets as we could, now name and add the port. + if value in name_table: + name = name_table[value] + else: + name = f"port${value[0].cell}${value[0].bit}" + module.ports[name] = (value, flow) + visited.update(value) + + # The 0th cell and the 0th module correspond to the toplevel. Transfer the net flows from + # the toplevel cell (used for data flow) to the toplevel module (used to split netlist into + # modules in the backends). + top_module = netlist.modules[0] + for name, (start, width) in netlist.top.ports_i.items(): + top_module.ports[name] = ( + _nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)), + _nir.ModuleNetFlow.INPUT + ) + for name, (start, width) in netlist.top.ports_io.items(): + top_module.ports[name] = ( + _nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)), + _nir.ModuleNetFlow.INOUT + ) + for name, value in netlist.top.ports_o.items(): + top_module.ports[name] = (value, _nir.ModuleNetFlow.OUTPUT) + + +def build_netlist(fragment, *, name="top"): + from ._xfrm import AssignmentLegalizer + + fragment = AssignmentLegalizer()(fragment) + netlist = _nir.Netlist() + _emit_netlist(netlist, fragment, hierarchy=(name,)) + netlist.resolve_all_nets() + _compute_net_flows(netlist) + _compute_ports(netlist) + return netlist diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py new file mode 100644 index 0000000..f3c0c0f --- /dev/null +++ b/amaranth/hdl/_nir.py @@ -0,0 +1,1003 @@ +from typing import Iterable +import enum + +from ._ast import SignalDict + + +__all__ = [ + # Netlist core + "Net", "Value", "Netlist", "ModuleNetFlow", "Module", "Cell", "Top", + # Computation cells + "Operator", "Part", "ArrayMux", + # Decision tree cells + "Matches", "PriorityMatch", "Assignment", "AssignmentList", + # Storage cells + "FlipFlop", "Memory", "SyncWritePort", "AsyncReadPort", "SyncReadPort", + # Formal verification cells + "Initial", "AnyValue", "AsyncProperty", "SyncProperty", + # Foreign interface cells + "Instance", "IOBuffer", +] + + +class Net(int): + @classmethod + def from_cell(cls, cell: int, bit: int): + assert bit in range(1 << 16) + assert cell >= 0 + if cell == 0: + assert bit >= 2 + return cls((cell << 16) | bit) + + @classmethod + def from_const(cls, val: int): + assert val in (0, 1) + return cls(val) + + @classmethod + def from_late(cls, val: int): + assert val < 0 + return cls(val) + + @property + def is_const(self): + return self in (0, 1) + + @property + def const(self): + assert self in (0, 1) + return int(self) + + @property + def is_late(self): + return self < 0 + + @property + def is_cell(self): + return self >= 2 + + @property + def cell(self): + assert self >= 2 + return self >> 16 + + @property + def bit(self): + assert self >= 2 + return self & 0xffff + + @classmethod + def ensure(cls, value: 'Net'): + assert isinstance(value, cls) + return value + + +class Value(tuple): + def __new__(cls, nets: 'Net | Iterable[Net]' = ()): + if isinstance(nets, Net): + return super().__new__(cls, (nets,)) + return super().__new__(cls, (Net.ensure(net) for net in nets)) + + @classmethod + def zeros(cls, digits=1): + return cls(Net.from_const(0) for _ in range(digits)) + + @classmethod + def ones(cls, digits=1): + return cls(Net.from_const(1) for _ in range(digits)) + + +class Netlist: + """A fine netlist. Consists of: + + - a flat array of cells + - a dictionary of connections for late-bound nets + - a map of hierarchical names to nets + - a map of signals to nets + + The nets are virtual: a list of nets is not materialized anywhere in the netlist. + A net is a single bit wide and represented as a single int. The int is encoded as follows: + + - A negative number means a late-bound net. The net should be looked up in the ``connections`` + dictionary to find its driver. + - Non-negative numbers are cell outputs, and are split into bitfields as follows: + + - bits 0-15: output bit index within a cell (exact meaning is cell-dependent) + - bits 16-...: index of cell in ``netlist.cells`` + + Cell 0 is always ``Top``. The first two output bits of ``Top`` are considered to be constants + ``0`` and ``1``, which effectively means that net encoded as ``0`` is always a constant ``0`` and + net encoded as ``1`` is always a constant ``1``. + + Multi-bit values are represented as tuples of int. + + Attributes + ---------- + + cells : list of ``Cell`` + connections : dict of (negative) int to int + signals : dict of Signal to ``Value`` + """ + def __init__(self): + self.modules: list[Module] = [] + self.cells: list[Cell] = [Top()] + self.connections: dict[Net, Net] = {} + self.signals = SignalDict() + self.last_late_net = 0 + + def resolve_net(self, net: Net): + assert isinstance(net, Net) + while net.is_late: + net = self.connections[net] + return net + + def resolve_value(self, value: Value): + return Value(self.resolve_net(net) for net in value) + + def resolve_all_nets(self): + for cell in self.cells: + cell.resolve_nets(self) + for sig in self.signals: + self.signals[sig] = self.resolve_value(self.signals[sig]) + + def __str__(self): + def net_to_str(net): + net = self.resolve_net(net) + if net.is_const: + return f"{net.const}" + return f"{net.cell}.{net.bit}" + def val_to_str(val): + return "{" + " ".join(net_to_str(x) for x in val) + "}" + result = [] + for module_idx, module in enumerate(self.modules): + result.append(f"module {module_idx} [parent {module.parent}]: {' '.join(module.name)}") + if module.submodules: + result.append(f" submodules {' '.join(str(x) for x in module.submodules)} ") + result.append(f" cells {' '.join(str(x) for x in module.cells)} ") + for name, (val, flow) in module.ports.items(): + result.append(f" port {name} {val_to_str(val)}: {flow}") + for cell_idx, cell in enumerate(self.cells): + result.append(f"cell {cell_idx} [module {cell.module_idx}]: ") + if isinstance(cell, Top): + result.append("top") + for name, val in cell.ports_o.items(): + result.append(f" output {name}: {val_to_str(val)}") + for name, (start, num) in cell.ports_i.items(): + result.append(f" input {name}: 0.{start}..0.{start+num-1}") + for name, (start, num) in cell.ports_io.items(): + result.append(f" inout {name}: 0.{start}..0.{start+num-1}") + elif isinstance(cell, Matches): + result.append(f"matches {val_to_str(cell.value)}, {' | '.join(cell.patterns)}") + elif isinstance(cell, PriorityMatch): + result.append(f"priority_match {net_to_str(cell.en)}, {val_to_str(cell.inputs)}") + elif isinstance(cell, AssignmentList): + result.append(f"list {val_to_str(cell.default)}") + for assign in cell.assignments: + result.append(f" if {net_to_str(assign.cond)} start {assign.start} <- {val_to_str(assign.value)}") + elif isinstance(cell, Operator): + inputs = ", ".join(val_to_str(input) for input in cell.inputs) + result.append(f"{cell.operator} {inputs}") + elif isinstance(cell, Part): + result.append(f"part {val_to_str(cell.value)}, {val_to_str(cell.offset)}, {cell.width}, {cell.stride}, {cell.value_signed}") + elif isinstance(cell, ArrayMux): + result.append(f"array {cell.width}, {val_to_str(cell.index)}, {', '.join(val_to_str(elem) for elem in cell.elems)}") + elif isinstance(cell, FlipFlop): + result.append(f"ff {val_to_str(cell.data)} {cell.init} @{cell.clk_edge}edge {net_to_str(cell.clk)} {net_to_str(cell.arst)}") + for attr_name, attr_value in cell.attributes.items(): + result.append(f" attribute {attr_name} {attr_value}") + elif isinstance(cell, Memory): + result.append(f"memory {cell.name} {cell.width} {cell.depth} {cell.init}") + for attr_name, attr_value in cell.attributes.items(): + result.append(f" attribute {attr_name} {attr_value}") + elif isinstance(cell, SyncWritePort): + result.append(f"wrport {cell.memory} {val_to_str(cell.data)} {val_to_str(cell.addr)} {val_to_str(cell.en)} @{cell.clk_edge}edge {net_to_str(cell.clk)}") + elif isinstance(cell, AsyncReadPort): + result.append(f"rdport {cell.memory} {cell.width} {val_to_str(cell.addr)}") + elif isinstance(cell, SyncReadPort): + result.append(f"rdport {cell.memory} {cell.width} {val_to_str(cell.addr)} {net_to_str(cell.en)} @{cell.clk_edge}edge {net_to_str(cell.clk)}") + for port in cell.transparent_for: + result.append(f" transparent {port}") + elif isinstance(cell, Initial): + result.append("initial") + elif isinstance(cell, AnyValue): + result.append("{cell.kind} {cell.width}") + elif isinstance(cell, AsyncProperty): + result.append(f"{cell.kind} {cell.name!r} {net_to_str(cell.test)} {net_to_str(cell.en)}") + elif isinstance(cell, SyncProperty): + result.append(f"{cell.kind} {cell.name!r} {net_to_str(cell.test)} {net_to_str(cell.en)} @{cell.clk_edge}edge {net_to_str(cell.clk)}") + elif isinstance(cell, Instance): + result.append("instance {cell.type} {cell.name}") + for attr_name, attr_value in cell.attributes.items(): + result.append(f" attribute {attr_name} {attr_value}") + for attr_name, attr_value in cell.parameters.items(): + result.append(f" parameter {attr_name} {attr_value}") + for attr_name, (start, num) in cell.ports_o.items(): + result.append(f" output {attr_name} {cell_idx}.{start}..{cell_idx}.{start+num-1}") + for attr_name, attr_value in cell.ports_i.items(): + result.append(f" input {attr_name} {val_to_str(attr_value)}") + for attr_name, attr_value in cell.ports_io.items(): + result.append(f" inout {attr_name} {val_to_str(attr_value)}") + elif isinstance(cell, IOBuffer): + result.append(f"iob {val_to_str(cell.pad)} {val_to_str(cell.o)} {net_to_str(cell.oe)}") + else: + assert False # :nocov: + return "\n".join(result) + + def add_module(self, parent, name: str): + module_idx = len(self.modules) + self.modules.append(Module(parent, name)) + if module_idx == 0: + self.modules[0].cells.append(0) + if parent is not None: + self.modules[parent].submodules.append(module_idx) + return module_idx + + def add_cell(self, cell): + idx = len(self.cells) + self.cells.append(cell) + self.modules[cell.module_idx].cells.append(idx) + return idx + + def add_value_cell(self, width: int, cell): + cell_idx = self.add_cell(cell) + return Value(Net.from_cell(cell_idx, bit) for bit in range(width)) + + def alloc_late_value(self, width: int): + self.last_late_net -= width + return Value(Net.from_late(self.last_late_net + bit) for bit in range(width)) + + @property + def top(self): + top = self.cells[0] + assert isinstance(top, Top) + return top + + +class ModuleNetFlow(enum.Enum): + """Describes how a given Net flows into or out of a Module. + + The net can also be none of these (not present in the dictionary at all), + when it is not present in the module at all. + """ + + #: The net is present in the module (used in the module or needs + #: to be routed through it between its submodules), but is not + #: present outside its subtree and thus is not a port of this module. + INTERNAL = "internal" + + #: The net is present in the module, and is not driven from + #: the module or any of its submodules. It is thus an input + #: port of this module. + INPUT = "input" + + #: The net is present in the module, is driven from the module or + #: one of its submodules, and is also used outside of its subtree. + #: It is thus an output port of this module. + OUTPUT = "output" + + #: The net is a special top-level inout net that is used within + #: this module or its submodules. It is an inout port of this module. + INOUT = "inout" + + +class Module: + """A module within the netlist. + + Attributes + ---------- + + parent: index of parent module, or ``None`` for top module + name: a tuple of str, hierarchical name of this module (top has empty tuple) + submodules: a list of nested module indices + signal_names: a SignalDict from Signal to str, signal names visible in this module + net_flow: a dict from Net to NetFlow, describes how a net is used within this module + ports: a dict from port name to (Value, NetFlow) pair + cells: a list of cell indices that belong to this module + """ + def __init__(self, parent, name): + self.parent = parent + self.name = name + self.submodules = [] + self.signal_names = SignalDict() + self.net_flow = {} + self.ports = {} + self.cells = [] + + +class Cell: + """A base class for all cell types. + + Attributes + ---------- + + src_loc: str + module: int, index of the module this cell belongs to (within Netlist.modules) + """ + + def __init__(self, module_idx: int, *, src_loc): + self.module_idx = module_idx + self.src_loc = src_loc + + def input_nets(self): + raise NotImplementedError + + def output_nets(self, self_idx: int): + raise NotImplementedError + + def resolve_nets(self, netlist: Netlist): + raise NotImplementedError + + +class Top(Cell): + """A special cell type representing top-level ports. Must be present in the netlist exactly + once, at index 0. + + Top-level outputs are stored as a dict of names to their assigned values. + + Top-level inputs and inouts are effectively the output of this cell. They are both stored + as a dict of names to a (start bit index, width) tuple. Output bit indices 0 and 1 are reserved + for constant nets, so the lowest bit index that can be assigned to a port is 2. + + Top-level inouts are special and can only be used by inout ports of instances, or in the pad + value of an ``IoBuf`` cell. + + Attributes + ---------- + + ports_o: dict of str to Value + ports_i: dict of str to (int, int) + ports_io: dict of str to (int, int) + """ + def __init__(self): + super().__init__(module_idx=0, src_loc=None) + + self.ports_o = {} + self.ports_i = {} + self.ports_io = {} + + def input_nets(self): + nets = set() + for value in self.ports_o.values(): + nets |= set(value) + return nets + + def output_nets(self, self_idx: int): + nets = set() + for start, width in self.ports_i.values(): + for bit in range(start, start + width): + nets.add(Net.from_cell(self_idx, bit)) + for start, width in self.ports_io.values(): + for bit in range(start, start + width): + nets.add(Net.from_cell(self_idx, bit)) + return nets + + def resolve_nets(self, netlist: Netlist): + for port in self.ports_o: + self.ports_o[port] = netlist.resolve_value(self.ports_o[port]) + + +class Operator(Cell): + """Roughly corresponds to ``hdl.ast.Operator``. + + The available operators are roughly the same as in AST, with some changes: + + - '<', '>', '<=', '>=', '//', '%', '>>' have signed and unsigned variants that are selected + by prepending 'u' or 's' to operator name + - 's', 'u', and unary '+' are redundant and do not exist + - many operators restrict input widths to be the same as output width, + and/or to be the same as each other + + The unary operators are: + + - '-', '~': like AST, input same width as output + - 'b', 'r|', 'r&', 'r^': like AST, 1-bit output + + The binary operators are: + + - '+', '-', '*', '&', '^', '|', 'u//', 's//', 'u%', 's%': like AST, both inputs same width as output + - '<<', 'u>>', 's>>': like AST, first input same width as output + - '==', '!=', 'u<', 's<', 'u>', 's>', 'u<=', 's<=', 'u>=', 's>=': like AST, both inputs need to have + the same width, 1-bit output + + The ternary operators are: + + - 'm': like AST, first input needs to have width of 1, second and third operand need to have the same + width as output + + Attributes + ---------- + + operator: str, symbol of the operator (from the above list) + inputs: tuple of Value + """ + + def __init__(self, module_idx, *, operator: str, inputs, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.operator = operator + self.inputs = tuple(Value(input) for input in inputs) + + @property + def width(self): + if self.operator in ('~', '-', '+', '*', '&', '^', '|', 'u//', 's//', 'u%', 's%', '<<', 'u>>', 's>>'): + return len(self.inputs[0]) + elif self.operator in ('b', 'r&', 'r^', 'r|', '==', '!=', 'u<', 's<', 'u>', 's>', 'u<=', 's<=', 'u>=', 's>='): + return 1 + elif self.operator == 'm': + return len(self.inputs[1]) + else: + assert False # :nocov: + + def input_nets(self): + nets = set() + for value in self.inputs: + nets |= set(value) + return nets + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.inputs = tuple(netlist.resolve_value(val) for val in self.inputs) + + +class Part(Cell): + """Corresponds to ``hdl.ast.Part``. + + Attributes + ---------- + + value: Value, the data input + value_signed: bool + offset: Value, the offset input + width: int + stride: int + """ + def __init__(self, module_idx, *, value, value_signed, offset, width, stride, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert type(width) is int + assert type(stride) is int + + self.value = Value(value) + self.value_signed = value_signed + self.offset = Value(offset) + self.width = width + self.stride = stride + + def input_nets(self): + return set(self.value) | set(self.offset) + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.value = netlist.resolve_value(self.value) + self.offset = netlist.resolve_value(self.offset) + + +class ArrayMux(Cell): + """Corresponds to ``hdl.ast.ArrayProxy``. All values in the ``elems`` array need to have + the same width as the output. + + Attributes + ---------- + + width: int (width of output and all inputs) + elems: tuple of Value + index: Value + """ + def __init__(self, module_idx, *, width, elems, index, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.width = width + self.elems = tuple(Value(val) for val in elems) + self.index = Value(index) + + def input_nets(self): + nets = set(self.index) + for value in self.elems: + nets |= set(value) + return nets + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.elems = tuple(netlist.resolve_value(val) for val in self.elems) + self.index = netlist.resolve_value(self.index) + + +class Matches(Cell): + """A combinatorial cell performing a comparison like ``Value.matches`` + (or, equivalently, a case condition). + + Attributes + ---------- + + value: Value + patterns: tuple of str, each str contains '0', '1', '-' + """ + def __init__(self, module_idx, *, value, patterns, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.value = Value(value) + self.patterns = tuple(patterns) + + def input_nets(self): + return set(self.value) + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, 0)} + + def resolve_nets(self, netlist: Netlist): + self.value = netlist.resolve_value(self.value) + + +class PriorityMatch(Cell): + """Used to represent a single switch on the control plane of processes. + + The output is the same length as ``inputs``. If ``en`` is ``0``, the output + is all-0. Otherwise, output keeps the lowest-numbered ``1`` bit in the input + (if any) and masks all other bits to ``0``. + + Note: the RTLIL backend requires all bits of ``inputs`` to be driven + by a ``Match`` cell within the same module. + + Attributes + ---------- + en: Net + inputs: Value + """ + def __init__(self, module_idx, *, en, inputs, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.en = Net.ensure(en) + self.inputs = Value(inputs) + + def input_nets(self): + return set(self.inputs) | {self.en} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.inputs))} + + def resolve_nets(self, netlist: Netlist): + self.en = netlist.resolve_net(self.en) + self.inputs = netlist.resolve_value(self.inputs) + + +class Assignment: + """A single assignment in an ``AssignmentList``. + + The assignment is executed iff ``cond`` is true. When the assignment + is executed, ``len(value)`` bits starting at position `offset` are set + to the value ``value``, and the remaining bits are unchanged. + Assignments to out-of-bounds bit indices are ignored. + + Attributes + ---------- + + cond: Net + start: int + value: Value + src_loc: str + """ + def __init__(self, *, cond, start, value, src_loc): + assert isinstance(start, int) + self.cond = Net.ensure(cond) + self.start = start + self.value = Value(value) + self.src_loc = src_loc + + def resolve_nets(self, netlist: Netlist): + self.cond = netlist.resolve_net(self.cond) + self.value = netlist.resolve_value(self.value) + + +class AssignmentList(Cell): + """Used to represent a single assigned signal on the data plane of processes. + + The output of this cell is determined by starting with the ``default`` value, + then executing each assignment in sequence. + + Note: the RTLIL backend requires all ``cond`` inputs of assignments to be driven + by a ``PriorityMatch`` cell within the same module. + + Attributes + ---------- + default: Value + assignments: tuple of ``Assignment`` + """ + def __init__(self, module_idx, *, default, assignments, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assignments = tuple(assignments) + for assign in assignments: + assert isinstance(assign, Assignment) + + self.default = Value(default) + self.assignments: tuple[Assignment, ...] = assignments + + def input_nets(self): + nets = set(self.default) + for assign in self.assignments: + nets.add(assign.cond) + nets |= set(assign.value) + return nets + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.default))} + + def resolve_nets(self, netlist: Netlist): + for assign in self.assignments: + assign.resolve_nets(netlist) + self.default = netlist.resolve_value(self.default) + + +class FlipFlop(Cell): + """A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value. + ``clk`` and ``clk_edge`` work as in a ``ClockDomain``. ``arst`` is the async reset signal, + or ``0`` if async reset is not used. + + Attributes + ---------- + + data: Value + init: int + clk: Net + clk_edge: str, either 'pos' or 'neg' + arst: Net + attributes: dict from str to int, Const, or str + """ + def __init__(self, module_idx, *, data, init, clk, clk_edge, arst, attributes, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert clk_edge in ('pos', 'neg') + assert type(init) is int + + self.data = Value(data) + self.init = init + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + self.arst = Net.ensure(arst) + self.attributes = attributes + + def input_nets(self): + return set(self.data) | {self.clk, self.arst} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.data))} + + def resolve_nets(self, netlist: Netlist): + self.data = netlist.resolve_value(self.data) + self.clk = netlist.resolve_net(self.clk) + self.arst = netlist.resolve_net(self.arst) + + +class Memory(Cell): + """Corresponds to ``Memory``. ``init`` must have length equal to ``depth``. + Read and write ports are separate cells. + + Attributes + ---------- + + width: int + depth: int + init: tuple of int + name: str + attributes: dict from str to int, Const, or str + """ + def __init__(self, module_idx, *, width, depth, init, name, attributes, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.width = width + self.depth = depth + self.init = tuple(init) + self.name = name + self.attributes = attributes + + def input_nets(self): + return set() + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + pass + + +class SyncWritePort(Cell): + """A single write port of a memory. This cell has no output. + + Attributes + ---------- + + memory: cell index of ``Memory`` + data: Value + addr: Value + en: Value + clk: Net + clk_edge: str, either 'pos' or 'neg' + """ + def __init__(self, module_idx, memory, *, data, addr, en, clk, clk_edge, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert clk_edge in ('pos', 'neg') + self.memory = memory + self.data = Value(data) + self.addr = Value(addr) + self.en = Value(en) + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + + def input_nets(self): + return set(self.data) | set(self.addr) | set(self.en) | {self.clk} + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + self.data = netlist.resolve_value(self.data) + self.addr = netlist.resolve_value(self.addr) + self.en = netlist.resolve_value(self.en) + self.clk = netlist.resolve_net(self.clk) + + +class AsyncReadPort(Cell): + """A single asynchronous read port of a memory. + + Attributes + ---------- + + memory: cell index of ``Memory`` + width: int + addr: Value + """ + def __init__(self, module_idx, memory, *, width, addr, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.memory = memory + self.width = width + self.addr = Value(addr) + + def input_nets(self): + return set(self.addr) + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.addr = netlist.resolve_value(self.addr) + +class SyncReadPort(Cell): + """A single synchronous read port of a memory. The cell output is the data port. + ``transparent_for`` is the set of write ports (identified by cell index) that this + read port is transparent with. + + Attributes + ---------- + + memory: cell index of ``Memory`` + width: int + addr: Value + en: Net + clk: Net + clk_edge: str, either 'pos' or 'neg' + transparent_for: tuple of int + """ + def __init__(self, module_idx, memory, *, width, addr, en, clk, clk_edge, transparent_for, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert clk_edge in ('pos', 'neg') + self.memory = memory + self.width = width + self.addr = Value(addr) + self.en = Net.ensure(en) + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + self.transparent_for = tuple(transparent_for) + + def input_nets(self): + return set(self.addr) | {self.en, self.clk} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + self.addr = netlist.resolve_value(self.addr) + self.en = netlist.resolve_net(self.en) + self.clk = netlist.resolve_net(self.clk) + +class Initial(Cell): + """Corresponds to ``Initial`` value.""" + + def input_nets(self): + return set() + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, 0)} + + def resolve_nets(self, netlist: Netlist): + pass + + +class AnyValue(Cell): + """Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'`` + or ``'anyseq'``. + + Attributes + ---------- + + kind: str, 'anyconst' or 'anyseq' + width: int + """ + def __init__(self, module_idx, *, kind, width, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert kind in ('anyconst', 'anyseq') + self.kind = kind + self.width = width + + def input_nets(self): + return set() + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(self.width)} + + def resolve_nets(self, netlist: Netlist): + pass + + +class AsyncProperty(Cell): + """Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain._ + + Attributes + ---------- + + kind: str, either 'assert', 'assume', or 'cover' + test: Net + en: Net + name: str + """ + def __init__(self, module_idx, *, kind, test, en, name, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert kind in ('assert', 'assume', 'cover') + self.kind = kind + self.test = Net.ensure(test) + self.en = Net.ensure(en) + self.name = name + + def input_nets(self): + return {self.test, self.en} + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + self.test = netlist.resolve_net(self.test) + self.en = netlist.resolve_net(self.en) + + +class SyncProperty(Cell): + """Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain. + + Attributes + ---------- + + kind: str, either 'assert', 'assume', or 'cover' + test: Net + en: Net + clk: Net + clk_edge: str, either 'pos' or 'neg' + name: str + """ + + def __init__(self, module_idx, *, kind, test, en, clk, clk_edge, name, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + assert kind in ('assert', 'assume', 'cover') + assert clk_edge in ('pos', 'neg') + self.kind = kind + self.test = Net.ensure(test) + self.en = Net.ensure(en) + self.clk = Net.ensure(clk) + self.clk_edge = clk_edge + self.name = name + + def input_nets(self): + return {self.test, self.en, self.clk} + + def output_nets(self, self_idx: int): + return set() + + def resolve_nets(self, netlist: Netlist): + self.test = netlist.resolve_net(self.test) + self.en = netlist.resolve_net(self.en) + self.clk = netlist.resolve_net(self.clk) + + +class Instance(Cell): + """Corresponds to ``Instance``. ``type``, ``parameters`` and ``attributes`` work the same as in + ``Instance``. Input and inout ports are represented as a dict of port names to values. + Inout ports must be connected to nets corresponding to an IO port of the ``Top`` cell. + + Output ports are represented as a dict of port names to (start bit index, width) describing + their position in the virtual "output" of this cell. + + Attributes + ---------- + + type: str + name: str + parameters: dict of str to Const, int, or str + attributes: dict of str to Const, int, or str + ports_i: dict of str to Value + ports_o: dict of str to pair of int (index start, width) + ports_io: dict of str to Value + """ + + def __init__(self, module_idx, *, type, name, parameters, attributes, ports_i, ports_o, ports_io, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.type = type + self.name = name + self.parameters = parameters + self.attributes = attributes + self.ports_i = {name: Value(val) for name, val in ports_i.items()} + self.ports_o = ports_o + self.ports_io = {name: Value(val) for name, val in ports_io.items()} + + def input_nets(self): + nets = set() + for val in self.ports_i.values(): + nets |= set(val) + for val in self.ports_io.values(): + nets |= set(val) + return nets + + def output_nets(self, self_idx: int): + nets = set() + for start, width in self.ports_o.values(): + for bit in range(start, start + width): + nets.add(Net.from_cell(self_idx, bit)) + return nets + + def resolve_nets(self, netlist: Netlist): + for port in self.ports_i: + self.ports_i[port] = netlist.resolve_value(self.ports_i[port]) + for port in self.ports_io: + self.ports_io[port] = netlist.resolve_value(self.ports_io[port]) + + +class IOBuffer(Cell): + """An IO buffer cell. ``pad`` must be connected to nets corresponding to an IO port + of the ``Top`` cell. This cell does two things: + + - a tristate buffer is inserted driving ``pad`` based on ``o`` and ``oe`` nets (output buffer) + - the value of ``pad`` is sampled and made available as output of this cell (input buffer) + + Attributes + ---------- + + pad: Value + o: Value + oe: Net + """ + def __init__(self, module_idx, *, pad, o, oe, src_loc): + super().__init__(module_idx, src_loc=src_loc) + + self.pad = Value(pad) + self.o = Value(o) + self.oe = Net.ensure(oe) + + def input_nets(self): + return set(self.pad) | set(self.o) | {self.oe} + + def output_nets(self, self_idx: int): + return {Net.from_cell(self_idx, bit) for bit in range(len(self.pad))} + + def resolve_nets(self, netlist: Netlist): + self.pad = netlist.resolve_value(self.pad) + self.o = netlist.resolve_value(self.o) + self.oe = netlist.resolve_net(self.oe) diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index a8cf574..ec3c057 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -16,7 +16,6 @@ __all__ = ["ValueVisitor", "ValueTransformer", "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", - "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", "ResetInserter", "EnableInserter", "AssignmentLegalizer"] @@ -195,7 +194,7 @@ class StatementTransformer(StatementVisitor): return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs)) def on_Property(self, stmt): - return Property(stmt.kind, self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) + return Property(stmt.kind, self.on_value(stmt.test), name=stmt.name) def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) @@ -533,97 +532,6 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer) return new_fragment -class SwitchCleaner(StatementVisitor): - def on_ignore(self, stmt): - return stmt - - on_Assign = on_ignore - on_Property = on_ignore - - def on_Switch(self, stmt): - cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) - if any(len(s) for s in cases.values()): - return Switch(stmt.test, cases) - - def on_statements(self, stmts): - stmts = flatten(self.on_statement(stmt) for stmt in stmts) - return _StatementList(stmt for stmt in stmts if stmt is not None) - - -class LHSGroupAnalyzer(StatementVisitor): - def __init__(self): - self.signals = SignalDict() - self.unions = OrderedDict() - - def find(self, signal): - if signal not in self.signals: - self.signals[signal] = len(self.signals) - group = self.signals[signal] - while group in self.unions: - group = self.unions[group] - self.signals[signal] = group - return group - - def unify(self, root, *leaves): - root_group = self.find(root) - for leaf in leaves: - leaf_group = self.find(leaf) - if root_group == leaf_group: - continue - self.unions[leaf_group] = root_group - - def groups(self): - groups = OrderedDict() - for signal in self.signals: - group = self.find(signal) - if group not in groups: - groups[group] = SignalSet() - groups[group].add(signal) - return groups - - def on_Assign(self, stmt): - lhs_signals = stmt._lhs_signals() - if lhs_signals: - self.unify(*stmt._lhs_signals()) - - def on_Property(self, stmt): - lhs_signals = stmt._lhs_signals() - if lhs_signals: - self.unify(*stmt._lhs_signals()) - - def on_Switch(self, stmt): - for case_stmts in stmt.cases.values(): - self.on_statements(case_stmts) - - def on_statements(self, stmts): - assert not isinstance(stmts, str) - for stmt in stmts: - self.on_statement(stmt) - - def __call__(self, stmts): - self.on_statements(stmts) - return self.groups() - - -class LHSGroupFilter(SwitchCleaner): - def __init__(self, signals): - self.signals = signals - - def on_Assign(self, stmt): - # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together - # on LHS are a part of the same group, so it is sufficient to check any of them. - lhs_signals = stmt.lhs._lhs_signals() - if lhs_signals: - any_lhs_signal = next(iter(lhs_signals)) - if any_lhs_signal in self.signals: - return stmt - - def on_Property(self, stmt): - any_lhs_signal = next(iter(stmt._lhs_signals())) - if any_lhs_signal in self.signals: - return stmt - - class _ControlInserter(FragmentTransformer): def __init__(self, controls): self.src_loc = None @@ -655,10 +563,23 @@ class ResetInserter(_ControlInserter): fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) +class _PropertyEnableInserter(StatementTransformer): + def __init__(self, en): + self.en = en + + def on_Property(self, stmt): + return Switch( + self.en, + {1: [stmt]}, + src_loc=stmt.src_loc, + ) + + class EnableInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(s) for s in signals] fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) + fragment.statements[domain] = _PropertyEnableInserter(self.controls[domain])(fragment.statements[domain]) def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) diff --git a/amaranth/hdl/xfrm.py b/amaranth/hdl/xfrm.py index 2b96db6..b490f00 100644 --- a/amaranth/hdl/xfrm.py +++ b/amaranth/hdl/xfrm.py @@ -9,7 +9,6 @@ __all__ = ["ValueVisitor", "ValueTransformer", "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", - "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", "ResetInserter", "EnableInserter"] diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 94c50ec..56cc2ec 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -244,136 +244,6 @@ class DomainLowererTestCase(FHDLTestCase): DomainLowerer()(f) -class SwitchCleanerTestCase(FHDLTestCase): - def test_clean(self): - a = Signal() - b = Signal() - c = Signal() - stmts = [ - Switch(a, { - 1: a.eq(0), - 0: [ - b.eq(1), - Switch(b, {1: [ - Switch(a|b, {}) - ]}) - ] - }) - ] - - self.assertRepr(SwitchCleaner()(stmts), """ - ( - (switch (sig a) - (case 1 - (eq (sig a) (const 1'd0))) - (case 0 - (eq (sig b) (const 1'd1))) - ) - ) - """) - - -class LHSGroupAnalyzerTestCase(FHDLTestCase): - def test_no_group_unrelated(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - b.eq(0), - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a,)), - SignalSet((b,)), - ]) - - def test_group_related(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - Cat(a, b).eq(0), - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a, b)), - ]) - - def test_no_loops(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - Cat(a, b).eq(0), - Cat(a, b).eq(0), - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a, b)), - ]) - - def test_switch(self): - a = Signal() - b = Signal() - stmts = [ - a.eq(0), - Switch(a, { - 1: b.eq(0), - }) - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - SignalSet((a,)), - SignalSet((b,)), - ]) - - def test_lhs_empty(self): - stmts = [ - Cat().eq(0) - ] - - groups = LHSGroupAnalyzer()(stmts) - self.assertEqual(list(groups.values()), [ - ]) - - -class LHSGroupFilterTestCase(FHDLTestCase): - def test_filter(self): - a = Signal() - b = Signal() - c = Signal() - stmts = [ - Switch(a, { - 1: a.eq(0), - 0: [ - b.eq(1), - Switch(b, {1: []}) - ] - }) - ] - - self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """ - ( - (switch (sig a) - (case 1 - (eq (sig a) (const 1'd0))) - (case 0 ) - ) - ) - """) - - def test_lhs_empty(self): - stmts = [ - Cat().eq(0) - ] - - self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()") - - class ResetInserterTestCase(FHDLTestCase): def setUp(self): self.s1 = Signal()