hdl._ast: change Switch to operate on list of cases.

This commit is contained in:
Wanda 2024-04-02 23:04:56 +02:00 committed by Catherine
parent cd6cbd71ca
commit 2eb62a8b49
8 changed files with 95 additions and 90 deletions

View file

@ -638,6 +638,8 @@ class ModuleEmitter:
assert isinstance(matches_cell, _nir.Matches)
assert test == matches_cell.value
patterns = matches_cell.patterns
# RTLIL cannot support empty pattern sets.
assert patterns
with switch.case(*patterns) as subcase:
emit_assignments(subcase, subcond)
emitted_switch = True

View file

@ -2763,38 +2763,34 @@ class _LateBoundStatement(Statement):
@final
class Switch(Statement):
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
if src_loc is None:
super().__init__(src_loc_at=src_loc_at)
else:
# Switch is a bit special in terms of location tracking because it is usually created
# long after the control has left the statement that directly caused its creation.
self.src_loc = src_loc
# Switch is also a bit special in that its parts also have location information. It can't
# be automatically traced, so whatever constructs a Switch may optionally provide it.
self.case_src_locs = {}
self._test = Value.cast(test)
self._cases = OrderedDict()
for orig_keys, stmts in cases.items():
# Map: None -> (); key -> (key,); (key...) -> (key...)
keys = orig_keys
if keys is None:
keys = ()
if not isinstance(keys, tuple):
keys = (keys,)
# Map: 2 -> "0010"; "0010" -> "0010"
new_keys = ()
key_mask = (1 << len(self.test)) - 1
for key in _normalize_patterns(keys, self._test.shape()):
if isinstance(key, int):
key = to_binary(key & key_mask, len(self.test))
new_keys = (*new_keys, key)
self._cases = []
for patterns, stmts, case_src_loc in cases:
if patterns is not None:
# Map: key -> (key,); (key...) -> (key...)
if not isinstance(patterns, tuple):
patterns = (patterns,)
# Map: 2 -> "0010"; "0010" -> "0010"
new_patterns = ()
key_mask = (1 << len(self.test)) - 1
for key in _normalize_patterns(patterns, self._test.shape()):
if isinstance(key, int):
key = to_binary(key & key_mask, len(self.test))
new_patterns = (*new_patterns, key)
else:
new_patterns = None
if not isinstance(stmts, Iterable):
stmts = [stmts]
self._cases[new_keys] = Statement.cast(stmts)
if orig_keys in case_src_locs:
self.case_src_locs[new_keys] = case_src_locs[orig_keys]
self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
self._cases = tuple(self._cases)
@property
def test(self):
@ -2805,22 +2801,22 @@ class Switch(Statement):
return self._cases
def _lhs_signals(self):
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())
return union((stmts._lhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet())
def _rhs_signals(self):
signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
signals = union((stmts._rhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet())
return self.test._rhs_signals() | signals
def __repr__(self):
def case_repr(keys, stmts):
def case_repr(patterns, stmts):
stmts_repr = " ".join(map(repr, stmts))
if keys == ():
if patterns is None:
return f"(default {stmts_repr})"
elif len(keys) == 1:
return f"(case {keys[0]} {stmts_repr})"
elif len(patterns) == 1:
return f"(case {patterns[0]} {stmts_repr})"
else:
return "(case ({}) {})".format(" ".join(keys), stmts_repr)
case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
return "(case ({}) {})".format(" ".join(patterns), stmts_repr)
case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases]
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))

View file

