From 2eb62a8b4977f6ca46ab49052365bb3e61b261e3 Mon Sep 17 00:00:00 2001 From: Wanda Date: Tue, 2 Apr 2024 23:04:56 +0200 Subject: [PATCH] hdl._ast: change `Switch` to operate on list of cases. --- amaranth/back/rtlil.py | 2 ++ amaranth/hdl/_ast.py | 56 ++++++++++++++++++++---------------------- amaranth/hdl/_dsl.py | 56 +++++++++++++++++++----------------------- amaranth/hdl/_ir.py | 13 +++++++--- amaranth/hdl/_xfrm.py | 10 +++----- amaranth/sim/_pyrtl.py | 6 +++-- tests/test_hdl_ast.py | 32 ++++++++++++------------ tests/test_hdl_dsl.py | 10 +++++++- 8 files changed, 95 insertions(+), 90 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 5c5b7dd..483ba13 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -638,6 +638,8 @@ class ModuleEmitter: assert isinstance(matches_cell, _nir.Matches) assert test == matches_cell.value patterns = matches_cell.patterns + # RTLIL cannot support empty pattern sets. + assert patterns with switch.case(*patterns) as subcase: emit_assignments(subcase, subcond) emitted_switch = True diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 5b48ee0..0b7f15e 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -2763,38 +2763,34 @@ class _LateBoundStatement(Statement): @final class Switch(Statement): - def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}): + def __init__(self, test, cases, *, src_loc=None, src_loc_at=0): if src_loc is None: super().__init__(src_loc_at=src_loc_at) else: # Switch is a bit special in terms of location tracking because it is usually created # long after the control has left the statement that directly caused its creation. self.src_loc = src_loc - # Switch is also a bit special in that its parts also have location information. It can't - # be automatically traced, so whatever constructs a Switch may optionally provide it. - self.case_src_locs = {} self._test = Value.cast(test) - self._cases = OrderedDict() - for orig_keys, stmts in cases.items(): - # Map: None -> (); key -> (key,); (key...) -> (key...) - keys = orig_keys - if keys is None: - keys = () - if not isinstance(keys, tuple): - keys = (keys,) - # Map: 2 -> "0010"; "0010" -> "0010" - new_keys = () - key_mask = (1 << len(self.test)) - 1 - for key in _normalize_patterns(keys, self._test.shape()): - if isinstance(key, int): - key = to_binary(key & key_mask, len(self.test)) - new_keys = (*new_keys, key) + self._cases = [] + for patterns, stmts, case_src_loc in cases: + if patterns is not None: + # Map: key -> (key,); (key...) -> (key...) + if not isinstance(patterns, tuple): + patterns = (patterns,) + # Map: 2 -> "0010"; "0010" -> "0010" + new_patterns = () + key_mask = (1 << len(self.test)) - 1 + for key in _normalize_patterns(patterns, self._test.shape()): + if isinstance(key, int): + key = to_binary(key & key_mask, len(self.test)) + new_patterns = (*new_patterns, key) + else: + new_patterns = None if not isinstance(stmts, Iterable): stmts = [stmts] - self._cases[new_keys] = Statement.cast(stmts) - if orig_keys in case_src_locs: - self.case_src_locs[new_keys] = case_src_locs[orig_keys] + self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc)) + self._cases = tuple(self._cases) @property def test(self): @@ -2805,22 +2801,22 @@ class Switch(Statement): return self._cases def _lhs_signals(self): - return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet()) + return union((stmts._lhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet()) def _rhs_signals(self): - signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet()) + signals = union((stmts._rhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet()) return self.test._rhs_signals() | signals def __repr__(self): - def case_repr(keys, stmts): + def case_repr(patterns, stmts): stmts_repr = " ".join(map(repr, stmts)) - if keys == (): + if patterns is None: return f"(default {stmts_repr})" - elif len(keys) == 1: - return f"(case {keys[0]} {stmts_repr})" + elif len(patterns) == 1: + return f"(case {patterns[0]} {stmts_repr})" else: - return "(case ({}) {})".format(" ".join(keys), stmts_repr) - case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()] + return "(case ({}) {})".format(" ".join(patterns), stmts_repr) + case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases] return "(switch {!r} {})".format(self.test, " ".join(case_reprs)) diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index af636b8..d92255f 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -169,12 +169,11 @@ def resolve_statement(stmt): elif isinstance(stmt, Switch): return Switch( test=stmt.test, - cases=OrderedDict( - (patterns, resolve_statements(stmts)) - for patterns, stmts in stmt.cases.items() - ), + cases=[ + (patterns, resolve_statements(stmts), src_loc) + for patterns, stmts, src_loc in stmt.cases + ], src_loc=stmt.src_loc, - case_src_locs=stmt.case_src_locs, ) elif isinstance(stmt, (Assign, Property, Print)): return stmt @@ -318,9 +317,9 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._check_context("Switch", context=None) switch_data = self._set_ctrl("Switch", { "test": Value.cast(test), - "cases": OrderedDict(), + "cases": [], "src_loc": tracer.get_src_loc(src_loc_at=1), - "case_src_locs": {}, + "got_default": False, }) try: self._ctrl_context = "Switch" @@ -336,7 +335,7 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._check_context("Case", context="Switch") src_loc = tracer.get_src_loc(src_loc_at=1) switch_data = self._get_ctrl("Switch") - if () in switch_data["cases"]: + if switch_data["got_default"]: warnings.warn("A case defined after the default case will never be active", SyntaxWarning, stacklevel=3) new_patterns = _normalize_patterns(patterns, switch_data["test"].shape()) @@ -345,12 +344,7 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._ctrl_context = None yield self._flush_ctrl() - # If none of the provided cases can possibly be true, omit this branch completely. - # Likewise, omit this branch if another branch with this exact set of patterns already - # exists (since otherwise we'd overwrite the previous branch's slot in the dict). - if new_patterns and new_patterns not in switch_data["cases"]: - switch_data["cases"][new_patterns] = self._statements - switch_data["case_src_locs"][new_patterns] = src_loc + switch_data["cases"].append((new_patterns, self._statements, src_loc)) finally: self._ctrl_context = "Switch" self._statements = _outer_case @@ -360,7 +354,7 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._check_context("Default", context="Switch") src_loc = tracer.get_src_loc(src_loc_at=1) switch_data = self._get_ctrl("Switch") - if () in switch_data["cases"]: + if switch_data["got_default"]: warnings.warn("A case defined after the default case will never be active", SyntaxWarning, stacklevel=3) try: @@ -368,9 +362,8 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._ctrl_context = None yield self._flush_ctrl() - if () not in switch_data["cases"]: - switch_data["cases"][()] = self._statements - switch_data["case_src_locs"][()] = src_loc + switch_data["cases"].append((None, self._statements, src_loc)) + switch_data["got_default"] = True finally: self._ctrl_context = "Switch" self._statements = _outer_case @@ -471,8 +464,8 @@ class Module(_ModuleBuilderRoot, Elaboratable): domains[domain] = None for domain in domains: - tests, cases = [], OrderedDict() - for if_test, if_case in zip(if_tests + [None], if_bodies): + tests, cases = [], [] + for if_test, if_case, if_src_loc in zip(if_tests + [None], if_bodies, if_src_locs): if if_test is not None: if len(if_test) != 1: if_test = if_test.bool() @@ -482,27 +475,26 @@ class Module(_ModuleBuilderRoot, Elaboratable): match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-") else: match = None - cases[match] = if_case.get(domain, []) + cases.append((match, if_case.get(domain, []), if_src_loc)) self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases, - src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs)))) + src_loc=src_loc)) if name == "Switch": switch_test, switch_cases = data["test"], data["cases"] - switch_case_src_locs = data["case_src_locs"] domains = {} - for stmts in switch_cases.values(): + for _patterns, stmts, _src_loc in switch_cases: for domain in stmts: domains[domain] = None for domain in domains: - domain_cases = OrderedDict() - for patterns, stmts in switch_cases.items(): - domain_cases[patterns] = stmts.get(domain, []) + domain_cases = [] + for patterns, stmts, case_src_loc in switch_cases: + domain_cases.append((patterns, stmts.get(domain, []), case_src_loc)) self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases, - src_loc=src_loc, case_src_locs=switch_case_src_locs)) + src_loc=src_loc)) if name == "FSM": fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \ @@ -536,9 +528,11 @@ class Module(_ModuleBuilderRoot, Elaboratable): domain_states[state] = stmts.get(domain, []) self._statements.setdefault(domain, []).append(Switch(fsm_signal, - OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()), - src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name] - for name in fsm_states})) + [ + (fsm_encoding[name], stmts, fsm_state_src_locs[name]) + for name, stmts in domain_states.items() + ], + src_loc=src_loc)) def _add_statement(self, assigns, domain, depth): while len(self._ctrl_stack) > self.domain._depth: diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 5037afb..717db2e 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -1110,20 +1110,25 @@ class NetlistEmitter: elif isinstance(stmt, _ast.Switch): test, _signed = self.emit_rhs(module_idx, stmt.test) conds = [] - for patterns in stmt.cases: - if patterns: + case_stmts = [] + for patterns, stmts, case_src_loc in stmt.cases: + if patterns is not None: + if not patterns: + # Hack: empty pattern set cannot be supported by RTLIL. + continue for pattern in patterns: assert len(pattern) == len(test) cell = _nir.Matches(module_idx, value=test, patterns=patterns, - src_loc=stmt.case_src_locs.get(patterns)) + src_loc=case_src_loc) net, = self.netlist.add_value_cell(1, cell) conds.append(net) else: conds.append(_nir.Net.from_const(1)) + case_stmts.append(stmts) cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds), src_loc=stmt.src_loc) conds = self.netlist.add_value_cell(len(conds), cell) - for subcond, substmts in zip(conds, stmt.cases.values()): + for subcond, substmts in zip(conds, case_stmts): for substmt in substmts: self.emit_stmt(module_idx, fragment, domain, substmt, subcond) else: diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index f96ec51..01b4fa1 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -183,8 +183,6 @@ class StatementVisitor(metaclass=ABCMeta): new_stmt = self.on_unknown_statement(stmt) if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt): new_stmt.src_loc = stmt.src_loc - if isinstance(new_stmt, Switch) and isinstance(stmt, Switch): - new_stmt.case_src_locs = stmt.case_src_locs if isinstance(new_stmt, (Print, Property)): new_stmt._MustUse__used = True return new_stmt @@ -221,7 +219,7 @@ class StatementTransformer(StatementVisitor): return Property(stmt.kind, self.on_value(stmt.test), message) def on_Switch(self, stmt): - cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) + cases = [(k, self.on_statement(s), l) for k, s, l in stmt.cases] return Switch(self.on_value(stmt.test), cases) def on_statements(self, stmts): @@ -429,7 +427,7 @@ class DomainCollector(ValueVisitor, StatementVisitor): def on_Switch(self, stmt): self.on_value(stmt.test) - for stmts in stmt.cases.values(): + for _patterns, stmts, _src_loc in stmt.cases: self.on_statement(stmts) def on_statements(self, stmts): @@ -624,7 +622,7 @@ class _ControlInserter(FragmentTransformer): class ResetInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less] - fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) + fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc)) class EnableInserter(_ControlInserter): @@ -632,7 +630,7 @@ class EnableInserter(_ControlInserter): if domain in fragment.statements: fragment.statements[domain] = _StatementList([Switch( self.controls[domain], - {1: fragment.statements[domain]}, + [(1, fragment.statements[domain], None)], src_loc=self.src_loc, )]) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 363cec6..3f41f34 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -396,10 +396,12 @@ class _StatementCompiler(StatementVisitor, _Compiler): def on_Switch(self, stmt): gen_test_value = self.rhs(stmt.test) # check for oversized value before generating mask gen_test = self.emitter.def_var("test", f"{(1 << len(stmt.test)) - 1:#x} & {gen_test_value}") - for index, (patterns, stmts) in enumerate(stmt.cases.items()): + for index, (patterns, stmts, _src_loc) in enumerate(stmt.cases): gen_checks = [] - if not patterns: + if patterns is None: gen_checks.append(f"True") + elif not patterns: + gen_checks.append(f"False") else: for pattern in patterns: if "-" in pattern: diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index b173698..23e8e76 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1687,38 +1687,38 @@ class AssertTestCase(FHDLTestCase): class SwitchTestCase(FHDLTestCase): def test_default_case(self): - s = Switch(Const(0), {None: []}) - self.assertEqual(s.cases, {(): []}) + s = Switch(Const(0), [(None, [], None)]) + self.assertEqual(s.cases, ((None, [], None),)) def test_int_case(self): - s = Switch(Const(0, 8), {10: []}) - self.assertEqual(s.cases, {("00001010",): []}) + s = Switch(Const(0, 8), [(10, [], None)]) + self.assertEqual(s.cases, ((("00001010",), [], None),)) def test_int_neg_case(self): - s = Switch(Const(0, signed(8)), {-10: []}) - self.assertEqual(s.cases, {("11110110",): []}) + s = Switch(Const(0, signed(8)), [(-10, [], None)]) + self.assertEqual(s.cases, ((("11110110",), [], None),)) def test_int_zero_width(self): - s = Switch(Const(0, 0), {0: []}) - self.assertEqual(s.cases, {("",): []}) + s = Switch(Const(0, 0), [(0, [], None)]) + self.assertEqual(s.cases, ((("",), [], None),)) def test_int_zero_width_enum(self): class ZeroEnum(Enum): A = 0 - s = Switch(Const(0, 0), {ZeroEnum.A: []}) - self.assertEqual(s.cases, {("",): []}) + s = Switch(Const(0, 0), [(ZeroEnum.A, [], None)]) + self.assertEqual(s.cases, ((("",), [], None),)) def test_enum_case(self): - s = Switch(Const(0, UnsignedEnum), {UnsignedEnum.FOO: []}) - self.assertEqual(s.cases, {("01",): []}) + s = Switch(Const(0, UnsignedEnum), [(UnsignedEnum.FOO, [], None)]) + self.assertEqual(s.cases, ((("01",), [], None),)) def test_str_case(self): - s = Switch(Const(0, 8), {"0000 11\t01": []}) - self.assertEqual(s.cases, {("00001101",): []}) + s = Switch(Const(0, 8), [("0000 11\t01", [], None)]) + self.assertEqual(s.cases, ((("00001101",), [], None),)) def test_two_cases(self): - s = Switch(Const(0, 8), {("00001111", 123): []}) - self.assertEqual(s.cases, {("00001111", "01111011"): []}) + s = Switch(Const(0, 8), [(("00001111", 123), [], None)]) + self.assertEqual(s.cases, ((("00001111", "01111011"), [], None),)) class IOValueTestCase(FHDLTestCase): diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index ff7cfa0..ad373ee 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -411,6 +411,7 @@ class DSLTestCase(FHDLTestCase): ( (switch (sig w1) (case 0011 (eq (sig c1) (const 1'd1))) + (case () (eq (sig c2) (const 1'd1))) ) ) """) @@ -500,7 +501,14 @@ class DSLTestCase(FHDLTestCase): r"match value shape \(unsigned\(4\)\); comparison will never be true$"): with m.Case(Color.RED): m.d.comb += dummy.eq(0) - self.assertEqual(m._statements, {}) + self.assertRepr(m._statements["comb"], """ + ( + (switch (sig w1) + (case () (eq (sig dummy) (const 1'd0))) + (case () (eq (sig dummy) (const 1'd0))) + ) + ) + """) def test_Switch_zero_width(self): m = Module()