hdl._ast: change Switch
to operate on list of cases.
This commit is contained in:
parent
cd6cbd71ca
commit
2eb62a8b49
|
@ -638,6 +638,8 @@ class ModuleEmitter:
|
|||
assert isinstance(matches_cell, _nir.Matches)
|
||||
assert test == matches_cell.value
|
||||
patterns = matches_cell.patterns
|
||||
# RTLIL cannot support empty pattern sets.
|
||||
assert patterns
|
||||
with switch.case(*patterns) as subcase:
|
||||
emit_assignments(subcase, subcond)
|
||||
emitted_switch = True
|
||||
|
|
|
@ -2763,38 +2763,34 @@ class _LateBoundStatement(Statement):
|
|||
|
||||
@final
|
||||
class Switch(Statement):
|
||||
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):
|
||||
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
|
||||
if src_loc is None:
|
||||
super().__init__(src_loc_at=src_loc_at)
|
||||
else:
|
||||
# Switch is a bit special in terms of location tracking because it is usually created
|
||||
# long after the control has left the statement that directly caused its creation.
|
||||
self.src_loc = src_loc
|
||||
# Switch is also a bit special in that its parts also have location information. It can't
|
||||
# be automatically traced, so whatever constructs a Switch may optionally provide it.
|
||||
self.case_src_locs = {}
|
||||
|
||||
self._test = Value.cast(test)
|
||||
self._cases = OrderedDict()
|
||||
for orig_keys, stmts in cases.items():
|
||||
# Map: None -> (); key -> (key,); (key...) -> (key...)
|
||||
keys = orig_keys
|
||||
if keys is None:
|
||||
keys = ()
|
||||
if not isinstance(keys, tuple):
|
||||
keys = (keys,)
|
||||
# Map: 2 -> "0010"; "0010" -> "0010"
|
||||
new_keys = ()
|
||||
key_mask = (1 << len(self.test)) - 1
|
||||
for key in _normalize_patterns(keys, self._test.shape()):
|
||||
if isinstance(key, int):
|
||||
key = to_binary(key & key_mask, len(self.test))
|
||||
new_keys = (*new_keys, key)
|
||||
self._cases = []
|
||||
for patterns, stmts, case_src_loc in cases:
|
||||
if patterns is not None:
|
||||
# Map: key -> (key,); (key...) -> (key...)
|
||||
if not isinstance(patterns, tuple):
|
||||
patterns = (patterns,)
|
||||
# Map: 2 -> "0010"; "0010" -> "0010"
|
||||
new_patterns = ()
|
||||
key_mask = (1 << len(self.test)) - 1
|
||||
for key in _normalize_patterns(patterns, self._test.shape()):
|
||||
if isinstance(key, int):
|
||||
key = to_binary(key & key_mask, len(self.test))
|
||||
new_patterns = (*new_patterns, key)
|
||||
else:
|
||||
new_patterns = None
|
||||
if not isinstance(stmts, Iterable):
|
||||
stmts = [stmts]
|
||||
self._cases[new_keys] = Statement.cast(stmts)
|
||||
if orig_keys in case_src_locs:
|
||||
self.case_src_locs[new_keys] = case_src_locs[orig_keys]
|
||||
self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
|
||||
self._cases = tuple(self._cases)
|
||||
|
||||
@property
|
||||
def test(self):
|
||||
|
@ -2805,22 +2801,22 @@ class Switch(Statement):
|
|||
return self._cases
|
||||
|
||||
def _lhs_signals(self):
|
||||
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())
|
||||
return union((stmts._lhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet())
|
||||
|
||||
def _rhs_signals(self):
|
||||
signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
|
||||
signals = union((stmts._rhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet())
|
||||
return self.test._rhs_signals() | signals
|
||||
|
||||
def __repr__(self):
|
||||
def case_repr(keys, stmts):
|
||||
def case_repr(patterns, stmts):
|
||||
stmts_repr = " ".join(map(repr, stmts))
|
||||
if keys == ():
|
||||
if patterns is None:
|
||||
return f"(default {stmts_repr})"
|
||||
elif len(keys) == 1:
|
||||
return f"(case {keys[0]} {stmts_repr})"
|
||||
elif len(patterns) == 1:
|
||||
return f"(case {patterns[0]} {stmts_repr})"
|
||||
else:
|
||||
return "(case ({}) {})".format(" ".join(keys), stmts_repr)
|
||||
case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
|
||||
return "(case ({}) {})".format(" ".join(patterns), stmts_repr)
|
||||
case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases]
|
||||
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
|
||||
|
||||
|
||||
|
|
|
@ -169,12 +169,11 @@ def resolve_statement(stmt):
|
|||
elif isinstance(stmt, Switch):
|
||||
return Switch(
|
||||
test=stmt.test,
|
||||
cases=OrderedDict(
|
||||
(patterns, resolve_statements(stmts))
|
||||
for patterns, stmts in stmt.cases.items()
|
||||
),
|
||||
cases=[
|
||||
(patterns, resolve_statements(stmts), src_loc)
|
||||
for patterns, stmts, src_loc in stmt.cases
|
||||
],
|
||||
src_loc=stmt.src_loc,
|
||||
case_src_locs=stmt.case_src_locs,
|
||||
)
|
||||
elif isinstance(stmt, (Assign, Property, Print)):
|
||||
return stmt
|
||||
|
@ -318,9 +317,9 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
self._check_context("Switch", context=None)
|
||||
switch_data = self._set_ctrl("Switch", {
|
||||
"test": Value.cast(test),
|
||||
"cases": OrderedDict(),
|
||||
"cases": [],
|
||||
"src_loc": tracer.get_src_loc(src_loc_at=1),
|
||||
"case_src_locs": {},
|
||||
"got_default": False,
|
||||
})
|
||||
try:
|
||||
self._ctrl_context = "Switch"
|
||||
|
@ -336,7 +335,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
self._check_context("Case", context="Switch")
|
||||
src_loc = tracer.get_src_loc(src_loc_at=1)
|
||||
switch_data = self._get_ctrl("Switch")
|
||||
if () in switch_data["cases"]:
|
||||
if switch_data["got_default"]:
|
||||
warnings.warn("A case defined after the default case will never be active",
|
||||
SyntaxWarning, stacklevel=3)
|
||||
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
|
||||
|
@ -345,12 +344,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
self._ctrl_context = None
|
||||
yield
|
||||
self._flush_ctrl()
|
||||
# If none of the provided cases can possibly be true, omit this branch completely.
|
||||
# Likewise, omit this branch if another branch with this exact set of patterns already
|
||||
# exists (since otherwise we'd overwrite the previous branch's slot in the dict).
|
||||
if new_patterns and new_patterns not in switch_data["cases"]:
|
||||
switch_data["cases"][new_patterns] = self._statements
|
||||
switch_data["case_src_locs"][new_patterns] = src_loc
|
||||
switch_data["cases"].append((new_patterns, self._statements, src_loc))
|
||||
finally:
|
||||
self._ctrl_context = "Switch"
|
||||
self._statements = _outer_case
|
||||
|
@ -360,7 +354,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
self._check_context("Default", context="Switch")
|
||||
src_loc = tracer.get_src_loc(src_loc_at=1)
|
||||
switch_data = self._get_ctrl("Switch")
|
||||
if () in switch_data["cases"]:
|
||||
if switch_data["got_default"]:
|
||||
warnings.warn("A case defined after the default case will never be active",
|
||||
SyntaxWarning, stacklevel=3)
|
||||
try:
|
||||
|
@ -368,9 +362,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
self._ctrl_context = None
|
||||
yield
|
||||
self._flush_ctrl()
|
||||
if () not in switch_data["cases"]:
|
||||
switch_data["cases"][()] = self._statements
|
||||
switch_data["case_src_locs"][()] = src_loc
|
||||
switch_data["cases"].append((None, self._statements, src_loc))
|
||||
switch_data["got_default"] = True
|
||||
finally:
|
||||
self._ctrl_context = "Switch"
|
||||
self._statements = _outer_case
|
||||
|
@ -471,8 +464,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
domains[domain] = None
|
||||
|
||||
for domain in domains:
|
||||
tests, cases = [], OrderedDict()
|
||||
for if_test, if_case in zip(if_tests + [None], if_bodies):
|
||||
tests, cases = [], []
|
||||
for if_test, if_case, if_src_loc in zip(if_tests + [None], if_bodies, if_src_locs):
|
||||
if if_test is not None:
|
||||
if len(if_test) != 1:
|
||||
if_test = if_test.bool()
|
||||
|
@ -482,27 +475,26 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
|
||||
else:
|
||||
match = None
|
||||
cases[match] = if_case.get(domain, [])
|
||||
cases.append((match, if_case.get(domain, []), if_src_loc))
|
||||
|
||||
self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
|
||||
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
|
||||
src_loc=src_loc))
|
||||
|
||||
if name == "Switch":
|
||||
switch_test, switch_cases = data["test"], data["cases"]
|
||||
switch_case_src_locs = data["case_src_locs"]
|
||||
|
||||
domains = {}
|
||||
for stmts in switch_cases.values():
|
||||
for _patterns, stmts, _src_loc in switch_cases:
|
||||
for domain in stmts:
|
||||
domains[domain] = None
|
||||
|
||||
for domain in domains:
|
||||
domain_cases = OrderedDict()
|
||||
for patterns, stmts in switch_cases.items():
|
||||
domain_cases[patterns] = stmts.get(domain, [])
|
||||
domain_cases = []
|
||||
for patterns, stmts, case_src_loc in switch_cases:
|
||||
domain_cases.append((patterns, stmts.get(domain, []), case_src_loc))
|
||||
|
||||
self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
|
||||
src_loc=src_loc, case_src_locs=switch_case_src_locs))
|
||||
src_loc=src_loc))
|
||||
|
||||
if name == "FSM":
|
||||
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
|
||||
|
@ -536,9 +528,11 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
domain_states[state] = stmts.get(domain, [])
|
||||
|
||||
self._statements.setdefault(domain, []).append(Switch(fsm_signal,
|
||||
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
|
||||
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
|
||||
for name in fsm_states}))
|
||||
[
|
||||
(fsm_encoding[name], stmts, fsm_state_src_locs[name])
|
||||
for name, stmts in domain_states.items()
|
||||
],
|
||||
src_loc=src_loc))
|
||||
|
||||
def _add_statement(self, assigns, domain, depth):
|
||||
while len(self._ctrl_stack) > self.domain._depth:
|
||||
|
|
|
@ -1110,20 +1110,25 @@ class NetlistEmitter:
|
|||
elif isinstance(stmt, _ast.Switch):
|
||||
test, _signed = self.emit_rhs(module_idx, stmt.test)
|
||||
conds = []
|
||||
for patterns in stmt.cases:
|
||||
if patterns:
|
||||
case_stmts = []
|
||||
for patterns, stmts, case_src_loc in stmt.cases:
|
||||
if patterns is not None:
|
||||
if not patterns:
|
||||
# Hack: empty pattern set cannot be supported by RTLIL.
|
||||
continue
|
||||
for pattern in patterns:
|
||||
assert len(pattern) == len(test)
|
||||
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
|
||||
src_loc=stmt.case_src_locs.get(patterns))
|
||||
src_loc=case_src_loc)
|
||||
net, = self.netlist.add_value_cell(1, cell)
|
||||
conds.append(net)
|
||||
else:
|
||||
conds.append(_nir.Net.from_const(1))
|
||||
case_stmts.append(stmts)
|
||||
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds),
|
||||
src_loc=stmt.src_loc)
|
||||
conds = self.netlist.add_value_cell(len(conds), cell)
|
||||
for subcond, substmts in zip(conds, stmt.cases.values()):
|
||||
for subcond, substmts in zip(conds, case_stmts):
|
||||
for substmt in substmts:
|
||||
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
|
||||
else:
|
||||
|
|
|
@ -183,8 +183,6 @@ class StatementVisitor(metaclass=ABCMeta):
|
|||
new_stmt = self.on_unknown_statement(stmt)
|
||||
if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt):
|
||||
new_stmt.src_loc = stmt.src_loc
|
||||
if isinstance(new_stmt, Switch) and isinstance(stmt, Switch):
|
||||
new_stmt.case_src_locs = stmt.case_src_locs
|
||||
if isinstance(new_stmt, (Print, Property)):
|
||||
new_stmt._MustUse__used = True
|
||||
return new_stmt
|
||||
|
@ -221,7 +219,7 @@ class StatementTransformer(StatementVisitor):
|
|||
return Property(stmt.kind, self.on_value(stmt.test), message)
|
||||
|
||||
def on_Switch(self, stmt):
|
||||
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
|
||||
cases = [(k, self.on_statement(s), l) for k, s, l in stmt.cases]
|
||||
return Switch(self.on_value(stmt.test), cases)
|
||||
|
||||
def on_statements(self, stmts):
|
||||
|
@ -429,7 +427,7 @@ class DomainCollector(ValueVisitor, StatementVisitor):
|
|||
|
||||
def on_Switch(self, stmt):
|
||||
self.on_value(stmt.test)
|
||||
for stmts in stmt.cases.values():
|
||||
for _patterns, stmts, _src_loc in stmt.cases:
|
||||
self.on_statement(stmts)
|
||||
|
||||
def on_statements(self, stmts):
|
||||
|
@ -624,7 +622,7 @@ class _ControlInserter(FragmentTransformer):
|
|||
class ResetInserter(_ControlInserter):
|
||||
def _insert_control(self, fragment, domain, signals):
|
||||
stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less]
|
||||
fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
|
||||
fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc))
|
||||
|
||||
|
||||
class EnableInserter(_ControlInserter):
|
||||
|
@ -632,7 +630,7 @@ class EnableInserter(_ControlInserter):
|
|||
if domain in fragment.statements:
|
||||
fragment.statements[domain] = _StatementList([Switch(
|
||||
self.controls[domain],
|
||||
{1: fragment.statements[domain]},
|
||||
[(1, fragment.statements[domain], None)],
|
||||
src_loc=self.src_loc,
|
||||
)])
|
||||
|
||||
|
|
|
@ -396,10 +396,12 @@ class _StatementCompiler(StatementVisitor, _Compiler):
|
|||
def on_Switch(self, stmt):
|
||||
gen_test_value = self.rhs(stmt.test) # check for oversized value before generating mask
|
||||
gen_test = self.emitter.def_var("test", f"{(1 << len(stmt.test)) - 1:#x} & {gen_test_value}")
|
||||
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
|
||||
for index, (patterns, stmts, _src_loc) in enumerate(stmt.cases):
|
||||
gen_checks = []
|
||||
if not patterns:
|
||||
if patterns is None:
|
||||
gen_checks.append(f"True")
|
||||
elif not patterns:
|
||||
gen_checks.append(f"False")
|
||||
else:
|
||||
for pattern in patterns:
|
||||
if "-" in pattern:
|
||||
|
|
|
@ -1687,38 +1687,38 @@ class AssertTestCase(FHDLTestCase):
|
|||
|
||||
class SwitchTestCase(FHDLTestCase):
|
||||
def test_default_case(self):
|
||||
s = Switch(Const(0), {None: []})
|
||||
self.assertEqual(s.cases, {(): []})
|
||||
s = Switch(Const(0), [(None, [], None)])
|
||||
self.assertEqual(s.cases, ((None, [], None),))
|
||||
|
||||
def test_int_case(self):
|
||||
s = Switch(Const(0, 8), {10: []})
|
||||
self.assertEqual(s.cases, {("00001010",): []})
|
||||
s = Switch(Const(0, 8), [(10, [], None)])
|
||||
self.assertEqual(s.cases, ((("00001010",), [], None),))
|
||||
|
||||
def test_int_neg_case(self):
|
||||
s = Switch(Const(0, signed(8)), {-10: []})
|
||||
self.assertEqual(s.cases, {("11110110",): []})
|
||||
s = Switch(Const(0, signed(8)), [(-10, [], None)])
|
||||
self.assertEqual(s.cases, ((("11110110",), [], None),))
|
||||
|
||||
def test_int_zero_width(self):
|
||||
s = Switch(Const(0, 0), {0: []})
|
||||
self.assertEqual(s.cases, {("",): []})
|
||||
s = Switch(Const(0, 0), [(0, [], None)])
|
||||
self.assertEqual(s.cases, ((("",), [], None),))
|
||||
|
||||
def test_int_zero_width_enum(self):
|
||||
class ZeroEnum(Enum):
|
||||
A = 0
|
||||
s = Switch(Const(0, 0), {ZeroEnum.A: []})
|
||||
self.assertEqual(s.cases, {("",): []})
|
||||
s = Switch(Const(0, 0), [(ZeroEnum.A, [], None)])
|
||||
self.assertEqual(s.cases, ((("",), [], None),))
|
||||
|
||||
def test_enum_case(self):
|
||||
s = Switch(Const(0, UnsignedEnum), {UnsignedEnum.FOO: []})
|
||||
self.assertEqual(s.cases, {("01",): []})
|
||||
s = Switch(Const(0, UnsignedEnum), [(UnsignedEnum.FOO, [], None)])
|
||||
self.assertEqual(s.cases, ((("01",), [], None),))
|
||||
|
||||
def test_str_case(self):
|
||||
s = Switch(Const(0, 8), {"0000 11\t01": []})
|
||||
self.assertEqual(s.cases, {("00001101",): []})
|
||||
s = Switch(Const(0, 8), [("0000 11\t01", [], None)])
|
||||
self.assertEqual(s.cases, ((("00001101",), [], None),))
|
||||
|
||||
def test_two_cases(self):
|
||||
s = Switch(Const(0, 8), {("00001111", 123): []})
|
||||
self.assertEqual(s.cases, {("00001111", "01111011"): []})
|
||||
s = Switch(Const(0, 8), [(("00001111", 123), [], None)])
|
||||
self.assertEqual(s.cases, ((("00001111", "01111011"), [], None),))
|
||||
|
||||
|
||||
class IOValueTestCase(FHDLTestCase):
|
||||
|
|
|
@ -411,6 +411,7 @@ class DSLTestCase(FHDLTestCase):
|
|||
(
|
||||
(switch (sig w1)
|
||||
(case 0011 (eq (sig c1) (const 1'd1)))
|
||||
(case () (eq (sig c2) (const 1'd1)))
|
||||
)
|
||||
)
|
||||
""")
|
||||
|
@ -500,7 +501,14 @@ class DSLTestCase(FHDLTestCase):
|
|||
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
|
||||
with m.Case(Color.RED):
|
||||
m.d.comb += dummy.eq(0)
|
||||
self.assertEqual(m._statements, {})
|
||||
self.assertRepr(m._statements["comb"], """
|
||||
(
|
||||
(switch (sig w1)
|
||||
(case () (eq (sig dummy) (const 1'd0)))
|
||||
(case () (eq (sig dummy) (const 1'd0)))
|
||||
)
|
||||
)
|
||||
""")
|
||||
|
||||
def test_Switch_zero_width(self):
|
||||
m = Module()
|
||||
|
|
Loading…
Reference in a new issue