@ -169,12 +169,11 @@ def resolve_statement(stmt):
elif isinstance(stmt, Switch):
return Switch(
test=stmt.test,
cases=OrderedDict(
(patterns, resolve_statements(stmts))
for patterns, stmts in stmt.cases.items()
),
cases=[
(patterns, resolve_statements(stmts), src_loc)
for patterns, stmts, src_loc in stmt.cases
],
src_loc=stmt.src_loc,
case_src_locs=stmt.case_src_locs,
)
elif isinstance(stmt, (Assign, Property, Print)):
return stmt
@ -318,9 +317,9 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._check_context("Switch", context=None)
switch_data = self._set_ctrl("Switch", {
"test": Value.cast(test),
"cases": OrderedDict(),
"cases": [],
"src_loc": tracer.get_src_loc(src_loc_at=1),
"case_src_locs": {},
"got_default": False,
})
try:
self._ctrl_context = "Switch"
@ -336,7 +335,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._check_context("Case", context="Switch")
src_loc = tracer.get_src_loc(src_loc_at=1)
switch_data = self._get_ctrl("Switch")
if () in switch_data["cases"]:
if switch_data["got_default"]:
warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3)
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
@ -345,12 +344,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._ctrl_context = None
yield
self._flush_ctrl()
# If none of the provided cases can possibly be true, omit this branch completely.
# Likewise, omit this branch if another branch with this exact set of patterns already
# exists (since otherwise we'd overwrite the previous branch's slot in the dict).
if new_patterns and new_patterns not in switch_data["cases"]:
switch_data["cases"][new_patterns] = self._statements
switch_data["case_src_locs"][new_patterns] = src_loc
switch_data["cases"].append((new_patterns, self._statements, src_loc))
finally:
self._ctrl_context = "Switch"
self._statements = _outer_case
@ -360,7 +354,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._check_context("Default", context="Switch")
src_loc = tracer.get_src_loc(src_loc_at=1)
switch_data = self._get_ctrl("Switch")
if () in switch_data["cases"]:
if switch_data["got_default"]:
warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3)
try:
@ -368,9 +362,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._ctrl_context = None
yield
self._flush_ctrl()
if () not in switch_data["cases"]:
switch_data["cases"][()] = self._statements
switch_data["case_src_locs"][()] = src_loc
switch_data["cases"].append((None, self._statements, src_loc))
switch_data["got_default"] = True
finally:
self._ctrl_context = "Switch"
self._statements = _outer_case
@ -471,8 +464,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
domains[domain] = None
for domain in domains:
tests, cases = [], OrderedDict()
for if_test, if_case in zip(if_tests + [None], if_bodies):
tests, cases = [], []
for if_test, if_case, if_src_loc in zip(if_tests + [None], if_bodies, if_src_locs):
if if_test is not None:
if len(if_test) != 1:
if_test = if_test.bool()
@ -482,27 +475,26 @@ class Module(_ModuleBuilderRoot, Elaboratable):
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
else:
match = None
cases[match] = if_case.get(domain, [])
cases.append((match, if_case.get(domain, []), if_src_loc))
self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
src_loc=src_loc))
if name == "Switch":
switch_test, switch_cases = data["test"], data["cases"]
switch_case_src_locs = data["case_src_locs"]
domains = {}
for stmts in switch_cases.values():
for _patterns, stmts, _src_loc in switch_cases:
for domain in stmts:
domains[domain] = None
for domain in domains:
domain_cases = OrderedDict()
for patterns, stmts in switch_cases.items():
domain_cases[patterns] = stmts.get(domain, [])
domain_cases = []
for patterns, stmts, case_src_loc in switch_cases:
domain_cases.append((patterns, stmts.get(domain, []), case_src_loc))
self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
src_loc=src_loc, case_src_locs=switch_case_src_locs))
src_loc=src_loc))
if name == "FSM":
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
@ -536,9 +528,11 @@ class Module(_ModuleBuilderRoot, Elaboratable):
domain_states[state] = stmts.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
for name in fsm_states}))
[
(fsm_encoding[name], stmts, fsm_state_src_locs[name])
for name, stmts in domain_states.items()
],
src_loc=src_loc))
def _add_statement(self, assigns, domain, depth):
while len(self._ctrl_stack) > self.domain._depth:

View file

@ -1110,20 +1110,25 @@ class NetlistEmitter:
elif isinstance(stmt, _ast.Switch):
test, _signed = self.emit_rhs(module_idx, stmt.test)
conds = []
for patterns in stmt.cases:
if patterns:
case_stmts = []
for patterns, stmts, case_src_loc in stmt.cases:
if patterns is not None:
if not patterns:
# Hack: empty pattern set cannot be supported by RTLIL.
continue
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=stmt.case_src_locs.get(patterns))
src_loc=case_src_loc)
net, = self.netlist.add_value_cell(1, cell)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
case_stmts.append(stmts)
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds),
src_loc=stmt.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
for subcond, substmts in zip(conds, stmt.cases.values()):
for subcond, substmts in zip(conds, case_stmts):
for substmt in substmts:
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
else:

View file

@ -183,8 +183,6 @@ class StatementVisitor(metaclass=ABCMeta):
new_stmt = self.on_unknown_statement(stmt)
if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt):
new_stmt.src_loc = stmt.src_loc
if isinstance(new_stmt, Switch) and isinstance(stmt, Switch):
new_stmt.case_src_locs = stmt.case_src_locs
if isinstance(new_stmt, (Print, Property)):
new_stmt._MustUse__used = True
return new_stmt
@ -221,7 +219,7 @@ class StatementTransformer(StatementVisitor):
return Property(stmt.kind, self.on_value(stmt.test), message)
def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
cases = [(k, self.on_statement(s), l) for k, s, l in stmt.cases]
return Switch(self.on_value(stmt.test), cases)
def on_statements(self, stmts):
@ -429,7 +427,7 @@ class DomainCollector(ValueVisitor, StatementVisitor):
def on_Switch(self, stmt):
self.on_value(stmt.test)
for stmts in stmt.cases.values():
for _patterns, stmts, _src_loc in stmt.cases:
self.on_statement(stmts)
def on_statements(self, stmts):
@ -624,7 +622,7 @@ class _ControlInserter(FragmentTransformer):
class ResetInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less]
fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc))
class EnableInserter(_ControlInserter):
@ -632,7 +630,7 @@ class EnableInserter(_ControlInserter):
if domain in fragment.statements:
fragment.statements[domain] = _StatementList([Switch(
self.controls[domain],
{1: fragment.statements[domain]},
[(1, fragment.statements[domain], None)],
src_loc=self.src_loc,
)])

