hdl.{ast,dsl}, back.{pysim,rtlil}: allow multiple case values.

This means that instead of:

    with m.Case(0b00):
        <body>
    with m.Case(0b01):
        <body>

it is legal to write:

    with m.Case(0b00, 0b01):
        <body>

with no change in semantics, and slightly nicer RTLIL or Verilog
output.

Fixes #103.
This commit is contained in:
whitequark 2019-06-28 04:37:08 +00:00
parent 48d4ee4031
commit 32446831b4
6 changed files with 73 additions and 55 deletions

View file

@ -318,10 +318,14 @@ class _StatementCompiler(StatementVisitor):
def on_Switch(self, stmt): def on_Switch(self, stmt):
test = self.rrhs_compiler(stmt.test) test = self.rrhs_compiler(stmt.test)
cases = [] cases = []
for value, stmts in stmt.cases.items(): for values, stmts in stmt.cases.items():
if value is None: if values == ():
check = lambda test: True check = lambda test: True
else: else:
check = lambda test: False
def make_check(mask, value, prev_check):
return lambda test: prev_check(test) or test & mask == value
for value in values:
if "-" in value: if "-" in value:
mask = "".join("0" if b == "-" else "1" for b in value) mask = "".join("0" if b == "-" else "1" for b in value)
value = "".join("0" if b == "-" else b for b in value) value = "".join("0" if b == "-" else b for b in value)
@ -329,9 +333,7 @@ class _StatementCompiler(StatementVisitor):
mask = "1" * len(value) mask = "1" * len(value)
mask = int(mask, 2) mask = int(mask, 2)
value = int(value, 2) value = int(value, 2)
def make_check(mask, value): check = make_check(mask, value, check)
return lambda test: test & mask == value
check = make_check(mask, value)
cases.append((check, self.on_statements(stmts))) cases.append((check, self.on_statements(stmts)))
def run(state): def run(state):
test_value = test(state) test_value = test(state)

View file

@ -188,12 +188,12 @@ class _SwitchBuilder:
def __exit__(self, *args): def __exit__(self, *args):
self.rtlil._append("{}end\n", " " * self.indent) self.rtlil._append("{}end\n", " " * self.indent)
def case(self, value=None): def case(self, *values):
if value is None: if values == ():
self.rtlil._append("{}case\n", " " * (self.indent + 1)) self.rtlil._append("{}case\n", " " * (self.indent + 1))
else: else:
self.rtlil._append("{}case {}'{}\n", " " * (self.indent + 1), self.rtlil._append("{}case {}\n", " " * (self.indent + 1),
len(value), value) ", ".join("{}'{}".format(len(value), value) for value in values))
return _CaseBuilder(self.rtlil, self.indent + 2) return _CaseBuilder(self.rtlil, self.indent + 2)
@ -590,10 +590,10 @@ class _StatementCompiler(xfrm.StatementVisitor):
self._has_rhs = False self._has_rhs = False
@contextmanager @contextmanager
def case(self, switch, value): def case(self, switch, values):
try: try:
old_case = self._case old_case = self._case
with switch.case(value) as self._case: with switch.case(*values) as self._case:
yield yield
finally: finally:
self._case = old_case self._case = old_case
@ -645,8 +645,8 @@ class _StatementCompiler(xfrm.StatementVisitor):
test_sigspec = self._test_cache[stmt] test_sigspec = self._test_cache[stmt]
with self._case.switch(test_sigspec) as switch: with self._case.switch(test_sigspec) as switch:
for value, stmts in stmt.cases.items(): for values, stmts in stmt.cases.items():
with self.case(switch, value): with self.case(switch, values):
self.on_statements(stmts) self.on_statements(stmts)
def on_statement(self, stmt): def on_statement(self, stmt):

View file

@ -106,12 +106,12 @@ class Case(ast.Switch):
or choice > key): or choice > key):
key = choice key = choice
elif isinstance(key, str) and key == "default": elif isinstance(key, str) and key == "default":
key = None key = ()
else: else:
key = "{:0{}b}".format(wrap(key).value, len(self.test)) key = ("{:0{}b}".format(wrap(key).value, len(self.test)),)
stmts = self.cases[key] stmts = self.cases[key]
del self.cases[key] del self.cases[key]
self.cases[None] = stmts self.cases[()] = stmts
return self return self

View file

