hdl._ir: add caches for Matches and PriorityMatch cells.

This commit is contained in:
Wanda 2024-04-06 09:28:32 +02:00 committed by Catherine
parent df589a54e0
commit 7936b87667
3 changed files with 78 additions and 93 deletions

View file

@ -697,6 +697,8 @@ class NetlistEmitter:
self.drivers = _ast.SignalDict()
self.io_ports: dict[_ast.IOPort, int] = {}
self.rhs_cache: dict[int, Tuple[_nir.Value, bool, _ast.Value]] = {}
self.matches_cache = {}
self.priority_match_cache = {}
self.fragment_module_idx: dict[Fragment, int] = {}
# Collected for driver conflict diagnostics only.
@ -774,6 +776,26 @@ class NetlistEmitter:
def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc):
op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc)
return self.netlist.add_value_cell(op.width, op)
def emit_matches(self, module_idx: int, value: _nir.Value, patterns, *, src_loc):
key = module_idx, value, patterns, src_loc
try:
return self.matches_cache[key]
except KeyError:
cell = _nir.Matches(module_idx, value=value, patterns=patterns, src_loc=src_loc)
net, = self.netlist.add_value_cell(1, cell)
self.matches_cache[key] = net
return net
def emit_priority_match(self, module_idx: int, en: _nir.Net, inputs: _nir.Value, *, src_loc):
key = module_idx, en, inputs, src_loc
try:
return self.priority_match_cache[key]
except KeyError:
cell = _nir.PriorityMatch(module_idx, en=en, inputs=inputs, src_loc=src_loc)
res = self.netlist.add_value_cell(len(inputs), cell)
self.priority_match_cache[key] = res
return res
def unify_shapes_bitwise(self,
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
@ -928,19 +950,13 @@ class NetlistEmitter:
elems = []
for patterns, elem, in value.cases:
if patterns is not None:
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=value.src_loc)
net, = self.netlist.add_value_cell(1, cell)
net = self.emit_matches(module_idx, test, patterns, src_loc=value.src_loc)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
elems.append(self.emit_rhs(module_idx, elem))
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
inputs=_nir.Value(conds),
src_loc=value.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
conds = self.emit_priority_match(module_idx, _nir.Net.from_const(1),
_nir.Value(conds), src_loc=value.src_loc)
shape = _ast.Shape._unify(
_ast.Shape(len(value), signed)
for value, signed in elems
@ -1043,14 +1059,12 @@ class NetlistEmitter:
num_cases = min((width + lhs.stride - 1) // lhs.stride, 1 << len(offset))
conds = []
for case_index in range(num_cases):
cell = _nir.Matches(module_idx, value=offset,
patterns=(to_binary(case_index, len(offset)),),
src_loc=lhs.src_loc)
subcond, = self.netlist.add_value_cell(1, cell)
subcond = self.emit_matches(module_idx, offset,
(to_binary(case_index, len(offset)),),
src_loc=lhs.src_loc)
conds.append(subcond)
conds = _nir.Value(conds)
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds),
src_loc=lhs.src_loc)
for idx, subcond in enumerate(conds):
start = lhs_start + idx * lhs.stride
if start >= width:
@ -1066,19 +1080,13 @@ class NetlistEmitter:
elems = []
for patterns, elem in lhs.cases:
if patterns is not None:
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=lhs.src_loc)
net, = self.netlist.add_value_cell(1, cell)
net = self.emit_matches(module_idx, test, patterns, src_loc=lhs.src_loc)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
elems.append(elem)
conds = _nir.Value(conds)
cell = _nir.PriorityMatch(module_idx, en=cond,
inputs=conds, src_loc=lhs.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds),
src_loc=lhs.src_loc)
for subcond, val in zip(conds, elems):
self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.Operator):
@ -1163,18 +1171,13 @@ class NetlistEmitter:
case_stmts = []
for patterns, stmts, case_src_loc in stmt.cases:
if patterns is not None:
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=case_src_loc)
net, = self.netlist.add_value_cell(1, cell)
net = self.emit_matches(module_idx, test, patterns, src_loc=case_src_loc)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
case_stmts.append(stmts)
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds),
src_loc=stmt.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds),
src_loc=stmt.src_loc)
for subcond, substmts in zip(conds, case_stmts):
for substmt in substmts:
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
@ -1337,15 +1340,13 @@ class NetlistEmitter:
driver.domain.rst is not None and
not driver.domain.async_reset and
not driver.signal.reset_less):
cell = _nir.Matches(driver.module_idx,
value=self.emit_signal(driver.domain.rst),
patterns=("1",),
src_loc=driver.domain.rst.src_loc)
cond, = self.netlist.add_value_cell(1, cell)
cell = _nir.PriorityMatch(driver.module_idx, en=_nir.Net.from_const(1),
inputs=_nir.Value(cond),
src_loc=driver.domain.rst.src_loc)
cond, = self.netlist.add_value_cell(1, cell)
cond = self.emit_matches(driver.module_idx,
self.emit_signal(driver.domain.rst),
("1",),
src_loc=driver.domain.rst.src_loc)
cond, = self.emit_priority_match(driver.module_idx, _nir.Net.from_const(1),
_nir.Value(cond),
src_loc=driver.domain.rst.src_loc)
init = _nir.Value.from_const(driver.signal.init, len(driver.signal))
driver.assignments.append(_nir.Assignment(cond=cond, start=0,
value=init, src_loc=driver.signal.src_loc))