View file

@ -396,10 +396,12 @@ class _StatementCompiler(StatementVisitor, _Compiler):
def on_Switch(self, stmt):
gen_test_value = self.rhs(stmt.test) # check for oversized value before generating mask
gen_test = self.emitter.def_var("test", f"{(1 << len(stmt.test)) - 1:#x} & {gen_test_value}")
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
for index, (patterns, stmts, _src_loc) in enumerate(stmt.cases):
gen_checks = []
if not patterns:
if patterns is None:
gen_checks.append(f"True")
elif not patterns:
gen_checks.append(f"False")
else:
for pattern in patterns:
if "-" in pattern:

View file

@ -1687,38 +1687,38 @@ class AssertTestCase(FHDLTestCase):
class SwitchTestCase(FHDLTestCase):
def test_default_case(self):
s = Switch(Const(0), {None: []})
self.assertEqual(s.cases, {(): []})
s = Switch(Const(0), [(None, [], None)])
self.assertEqual(s.cases, ((None, [], None),))
def test_int_case(self):
s = Switch(Const(0, 8), {10: []})
self.assertEqual(s.cases, {("00001010",): []})
s = Switch(Const(0, 8), [(10, [], None)])
self.assertEqual(s.cases, ((("00001010",), [], None),))
def test_int_neg_case(self):
s = Switch(Const(0, signed(8)), {-10: []})
self.assertEqual(s.cases, {("11110110",): []})
s = Switch(Const(0, signed(8)), [(-10, [], None)])
self.assertEqual(s.cases, ((("11110110",), [], None),))
def test_int_zero_width(self):
s = Switch(Const(0, 0), {0: []})
self.assertEqual(s.cases, {("",): []})
s = Switch(Const(0, 0), [(0, [], None)])
self.assertEqual(s.cases, ((("",), [], None),))
def test_int_zero_width_enum(self):
class ZeroEnum(Enum):
A = 0
s = Switch(Const(0, 0), {ZeroEnum.A: []})
self.assertEqual(s.cases, {("",): []})
s = Switch(Const(0, 0), [(ZeroEnum.A, [], None)])
self.assertEqual(s.cases, ((("",), [], None),))
def test_enum_case(self):
s = Switch(Const(0, UnsignedEnum), {UnsignedEnum.FOO: []})
self.assertEqual(s.cases, {("01",): []})
s = Switch(Const(0, UnsignedEnum), [(UnsignedEnum.FOO, [], None)])
self.assertEqual(s.cases, ((("01",), [], None),))
def test_str_case(self):
s = Switch(Const(0, 8), {"0000 11\t01": []})
self.assertEqual(s.cases, {("00001101",): []})
s = Switch(Const(0, 8), [("0000 11\t01", [], None)])
self.assertEqual(s.cases, ((("00001101",), [], None),))
def test_two_cases(self):
s = Switch(Const(0, 8), {("00001111", 123): []})
self.assertEqual(s.cases, {("00001111", "01111011"): []})
s = Switch(Const(0, 8), [(("00001111", 123), [], None)])
self.assertEqual(s.cases, ((("00001111", "01111011"), [], None),))
class IOValueTestCase(FHDLTestCase):

View file

@ -411,6 +411,7 @@ class DSLTestCase(FHDLTestCase):
(
(switch (sig w1)
(case 0011 (eq (sig c1) (const 1'd1)))
(case () (eq (sig c2) (const 1'd1)))
)
)
""")
@ -500,7 +501,14 @@ class DSLTestCase(FHDLTestCase):
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
with m.Case(Color.RED):
m.d.comb += dummy.eq(0)
self.assertEqual(m._statements, {})
self.assertRepr(m._statements["comb"], """
(
(switch (sig w1)
(case () (eq (sig dummy) (const 1'd0)))
(case () (eq (sig dummy) (const 1'd0)))
)
)
""")
def test_Switch_zero_width(self):
m = Module()