From ea561378acb99554b52a57ac27f49a2f23761a2f Mon Sep 17 00:00:00 2001 From: Wanda Date: Sun, 3 Mar 2024 19:46:15 +0100 Subject: [PATCH] hdl._nir: Remove `ArrayMux`, use `AssignmentList` instead. --- amaranth/back/rtlil.py | 19 +------------- amaranth/hdl/_ir.py | 16 +++++++++++- amaranth/hdl/_nir.py | 38 +--------------------------- tests/test_hdl_ir.py | 56 +++++++++++++++++++++++++++++++++--------- 4 files changed, 61 insertions(+), 68 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index f495eb2..37989bf 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -446,7 +446,7 @@ class ModuleEmitter: continue # No outputs. elif isinstance(cell, _nir.AssignmentList): width = len(cell.default) - elif isinstance(cell, (_nir.Operator, _nir.Part, _nir.ArrayMux, _nir.AnyValue, + elif isinstance(cell, (_nir.Operator, _nir.Part, _nir.AnyValue, _nir.SyncReadPort, _nir.AsyncReadPort)): width = cell.width elif isinstance(cell, _nir.FlipFlop): @@ -738,21 +738,6 @@ class ModuleEmitter: "Y_WIDTH": cell.width, }, src=_src(cell.src_loc)) - def emit_array_mux(self, cell_idx, cell): - wire = self.cell_wires[cell_idx] - with self.builder.process(src=_src(cell.src_loc)) as proc: - with proc.case() as root_case: - with root_case.switch(self.sigspec(cell.index)) as switch: - for index, elem in enumerate(cell.elems): - if len(cell.index) > 0: - pattern = "{:0{}b}".format(index, len(cell.index)) - else: - pattern = "" - with switch.case(pattern) as case: - case.assign(wire, self.sigspec(elem)) - with switch.case() as case: - case.assign(wire, self.sigspec(cell.elems[0])) - def emit_flip_flop(self, cell_idx, cell): ports = { "D": self.sigspec(cell.data), @@ -944,8 +929,6 @@ class ModuleEmitter: self.emit_operator(cell_idx, cell) elif isinstance(cell, _nir.Part): self.emit_part(cell_idx, cell) - elif isinstance(cell, _nir.ArrayMux): - self.emit_array_mux(cell_idx, cell) elif isinstance(cell, _nir.FlipFlop): self.emit_flip_flop(cell_idx, cell) elif isinstance(cell, _nir.IOBuffer): diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index bace04a..f46f6f8 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -842,7 +842,21 @@ class NetlistEmitter: width = max(width, len(elem)) elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems) index, _signed = self.emit_rhs(module_idx, value.index) - cell = _nir.ArrayMux(module_idx, width=width, elems=elems, index=index, + conds = [] + for case_index in range(len(elems)): + cell = _nir.Matches(module_idx, value=index, + patterns=(f"{case_index:0{len(index)}b}",), + src_loc=value.src_loc) + subcond, = self.netlist.add_value_cell(1, cell) + conds.append(subcond) + conds = _nir.Value(conds) + cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), inputs=conds, src_loc=value.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) + assignments = [ + _nir.Assignment(cond=cond, start=0, value=elem, src_loc=value.src_loc) + for cond, elem in zip(conds, elems) + ] + cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments, src_loc=value.src_loc) result = self.netlist.add_value_cell(width, cell) elif isinstance(value, _ast.Cat): diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py index eefce58..cdc7229 100644 --- a/amaranth/hdl/_nir.py +++ b/amaranth/hdl/_nir.py @@ -8,7 +8,7 @@ __all__ = [ # Netlist core "Net", "Value", "Netlist", "ModuleNetFlow", "Module", "Cell", "Top", # Computation cells - "Operator", "Part", "ArrayMux", + "Operator", "Part", # Decision tree cells "Matches", "PriorityMatch", "Assignment", "AssignmentList", # Storage cells @@ -501,42 +501,6 @@ class Part(Cell): return f"(part {self.value} {value_signed} {self.offset} {self.width} {self.stride})" -class ArrayMux(Cell): - """Corresponds to ``hdl.ast.ArrayProxy``. All values in the ``elems`` array need to have - the same width as the output. - - Attributes - ---------- - - width: int (width of output and all inputs) - elems: tuple of Value - index: Value - """ - def __init__(self, module_idx, *, width, elems, index, src_loc): - super().__init__(module_idx, src_loc=src_loc) - - self.width = width - self.elems = tuple(Value(val) for val in elems) - self.index = Value(index) - - def input_nets(self): - nets = set(self.index) - for value in self.elems: - nets |= set(value) - return nets - - def output_nets(self, self_idx: int): - return {Net.from_cell(self_idx, bit) for bit in range(self.width)} - - def resolve_nets(self, netlist: Netlist): - self.elems = tuple(netlist.resolve_value(val) for val in self.elems) - self.index = netlist.resolve_value(self.index) - - def __repr__(self): - elems = " ".join(repr(elem) for elem in self.elems) - return f"(array_mux {self.width} {self.index} ({elems}))" - - class Matches(Cell): """A combinatorial cell performing a comparison like ``Value.matches`` (or, equivalently, a case condition). diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index 8c37a7a..1086bb2 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -2901,10 +2901,10 @@ class RhsTestCase(FHDLTestCase): (input 'i8sb' 0.34:42) (input 'i8sc' 0.42:50) (input 'i4' 0.50:54) - (output 'o1' (cat 1.0:8 2'd0)) - (output 'o2' (cat 2.0:9 2.8)) - (output 'o3' (cat 3.0:8 3.7 3.7)) - (output 'o4' (cat 4.0:8 4.7 4.7)) + (output 'o1' (cat 5.0:8 2'd0)) + (output 'o2' (cat 10.0:9 10.8)) + (output 'o3' (cat 15.0:8 15.7 15.7)) + (output 'o4' (cat 20.0:8 20.7 20.7)) ) (cell 0 0 (top (input 'i8ua' 2:10) @@ -2914,15 +2914,47 @@ class RhsTestCase(FHDLTestCase): (input 'i8sb' 34:42) (input 'i8sc' 42:50) (input 'i4' 50:54) - (output 'o1' (cat 1.0:8 2'd0)) - (output 'o2' (cat 2.0:9 2.8)) - (output 'o3' (cat 3.0:8 3.7 3.7)) - (output 'o4' (cat 4.0:8 4.7 4.7)) + (output 'o1' (cat 5.0:8 2'd0)) + (output 'o2' (cat 10.0:9 10.8)) + (output 'o3' (cat 15.0:8 15.7 15.7)) + (output 'o4' (cat 20.0:8 20.7 20.7)) + )) + (cell 1 0 (matches 0.50:54 0000)) + (cell 2 0 (matches 0.50:54 0001)) + (cell 3 0 (matches 0.50:54 0010)) + (cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0))) + (cell 5 0 (assignment_list 0.2:10 + (4.0 0:8 0.2:10) + (4.1 0:8 0.10:18) + (4.2 0:8 0.18:26) + )) + (cell 6 0 (matches 0.50:54 0000)) + (cell 7 0 (matches 0.50:54 0001)) + (cell 8 0 (matches 0.50:54 0010)) + (cell 9 0 (priority_match 1 (cat 6.0 7.0 8.0))) + (cell 10 0 (assignment_list (cat 0.2:10 1'd0) + (9.0 0:9 (cat 0.2:10 1'd0)) + (9.1 0:9 (cat 0.10:18 1'd0)) + (9.2 0:9 (cat 0.42:50 0.49)) + )) + (cell 11 0 (matches 0.50:54 0000)) + (cell 12 0 (matches 0.50:54 0001)) + (cell 13 0 (matches 0.50:54 0010)) + (cell 14 0 (priority_match 1 (cat 11.0 12.0 13.0))) + (cell 15 0 (assignment_list 0.26:34 + (14.0 0:8 0.26:34) + (14.1 0:8 0.34:42) + (14.2 0:8 0.42:50) + )) + (cell 16 0 (matches 0.50:54 0000)) + (cell 17 0 (matches 0.50:54 0001)) + (cell 18 0 (matches 0.50:54 0010)) + (cell 19 0 (priority_match 1 (cat 16.0 17.0 18.0))) + (cell 20 0 (assignment_list 0.26:34 + (19.0 0:8 0.26:34) + (19.1 0:8 0.34:42) + (19.2 0:8 (cat 0.50:54 4'd0)) )) - (cell 1 0 (array_mux 8 0.50:54 (0.2:10 0.10:18 0.18:26))) - (cell 2 0 (array_mux 9 0.50:54 ((cat 0.2:10 1'd0) (cat 0.10:18 1'd0) (cat 0.42:50 0.49)))) - (cell 3 0 (array_mux 8 0.50:54 (0.26:34 0.34:42 0.42:50))) - (cell 4 0 (array_mux 8 0.50:54 (0.26:34 0.34:42 (cat 0.50:54 4'd0)))) ) """)