View file

@ -659,6 +659,8 @@ class Matches(Cell):
def __init__(self, module_idx, *, value, patterns, src_loc):
super().__init__(module_idx, src_loc=src_loc)
for pattern in patterns:
assert len(pattern) == len(value)
self.value = Value(value)
self.patterns = tuple(patterns)

View file

@ -3247,8 +3247,6 @@ class SwitchTestCase(FHDLTestCase):
ClockSignal("b"), ResetSignal("b"),
ClockSignal("c"),
])
# TODO: inefficiency in NIR emitter:
# matches and priority_match duplicated between clock domains — add cache?
self.assertRepr(nl, """
(
(module 0 None ('top')
@ -3259,11 +3257,11 @@ class SwitchTestCase(FHDLTestCase):
(input 'b_clk' 0.13)
(input 'b_rst' 0.14)
(input 'c_clk' 0.15)
(output 'o1' 8.0:8)
(output 'o2' 12.0:8)
(output 'o3' 14.0:8)
(output 'o4' 16.0:8)
(output 'o5' 18.0:8)
(output 'o1' 4.0:8)
(output 'o2' 8.0:8)
(output 'o3' 10.0:8)
(output 'o4' 12.0:8)
(output 'o5' 14.0:8)
)
(cell 0 0 (top
(input 'i1' 2:10)
@ -3273,30 +3271,26 @@ class SwitchTestCase(FHDLTestCase):
(input 'b_clk' 13:14)
(input 'b_rst' 14:15)
(input 'c_clk' 15:16)
(output 'o1' 8.0:8)
(output 'o2' 12.0:8)
(output 'o3' 14.0:8)
(output 'o4' 16.0:8)
(output 'o5' 18.0:8)
(output 'o1' 4.0:8)
(output 'o2' 8.0:8)
(output 'o3' 10.0:8)
(output 'o4' 12.0:8)
(output 'o5' 14.0:8)
))
(cell 1 0 (matches 0.10 1))
(cell 2 0 (priority_match 1 1.0))
(cell 3 0 (matches 0.10 1))
(cell 4 0 (priority_match 1 3.0))
(cell 5 0 (matches 0.10 1))
(cell 3 0 (assignment_list 4.0:8 (2.0 0:8 0.2:10)))
(cell 4 0 (flipflop 3.0:8 0 pos 0.11 0))
(cell 5 0 (matches 0.12 1))
(cell 6 0 (priority_match 1 5.0))
(cell 7 0 (assignment_list 8.0:8 (2.0 0:8 0.2:10)))
(cell 8 0 (flipflop 7.0:8 0 pos 0.11 0))
(cell 9 0 (matches 0.12 1))
(cell 10 0 (priority_match 1 9.0))
(cell 11 0 (assignment_list 12.0:8 (2.0 0:8 0.2:10) (10.0 0:8 8'd123)))
(cell 12 0 (flipflop 11.0:8 123 pos 0.11 0))
(cell 13 0 (assignment_list 14.0:8 (4.0 0:8 0.2:10)))
(cell 14 0 (flipflop 13.0:8 45 pos 0.13 0))
(cell 15 0 (assignment_list 16.0:8 (4.0 0:8 0.2:10)))
(cell 16 0 (flipflop 15.0:8 67 pos 0.13 0.14))
(cell 17 0 (assignment_list 18.0:8 (6.0 0:8 0.2:10)))
(cell 18 0 (flipflop 17.0:8 89 neg 0.15 0))
(cell 7 0 (assignment_list 8.0:8 (2.0 0:8 0.2:10) (6.0 0:8 8'd123)))
(cell 8 0 (flipflop 7.0:8 123 pos 0.11 0))
(cell 9 0 (assignment_list 10.0:8 (2.0 0:8 0.2:10)))
(cell 10 0 (flipflop 9.0:8 45 pos 0.13 0))
(cell 11 0 (assignment_list 12.0:8 (2.0 0:8 0.2:10)))
(cell 12 0 (flipflop 11.0:8 67 pos 0.13 0.14))
(cell 13 0 (assignment_list 14.0:8 (2.0 0:8 0.2:10)))
(cell 14 0 (flipflop 13.0:8 89 neg 0.15 0))
)
""")
@ -3348,18 +3342,12 @@ class SwitchTestCase(FHDLTestCase):
(cell 4 0 (print 3.0 ((u 0.2:8 ''))))
(cell 5 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 6 0 (print 5.0 ((s 0.8:16 '') '\\n')))
(cell 7 0 (matches 0.16 1))
(cell 8 0 (priority_match 1 7.0))
(cell 9 0 (assignment_list 1'd0 (8.0 0:1 1'd1)))
(cell 10 0 (print 9.0 pos 0.17 ((u 0.2:8 '') ' ' (s 0.8:16 '') '\\n')))
(cell 11 0 (matches 0.16 1))
(cell 12 0 (priority_match 1 11.0))
(cell 13 0 (assignment_list 1'd0 (12.0 0:1 1'd1)))
(cell 14 0 (print 13.0 pos 0.19 ('values: ' (u 0.2:8 '02x') ', ' (s 0.8:16 '+d') '\\n')))
(cell 15 0 (matches 0.16 1))
(cell 16 0 (priority_match 1 15.0))
(cell 17 0 (assignment_list 1'd0 (16.0 0:1 1'd1)))
(cell 18 0 (print 17.0 neg 0.21 ('meow\\n')))
(cell 7 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 8 0 (print 7.0 pos 0.17 ((u 0.2:8 '') ' ' (s 0.8:16 '') '\\n')))
(cell 9 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 10 0 (print 9.0 pos 0.19 ('values: ' (u 0.2:8 '02x') ', ' (s 0.8:16 '+d') '\\n')))
(cell 11 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 12 0 (print 11.0 neg 0.21 ('meow\\n')))
)
""")
@ -3409,18 +3397,12 @@ class SwitchTestCase(FHDLTestCase):
(cell 7 0 (b 0.2:8))
(cell 8 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 9 0 (cover 7.0 8.0 ('d')))
(cell 10 0 (matches 0.7 1))
(cell 11 0 (priority_match 1 10.0))
(cell 12 0 (assignment_list 1'd0 (11.0 0:1 1'd1)))
(cell 13 0 (assert 0.4 12.0 pos 0.8 None))
(cell 14 0 (matches 0.7 1))
(cell 15 0 (priority_match 1 14.0))
(cell 16 0 (assignment_list 1'd0 (15.0 0:1 1'd1)))
(cell 17 0 (assume 0.5 16.0 pos 0.10 ('value: ' (u 0.2:8 ''))))
(cell 18 0 (matches 0.7 1))
(cell 19 0 (priority_match 1 18.0))
(cell 20 0 (assignment_list 1'd0 (19.0 0:1 1'd1)))
(cell 21 0 (cover 0.6 20.0 neg 0.12 ('c')))
(cell 10 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 11 0 (assert 0.4 10.0 pos 0.8 None))
(cell 12 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 13 0 (assume 0.5 12.0 pos 0.10 ('value: ' (u 0.2:8 ''))))
(cell 14 0 (assignment_list 1'd0 (2.0 0:1 1'd1)))
(cell 15 0 (cover 0.6 14.0 neg 0.12 ('c')))
)
""")