diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index be0c4da..e1696b1 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -697,6 +697,8 @@ class NetlistEmitter: self.drivers = _ast.SignalDict() self.io_ports: dict[_ast.IOPort, int] = {} self.rhs_cache: dict[int, Tuple[_nir.Value, bool, _ast.Value]] = {} + self.matches_cache = {} + self.priority_match_cache = {} self.fragment_module_idx: dict[Fragment, int] = {} # Collected for driver conflict diagnostics only. @@ -774,6 +776,26 @@ class NetlistEmitter: def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc): op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc) return self.netlist.add_value_cell(op.width, op) + + def emit_matches(self, module_idx: int, value: _nir.Value, patterns, *, src_loc): + key = module_idx, value, patterns, src_loc + try: + return self.matches_cache[key] + except KeyError: + cell = _nir.Matches(module_idx, value=value, patterns=patterns, src_loc=src_loc) + net, = self.netlist.add_value_cell(1, cell) + self.matches_cache[key] = net + return net + + def emit_priority_match(self, module_idx: int, en: _nir.Net, inputs: _nir.Value, *, src_loc): + key = module_idx, en, inputs, src_loc + try: + return self.priority_match_cache[key] + except KeyError: + cell = _nir.PriorityMatch(module_idx, en=en, inputs=inputs, src_loc=src_loc) + res = self.netlist.add_value_cell(len(inputs), cell) + self.priority_match_cache[key] = res + return res def unify_shapes_bitwise(self, operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool): @@ -928,19 +950,13 @@ class NetlistEmitter: elems = [] for patterns, elem, in value.cases: if patterns is not None: - for pattern in patterns: - assert len(pattern) == len(test) - cell = _nir.Matches(module_idx, value=test, patterns=patterns, - src_loc=value.src_loc) - net, = self.netlist.add_value_cell(1, cell) + net = self.emit_matches(module_idx, test, patterns, src_loc=value.src_loc) conds.append(net) else: conds.append(_nir.Net.from_const(1)) elems.append(self.emit_rhs(module_idx, elem)) - cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), - inputs=_nir.Value(conds), - src_loc=value.src_loc) - conds = self.netlist.add_value_cell(len(conds), cell) + conds = self.emit_priority_match(module_idx, _nir.Net.from_const(1), + _nir.Value(conds), src_loc=value.src_loc) shape = _ast.Shape._unify( _ast.Shape(len(value), signed) for value, signed in elems @@ -1043,14 +1059,12 @@ class NetlistEmitter: num_cases = min((width + lhs.stride - 1) // lhs.stride, 1 << len(offset)) conds = [] for case_index in range(num_cases): - cell = _nir.Matches(module_idx, value=offset, - patterns=(to_binary(case_index, len(offset)),), - src_loc=lhs.src_loc) - subcond, = self.netlist.add_value_cell(1, cell) + subcond = self.emit_matches(module_idx, offset, + (to_binary(case_index, len(offset)),), + src_loc=lhs.src_loc) conds.append(subcond) - conds = _nir.Value(conds) - cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc) - conds = self.netlist.add_value_cell(len(conds), cell) + conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds), + src_loc=lhs.src_loc) for idx, subcond in enumerate(conds): start = lhs_start + idx * lhs.stride if start >= width: @@ -1066,19 +1080,13 @@ class NetlistEmitter: elems = [] for patterns, elem in lhs.cases: if patterns is not None: - for pattern in patterns: - assert len(pattern) == len(test) - cell = _nir.Matches(module_idx, value=test, patterns=patterns, - src_loc=lhs.src_loc) - net, = self.netlist.add_value_cell(1, cell) + net = self.emit_matches(module_idx, test, patterns, src_loc=lhs.src_loc) conds.append(net) else: conds.append(_nir.Net.from_const(1)) elems.append(elem) - conds = _nir.Value(conds) - cell = _nir.PriorityMatch(module_idx, en=cond, - inputs=conds, src_loc=lhs.src_loc) - conds = self.netlist.add_value_cell(len(conds), cell) + conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds), + src_loc=lhs.src_loc) for subcond, val in zip(conds, elems): self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc) elif isinstance(lhs, _ast.Operator): @@ -1163,18 +1171,13 @@ class NetlistEmitter: case_stmts = [] for patterns, stmts, case_src_loc in stmt.cases: if patterns is not None: - for pattern in patterns: - assert len(pattern) == len(test) - cell = _nir.Matches(module_idx, value=test, patterns=patterns, - src_loc=case_src_loc) - net, = self.netlist.add_value_cell(1, cell) + net = self.emit_matches(module_idx, test, patterns, src_loc=case_src_loc) conds.append(net) else: conds.append(_nir.Net.from_const(1)) case_stmts.append(stmts) - cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds), - src_loc=stmt.src_loc) - conds = self.netlist.add_value_cell(len(conds), cell) + conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds), + src_loc=stmt.src_loc) for subcond, substmts in zip(conds, case_stmts): for substmt in substmts: self.emit_stmt(module_idx, fragment, domain, substmt, subcond) @@ -1337,15 +1340,13 @@ class NetlistEmitter: driver.domain.rst is not None and not driver.domain.async_reset and not driver.signal.reset_less): - cell = _nir.Matches(driver.module_idx, - value=self.emit_signal(driver.domain.rst), - patterns=("1",), - src_loc=driver.domain.rst.src_loc) - cond, = self.netlist.add_value_cell(1, cell) - cell = _nir.PriorityMatch(driver.module_idx, en=_nir.Net.from_const(1), - inputs=_nir.Value(cond), - src_loc=driver.domain.rst.src_loc) - cond, = self.netlist.add_value_cell(1, cell) + cond = self.emit_matches(driver.module_idx, + self.emit_signal(driver.domain.rst), + ("1",), + src_loc=driver.domain.rst.src_loc) + cond, = self.emit_priority_match(driver.module_idx, _nir.Net.from_const(1), + _nir.Value(cond), + src_loc=driver.domain.rst.src_loc) init = _nir.Value.from_const(driver.signal.init, len(driver.signal)) driver.assignments.append(_nir.Assignment(cond=cond, start=0, value=init, src_loc=driver.signal.src_loc)) diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py index 828e858..0ee5b1d 100644 --- a/amaranth/hdl/_nir.py +++ b/amaranth/hdl/_nir.py @@ -659,6 +659,8 @@ class Matches(Cell): def __init__(self, module_idx, *, value, patterns, src_loc): super().__init__(module_idx, src_loc=src_loc) + for pattern in patterns: + assert len(pattern) == len(value) self.value = Value(value) self.patterns = tuple(patterns) diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index 0483d7a..6cb3071 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -3247,8 +3247,6 @@ class SwitchTestCase(FHDLTestCase): ClockSignal("b"), ResetSignal("b"), ClockSignal("c"), ]) - # TODO: inefficiency in NIR emitter: - # matches and priority_match duplicated between clock domains — add cache? self.assertRepr(nl, """ ( (module 0 None ('top') @@ -3259,11 +3257,11 @@ class SwitchTestCase(FHDLTestCase): (input 'b_clk' 0.13) (input 'b_rst' 0.14) (input 'c_clk' 0.15) - (output 'o1' 8.0:8) - (output 'o2' 12.0:8) - (output 'o3' 14.0:8) - (output 'o4' 16.0:8) - (output 'o5' 18.0:8) + (output 'o1' 4.0:8) + (output 'o2' 8.0:8) + (output 'o3' 10.0:8) + (output 'o4' 12.0:8) + (output 'o5' 14.0:8) ) (cell 0 0 (top (input 'i1' 2:10) @@ -3273,30 +3271,26 @@ class SwitchTestCase(FHDLTestCase): (input 'b_clk' 13:14) (input 'b_rst' 14:15) (input 'c_clk' 15:16) - (output 'o1' 8.0:8) - (output 'o2' 12.0:8) - (output 'o3' 14.0:8) - (output 'o4' 16.0:8) - (output 'o5' 18.0:8) + (output 'o1' 4.0:8) + (output 'o2' 8.0:8) + (output 'o3' 10.0:8) + (output 'o4' 12.0:8) + (output 'o5' 14.0:8) )) (cell 1 0 (matches 0.10 1)) (cell 2 0 (priority_match 1 1.0)) - (cell 3 0 (matches 0.10 1)) - (cell 4 0 (priority_match 1 3.0)) - (cell 5 0 (matches 0.10 1)) + (cell 3 0 (assignment_list 4.0:8 (2.0 0:8 0.2:10))) + (cell 4 0 (flipflop 3.0:8 0 pos 0.11 0)) + (cell 5 0 (matches 0.12 1)) (cell 6 0 (priority_match 1 5.0)) - (cell 7 0 (assignment_list 8.0:8 (2.0 0:8 0.2:10))) - (cell 8 0 (flipflop 7.0:8 0 pos 0.11 0)) - (cell 9 0 (matches 0.12 1)) - (cell 10 0 (priority_match 1 9.0)) - (cell 11 0 (assignment_list 12.0:8 (2.0 0:8 0.2:10) (10.0 0:8 8'd123))) - (cell 12 0 (flipflop 11.0:8 123 pos 0.11 0)) - (cell 13 0 (assignment_list 14.0:8 (4.0 0:8 0.2:10))) - (cell 14 0 (flipflop 13.0:8 45 pos 0.13 0)) - (cell 15 0 (assignment_list 16.0:8 (4.0 0:8 0.2:10))) - (cell 16 0 (flipflop 15.0:8 67 pos 0.13 0.14)) - (cell 17 0 (assignment_list 18.0:8 (6.0 0:8 0.2:10))) - (cell 18 0 (flipflop 17.0:8 89 neg 0.15 0)) + (cell 7 0 (assignment_list 8.0:8 (2.0 0:8 0.2:10) (6.0 0:8 8'd123))) + (cell 8 0 (flipflop 7.0:8 123 pos 0.11 0)) + (cell 9 0 (assignment_list 10.0:8 (2.0 0:8 0.2:10))) + (cell 10 0 (flipflop 9.0:8 45 pos 0.13 0)) + (cell 11 0 (assignment_list 12.0:8 (2.0 0:8 0.2:10))) + (cell 12 0 (flipflop 11.0:8 67 pos 0.13 0.14)) + (cell 13 0 (assignment_list 14.0:8 (2.0 0:8 0.2:10))) + (cell 14 0 (flipflop 13.0:8 89 neg 0.15 0)) ) """) @@ -3348,18 +3342,12 @@ class SwitchTestCase(FHDLTestCase): (cell 4 0 (print 3.0 ((u 0.2:8 '')))) (cell 5 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) (cell 6 0 (print 5.0 ((s 0.8:16 '') '\\n'))) - (cell 7 0 (matches 0.16 1)) - (cell 8 0 (priority_match 1 7.0)) - (cell 9 0 (assignment_list 1'd0 (8.0 0:1 1'd1))) - (cell 10 0 (print 9.0 pos 0.17 ((u 0.2:8 '') ' ' (s 0.8:16 '') '\\n'))) - (cell 11 0 (matches 0.16 1)) - (cell 12 0 (priority_match 1 11.0)) - (cell 13 0 (assignment_list 1'd0 (12.0 0:1 1'd1))) - (cell 14 0 (print 13.0 pos 0.19 ('values: ' (u 0.2:8 '02x') ', ' (s 0.8:16 '+d') '\\n'))) - (cell 15 0 (matches 0.16 1)) - (cell 16 0 (priority_match 1 15.0)) - (cell 17 0 (assignment_list 1'd0 (16.0 0:1 1'd1))) - (cell 18 0 (print 17.0 neg 0.21 ('meow\\n'))) + (cell 7 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) + (cell 8 0 (print 7.0 pos 0.17 ((u 0.2:8 '') ' ' (s 0.8:16 '') '\\n'))) + (cell 9 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) + (cell 10 0 (print 9.0 pos 0.19 ('values: ' (u 0.2:8 '02x') ', ' (s 0.8:16 '+d') '\\n'))) + (cell 11 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) + (cell 12 0 (print 11.0 neg 0.21 ('meow\\n'))) ) """) @@ -3409,18 +3397,12 @@ class SwitchTestCase(FHDLTestCase): (cell 7 0 (b 0.2:8)) (cell 8 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) (cell 9 0 (cover 7.0 8.0 ('d'))) - (cell 10 0 (matches 0.7 1)) - (cell 11 0 (priority_match 1 10.0)) - (cell 12 0 (assignment_list 1'd0 (11.0 0:1 1'd1))) - (cell 13 0 (assert 0.4 12.0 pos 0.8 None)) - (cell 14 0 (matches 0.7 1)) - (cell 15 0 (priority_match 1 14.0)) - (cell 16 0 (assignment_list 1'd0 (15.0 0:1 1'd1))) - (cell 17 0 (assume 0.5 16.0 pos 0.10 ('value: ' (u 0.2:8 '')))) - (cell 18 0 (matches 0.7 1)) - (cell 19 0 (priority_match 1 18.0)) - (cell 20 0 (assignment_list 1'd0 (19.0 0:1 1'd1))) - (cell 21 0 (cover 0.6 20.0 neg 0.12 ('c'))) + (cell 10 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) + (cell 11 0 (assert 0.4 10.0 pos 0.8 None)) + (cell 12 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) + (cell 13 0 (assume 0.5 12.0 pos 0.10 ('value: ' (u 0.2:8 '')))) + (cell 14 0 (assignment_list 1'd0 (2.0 0:1 1'd1))) + (cell 15 0 (cover 0.6 14.0 neg 0.12 ('c'))) ) """)