@ -1019,20 +1019,27 @@ class Switch(Statement):
def __init__(self, test, cases): def __init__(self, test, cases):
self.test = Value.wrap(test) self.test = Value.wrap(test)
self.cases = OrderedDict() self.cases = OrderedDict()
for key, stmts in cases.items(): for keys, stmts in cases.items():
# Map: None -> (); key -> (key,); (key...) -> (key...)
if keys is None:
keys = ()
if not isinstance(keys, tuple):
keys = (keys,)
# Map: 2 -> "0010"; "0010" -> "0010"
new_keys = ()
for key in keys:
if isinstance(key, (bool, int)): if isinstance(key, (bool, int)):
key = "{:0{}b}".format(key, len(self.test)) key = "{:0{}b}".format(key, len(self.test))
elif isinstance(key, str): elif isinstance(key, str):
pass pass
elif key is None:
pass
else: else:
raise TypeError("Object '{!r}' cannot be used as a switch key" raise TypeError("Object '{!r}' cannot be used as a switch key"
.format(key)) .format(key))
assert key is None or len(key) == len(self.test) assert len(key) == len(self.test)
new_keys = (*new_keys, key)
if not isinstance(stmts, Iterable): if not isinstance(stmts, Iterable):
stmts = [stmts] stmts = [stmts]
self.cases[key] = Statement.wrap(stmts) self.cases[new_keys] = Statement.wrap(stmts)
def _lhs_signals(self): def _lhs_signals(self):
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss), signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
@ -1045,11 +1052,16 @@ class Switch(Statement):
return self.test._rhs_signals() | signals return self.test._rhs_signals() | signals
def __repr__(self): def __repr__(self):
cases = ["(default {})".format(" ".join(map(repr, stmts))) def case_repr(keys, stmts):
if key is None else stmts_repr = " ".join(map(repr, stmts))
"(case {} {})".format(key, " ".join(map(repr, stmts))) if keys == ():
for key, stmts in self.cases.items()] return "(default {})".format(stmts_repr)
return "(switch {!r} {})".format(self.test, " ".join(cases)) elif len(keys) == 1:
return "(case {} {})".format(keys[0], stmts_repr)
else:
return "(case ({}) {})".format(" ".join(keys), stmts_repr)
case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
@final @final

View file

@ -214,27 +214,31 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._pop_ctrl() self._pop_ctrl()
@contextmanager @contextmanager
def Case(self, value=None): def Case(self, *values):
self._check_context("Case", context="Switch") self._check_context("Case", context="Switch")
switch_data = self._get_ctrl("Switch") switch_data = self._get_ctrl("Switch")
if value is None: new_values = ()
value = "-" * len(switch_data["test"]) for value in values:
if isinstance(value, str) and len(value) != len(switch_data["test"]): if isinstance(value, str) and len(value) != len(switch_data["test"]):
raise SyntaxError("Case value '{}' must have the same width as test (which is {})" raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
.format(value, len(switch_data["test"]))) .format(value, len(switch_data["test"])))
omit_case = False
if isinstance(value, int) and bits_for(value) > len(switch_data["test"]): if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
warnings.warn("Case value '{:b}' is wider than test (which has width {}); " warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
"comparison will never be true" "comparison will never be true"
.format(value, len(switch_data["test"])), SyntaxWarning, stacklevel=3) .format(value, len(switch_data["test"])),
omit_case = True SyntaxWarning, stacklevel=3)
continue
new_values = (*new_values, value)
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, []
self._ctrl_context = None self._ctrl_context = None
yield yield
self._flush_ctrl() self._flush_ctrl()
if not omit_case: # If none of the provided cases can possibly be true, omit this branch completely.
switch_data["cases"][value] = self._statements # This needs to be differentiated from no cases being provided in the first place,
# which means the branch will always match.
if not (values and not new_values):
switch_data["cases"][new_values] = self._statements
finally: finally:
self._ctrl_context = "Switch" self._ctrl_context = "Switch"
self._statements = _outer_case self._statements = _outer_case

View file

@ -297,7 +297,7 @@ class DSLTestCase(FHDLTestCase):
( (
(switch (sig w1) (switch (sig w1)
(case 0011 (eq (sig c1) (const 1'd1))) (case 0011 (eq (sig c1) (const 1'd1)))
(case ---- (eq (sig c2) (const 1'd1))) (default (eq (sig c2) (const 1'd1)))
) )
) )
""") """)