hdl.{ast,dsl}, back.{pysim,rtlil}: allow multiple case values.
This means that instead of: with m.Case(0b00): <body> with m.Case(0b01): <body> it is legal to write: with m.Case(0b00, 0b01): <body> with no change in semantics, and slightly nicer RTLIL or Verilog output. Fixes #103.
This commit is contained in:
parent
48d4ee4031
commit
32446831b4
|
@ -318,20 +318,22 @@ class _StatementCompiler(StatementVisitor):
|
||||||
def on_Switch(self, stmt):
|
def on_Switch(self, stmt):
|
||||||
test = self.rrhs_compiler(stmt.test)
|
test = self.rrhs_compiler(stmt.test)
|
||||||
cases = []
|
cases = []
|
||||||
for value, stmts in stmt.cases.items():
|
for values, stmts in stmt.cases.items():
|
||||||
if value is None:
|
if values == ():
|
||||||
check = lambda test: True
|
check = lambda test: True
|
||||||
else:
|
else:
|
||||||
if "-" in value:
|
check = lambda test: False
|
||||||
mask = "".join("0" if b == "-" else "1" for b in value)
|
def make_check(mask, value, prev_check):
|
||||||
value = "".join("0" if b == "-" else b for b in value)
|
return lambda test: prev_check(test) or test & mask == value
|
||||||
else:
|
for value in values:
|
||||||
mask = "1" * len(value)
|
if "-" in value:
|
||||||
mask = int(mask, 2)
|
mask = "".join("0" if b == "-" else "1" for b in value)
|
||||||
value = int(value, 2)
|
value = "".join("0" if b == "-" else b for b in value)
|
||||||
def make_check(mask, value):
|
else:
|
||||||
return lambda test: test & mask == value
|
mask = "1" * len(value)
|
||||||
check = make_check(mask, value)
|
mask = int(mask, 2)
|
||||||
|
value = int(value, 2)
|
||||||
|
check = make_check(mask, value, check)
|
||||||
cases.append((check, self.on_statements(stmts)))
|
cases.append((check, self.on_statements(stmts)))
|
||||||
def run(state):
|
def run(state):
|
||||||
test_value = test(state)
|
test_value = test(state)
|
||||||
|
|
|
@ -188,12 +188,12 @@ class _SwitchBuilder:
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self.rtlil._append("{}end\n", " " * self.indent)
|
self.rtlil._append("{}end\n", " " * self.indent)
|
||||||
|
|
||||||
def case(self, value=None):
|
def case(self, *values):
|
||||||
if value is None:
|
if values == ():
|
||||||
self.rtlil._append("{}case\n", " " * (self.indent + 1))
|
self.rtlil._append("{}case\n", " " * (self.indent + 1))
|
||||||
else:
|
else:
|
||||||
self.rtlil._append("{}case {}'{}\n", " " * (self.indent + 1),
|
self.rtlil._append("{}case {}\n", " " * (self.indent + 1),
|
||||||
len(value), value)
|
", ".join("{}'{}".format(len(value), value) for value in values))
|
||||||
return _CaseBuilder(self.rtlil, self.indent + 2)
|
return _CaseBuilder(self.rtlil, self.indent + 2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -590,10 +590,10 @@ class _StatementCompiler(xfrm.StatementVisitor):
|
||||||
self._has_rhs = False
|
self._has_rhs = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def case(self, switch, value):
|
def case(self, switch, values):
|
||||||
try:
|
try:
|
||||||
old_case = self._case
|
old_case = self._case
|
||||||
with switch.case(value) as self._case:
|
with switch.case(*values) as self._case:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
self._case = old_case
|
self._case = old_case
|
||||||
|
@ -645,8 +645,8 @@ class _StatementCompiler(xfrm.StatementVisitor):
|
||||||
test_sigspec = self._test_cache[stmt]
|
test_sigspec = self._test_cache[stmt]
|
||||||
|
|
||||||
with self._case.switch(test_sigspec) as switch:
|
with self._case.switch(test_sigspec) as switch:
|
||||||
for value, stmts in stmt.cases.items():
|
for values, stmts in stmt.cases.items():
|
||||||
with self.case(switch, value):
|
with self.case(switch, values):
|
||||||
self.on_statements(stmts)
|
self.on_statements(stmts)
|
||||||
|
|
||||||
def on_statement(self, stmt):
|
def on_statement(self, stmt):
|
||||||
|
|
|
@ -106,12 +106,12 @@ class Case(ast.Switch):
|
||||||
or choice > key):
|
or choice > key):
|
||||||
key = choice
|
key = choice
|
||||||
elif isinstance(key, str) and key == "default":
|
elif isinstance(key, str) and key == "default":
|
||||||
key = None
|
key = ()
|
||||||
else:
|
else:
|
||||||
key = "{:0{}b}".format(wrap(key).value, len(self.test))
|
key = ("{:0{}b}".format(wrap(key).value, len(self.test)),)
|
||||||
stmts = self.cases[key]
|
stmts = self.cases[key]
|
||||||
del self.cases[key]
|
del self.cases[key]
|
||||||
self.cases[None] = stmts
|
self.cases[()] = stmts
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1019,20 +1019,27 @@ class Switch(Statement):
|
||||||
def __init__(self, test, cases):
|
def __init__(self, test, cases):
|
||||||
self.test = Value.wrap(test)
|
self.test = Value.wrap(test)
|
||||||
self.cases = OrderedDict()
|
self.cases = OrderedDict()
|
||||||
for key, stmts in cases.items():
|
for keys, stmts in cases.items():
|
||||||
if isinstance(key, (bool, int)):
|
# Map: None -> (); key -> (key,); (key...) -> (key...)
|
||||||
key = "{:0{}b}".format(key, len(self.test))
|
if keys is None:
|
||||||
elif isinstance(key, str):
|
keys = ()
|
||||||
pass
|
if not isinstance(keys, tuple):
|
||||||
elif key is None:
|
keys = (keys,)
|
||||||
pass
|
# Map: 2 -> "0010"; "0010" -> "0010"
|
||||||
else:
|
new_keys = ()
|
||||||
raise TypeError("Object '{!r}' cannot be used as a switch key"
|
for key in keys:
|
||||||
.format(key))
|
if isinstance(key, (bool, int)):
|
||||||
assert key is None or len(key) == len(self.test)
|
key = "{:0{}b}".format(key, len(self.test))
|
||||||
|
elif isinstance(key, str):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Object '{!r}' cannot be used as a switch key"
|
||||||
|
.format(key))
|
||||||
|
assert len(key) == len(self.test)
|
||||||
|
new_keys = (*new_keys, key)
|
||||||
if not isinstance(stmts, Iterable):
|
if not isinstance(stmts, Iterable):
|
||||||
stmts = [stmts]
|
stmts = [stmts]
|
||||||
self.cases[key] = Statement.wrap(stmts)
|
self.cases[new_keys] = Statement.wrap(stmts)
|
||||||
|
|
||||||
def _lhs_signals(self):
|
def _lhs_signals(self):
|
||||||
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
|
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
|
||||||
|
@ -1045,11 +1052,16 @@ class Switch(Statement):
|
||||||
return self.test._rhs_signals() | signals
|
return self.test._rhs_signals() | signals
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
cases = ["(default {})".format(" ".join(map(repr, stmts)))
|
def case_repr(keys, stmts):
|
||||||
if key is None else
|
stmts_repr = " ".join(map(repr, stmts))
|
||||||
"(case {} {})".format(key, " ".join(map(repr, stmts)))
|
if keys == ():
|
||||||
for key, stmts in self.cases.items()]
|
return "(default {})".format(stmts_repr)
|
||||||
return "(switch {!r} {})".format(self.test, " ".join(cases))
|
elif len(keys) == 1:
|
||||||
|
return "(case {} {})".format(keys[0], stmts_repr)
|
||||||
|
else:
|
||||||
|
return "(case ({}) {})".format(" ".join(keys), stmts_repr)
|
||||||
|
case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
|
||||||
|
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|
|
@ -214,27 +214,31 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
||||||
self._pop_ctrl()
|
self._pop_ctrl()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def Case(self, value=None):
|
def Case(self, *values):
|
||||||
self._check_context("Case", context="Switch")
|
self._check_context("Case", context="Switch")
|
||||||
switch_data = self._get_ctrl("Switch")
|
switch_data = self._get_ctrl("Switch")
|
||||||
if value is None:
|
new_values = ()
|
||||||
value = "-" * len(switch_data["test"])
|
for value in values:
|
||||||
if isinstance(value, str) and len(value) != len(switch_data["test"]):
|
if isinstance(value, str) and len(value) != len(switch_data["test"]):
|
||||||
raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
|
raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
|
||||||
.format(value, len(switch_data["test"])))
|
.format(value, len(switch_data["test"])))
|
||||||
omit_case = False
|
if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
|
||||||
if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
|
warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
|
||||||
warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
|
"comparison will never be true"
|
||||||
"comparison will never be true"
|
.format(value, len(switch_data["test"])),
|
||||||
.format(value, len(switch_data["test"])), SyntaxWarning, stacklevel=3)
|
SyntaxWarning, stacklevel=3)
|
||||||
omit_case = True
|
continue
|
||||||
|
new_values = (*new_values, value)
|
||||||
try:
|
try:
|
||||||
_outer_case, self._statements = self._statements, []
|
_outer_case, self._statements = self._statements, []
|
||||||
self._ctrl_context = None
|
self._ctrl_context = None
|
||||||
yield
|
yield
|
||||||
self._flush_ctrl()
|
self._flush_ctrl()
|
||||||
if not omit_case:
|
# If none of the provided cases can possibly be true, omit this branch completely.
|
||||||
switch_data["cases"][value] = self._statements
|
# This needs to be differentiated from no cases being provided in the first place,
|
||||||
|
# which means the branch will always match.
|
||||||
|
if not (values and not new_values):
|
||||||
|
switch_data["cases"][new_values] = self._statements
|
||||||
finally:
|
finally:
|
||||||
self._ctrl_context = "Switch"
|
self._ctrl_context = "Switch"
|
||||||
self._statements = _outer_case
|
self._statements = _outer_case
|
||||||
|
|
|
@ -297,7 +297,7 @@ class DSLTestCase(FHDLTestCase):
|
||||||
(
|
(
|
||||||
(switch (sig w1)
|
(switch (sig w1)
|
||||||
(case 0011 (eq (sig c1) (const 1'd1)))
|
(case 0011 (eq (sig c1) (const 1'd1)))
|
||||||
(case ---- (eq (sig c2) (const 1'd1)))
|
(default (eq (sig c2) (const 1'd1)))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
Loading…
Reference in a new issue