hdl._ast: add SwitchValue, reimplement ArrayProxy with it.

This commit is contained in:
Wanda 2024-04-03 09:09:18 +02:00 committed by Catherine
parent 2eb62a8b49
commit 2cf9bbf306
7 changed files with 382 additions and 140 deletions

View file

@ -18,7 +18,7 @@ from .._unused import *
__all__ = [
"SyntaxError", "SyntaxWarning",
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue",
"Array", "ArrayProxy",
"Signal", "ClockSignal", "ResetSignal",
"ValueCastable", "ValueLike",
@ -1892,6 +1892,60 @@ class Concat(Value):
return "(cat {})".format(" ".join(map(repr, self.parts)))
@final
class SwitchValue(Value):
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
if src_loc is None:
super().__init__(src_loc_at=src_loc_at)
else:
self.src_loc = src_loc
self._test = Value.cast(test)
new_cases = []
for patterns, value in cases:
if patterns is not None:
if not isinstance(patterns, tuple):
patterns = (patterns,)
new_patterns = ()
key_mask = (1 << len(self.test)) - 1
for key in _normalize_patterns(patterns, self._test.shape()):
if isinstance(key, int):
key = to_binary(key & key_mask, len(self.test))
new_patterns = (*new_patterns, key)
else:
new_patterns = None
new_cases.append((new_patterns, Value.cast(value)))
self._cases = tuple(new_cases)
@property
def test(self):
return self._test
@property
def cases(self):
return self._cases
def shape(self):
return Shape._unify(value.shape() for _patterns, value in self._cases)
def _lhs_signals(self):
return union((value._lhs_signals() for _patterns, value in self.cases), start=SignalSet())
def _rhs_signals(self):
signals = union((value._rhs_signals() for _patterns, value in self.cases), start=SignalSet())
return self.test._rhs_signals() | signals
def __repr__(self):
def case_repr(patterns, value):
if patterns is None:
return f"(default {value!r})"
elif len(patterns) == 1:
return f"(case {patterns[0]} {value!r})"
else:
return "(case ({}) {!r})".format(" ".join(patterns), value)
case_reprs = (case_repr(patterns, value) for patterns, value in self.cases)
return "(switch-value {!r} {})".format(self.test, " ".join(case_reprs))
class _SignalMeta(ABCMeta):
def __call__(cls, shape=None, src_loc_at=0, **kwargs):
signal = super().__call__(shape, **kwargs, src_loc_at=src_loc_at + 1)
@ -2356,10 +2410,17 @@ class Array(MutableSequence):
", ".join(map(repr, self._inner)))
def _proxy_value(name):
@functools.wraps(getattr(Value, name))
def inner(self, *args, **kwargs):
return getattr(Value.cast(self), name)(*args, **kwargs)
return inner
@final
class ArrayProxy(Value):
class ArrayProxy(ValueCastable):
def __init__(self, elems, index, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at)
self.src_loc = tracer.get_src_loc(1 + src_loc_at)
self._elems = elems
self._index = Value.cast(index)
@ -2385,19 +2446,73 @@ class ArrayProxy(Value):
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
return Shape._unify(elem.shape() for elem in self._iter_as_values())
def _lhs_signals(self):
signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
start=SignalSet())
return signals
def as_value(self):
return SwitchValue(
self._index,
(
(index, value)
for index, value in enumerate(self._elems)
if index in range(1 << len(self._index))
),
src_loc=self.src_loc,
)
def _rhs_signals(self):
signals = union((elem._rhs_signals() for elem in self._iter_as_values()),
start=SignalSet())
return self.index._rhs_signals() | signals
def eq(self, value, *, src_loc_at=0):
return self.as_value().eq(value, src_loc_at=1 + src_loc_at)
def __repr__(self):
return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index)
as_signed = _proxy_value("as_signed")
as_unsigned = _proxy_value("as_unsigned")
__len__ = _proxy_value("__len__")
__bool__ = _proxy_value("__bool__")
bool = _proxy_value("bool")
__pos__ = _proxy_value("__pos__")
__neg__ = _proxy_value("__neg__")
__add__ = _proxy_value("__add__")
__radd__ = _proxy_value("__radd__")
__sub__ = _proxy_value("__sub__")
__rsub__ = _proxy_value("__rsub__")
__mul__ = _proxy_value("__mul__")
__rmul__ = _proxy_value("__rmul__")
__floordiv__ = _proxy_value("__floordiv__")
__rfloordiv__ = _proxy_value("__rfloordiv__")
__mod__ = _proxy_value("__mod__")
__rmod__ = _proxy_value("__rmod__")
__eq__ = _proxy_value("__eq__")
__ne__ = _proxy_value("__ne__")
__lt__ = _proxy_value("__lt__")
__le__ = _proxy_value("__le__")
__gt__ = _proxy_value("__gt__")
__ge__ = _proxy_value("__ge__")
__abs__ = _proxy_value("__abs__")
__invert__ = _proxy_value("__invert__")
__and__ = _proxy_value("__and__")
__rand__ = _proxy_value("__rand__")
__or__ = _proxy_value("__or__")
__ror__ = _proxy_value("__ror__")
__xor__ = _proxy_value("__xor__")
__rxor__ = _proxy_value("__rxor__")
any = _proxy_value("any")
all = _proxy_value("all")
xor = _proxy_value("xor")
implies = _proxy_value("implies")
__lshift__ = _proxy_value("__lshift__")
__rlshift__ = _proxy_value("__rlshift__")
__rshift__ = _proxy_value("__rshift__")
__rrshift__ = _proxy_value("__rrshift__")
shift_left = _proxy_value("shift_left")
shift_right = _proxy_value("shift_right")
rotate_left = _proxy_value("rotate_left")
rotate_right = _proxy_value("rotate_right")
__contains__ = _proxy_value("__contains__")
bit_select = _proxy_value("bit_select")
word_select = _proxy_value("word_select")
replicate = _proxy_value("replicate")
matches = _proxy_value("matches")
__format__ = _proxy_value("__format__")
@final
class Initial(Value):
@ -2772,7 +2887,7 @@ class Switch(Statement):
self.src_loc = src_loc
self._test = Value.cast(test)
self._cases = []
new_cases = []
for patterns, stmts, case_src_loc in cases:
if patterns is not None:
# Map: key -> (key,); (key...) -> (key...)
@ -2787,10 +2902,8 @@ class Switch(Statement):
new_patterns = (*new_patterns, key)
else:
new_patterns = None
if not isinstance(stmts, Iterable):
stmts = [stmts]
self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
self._cases = tuple(self._cases)
new_cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
self._cases = tuple(new_cases)
@property
def test(self):
@ -2816,7 +2929,7 @@ class Switch(Statement):
return f"(case {patterns[0]} {stmts_repr})"
else:
return "(case ({}) {})".format(" ".join(patterns), stmts_repr)
case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases]
case_reprs = (case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases)
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))

View file

@ -885,30 +885,39 @@ class NetlistEmitter:
stride=value.stride, offset=offset, src_loc=value.src_loc)
result = self.netlist.add_value_cell(value.width, cell)
signed = False
elif isinstance(value, _ast.ArrayProxy):
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
elif isinstance(value, _ast.SwitchValue):
test, _signed = self.emit_rhs(module_idx, value.test)
conds = []
elems = []
for patterns, elem, in value.cases:
if patterns is not None:
if not patterns:
# Hack: empty pattern set cannot be supported by RTLIL.
continue
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=value.src_loc)
net, = self.netlist.add_value_cell(1, cell)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
elems.append(self.emit_rhs(module_idx, elem))
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
inputs=_nir.Value(conds),
src_loc=value.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
shape = _ast.Shape._unify(
_ast.Shape(len(value), signed)
for value, signed in elems
)
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
index, _signed = self.emit_rhs(module_idx, value.index)
conds = []
for case_index in range(len(elems)):
cell = _nir.Matches(module_idx, value=index,
patterns=(to_binary(case_index, len(index)),),
src_loc=value.src_loc)
subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond)
conds = _nir.Value(conds)
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), inputs=conds, src_loc=value.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
assignments = [
_nir.Assignment(cond=cond, start=0, value=elem, src_loc=value.src_loc)
for cond, elem in zip(conds, elems)
_nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc)
for subcond, elem in zip(conds, elems)
]
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
src_loc=value.src_loc)
cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width),
assignments=assignments, src_loc=value.src_loc)
result = self.netlist.add_value_cell(shape.width, cell)
signed = shape.signed
elif isinstance(value, _ast.Concat):
@ -1017,19 +1026,29 @@ class NetlistEmitter:
else:
subrhs = rhs
self.emit_assign(module_idx, cd, lhs.value, start, subrhs, subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.ArrayProxy):
index, _signed = self.emit_rhs(module_idx, lhs.index)
elif isinstance(lhs, _ast.SwitchValue):
test, _signed = self.emit_rhs(module_idx, lhs.test)
conds = []
for case_index in range(len(lhs.elems)):
cell = _nir.Matches(module_idx, value=index,
patterns=(to_binary(case_index, len(index)),),
src_loc=lhs.src_loc)
subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond)
elems = []
for patterns, elem in lhs.cases:
if patterns is not None:
if not patterns:
# Hack: empty pattern set cannot be supported by RTLIL.
continue
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=lhs.src_loc)
net, = self.netlist.add_value_cell(1, cell)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
elems.append(elem)
conds = _nir.Value(conds)
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc)
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
inputs=conds, src_loc=lhs.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
for subcond, val in zip(conds, lhs.elems):
for subcond, val in zip(conds, elems):
self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.Operator):
assert lhs.operator in ('u', 's')

View file

@ -58,7 +58,7 @@ class ValueVisitor(metaclass=ABCMeta):
pass # :nocov:
@abstractmethod
def on_ArrayProxy(self, value):
def on_SwitchValue(self, value):
pass # :nocov:
@abstractmethod
@ -90,8 +90,8 @@ class ValueVisitor(metaclass=ABCMeta):
new_value = self.on_Part(value)
elif type(value) is Concat:
new_value = self.on_Concat(value)
elif type(value) is ArrayProxy:
new_value = self.on_ArrayProxy(value)
elif type(value) is SwitchValue:
new_value = self.on_SwitchValue(value)
elif type(value) is Initial:
new_value = self.on_Initial(value)
else:
@ -133,9 +133,8 @@ class ValueTransformer(ValueVisitor):
def on_Concat(self, value):
return Concat(self.on_value(o) for o in value.parts)
def on_ArrayProxy(self, value):
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
self.on_value(value.index))
def on_SwitchValue(self, value):
return SwitchValue(self.on_value(value.test), [(patterns, self.on_value(val)) for patterns, val in value.cases])
def on_Initial(self, value):
return value
@ -399,10 +398,10 @@ class DomainCollector(ValueVisitor, StatementVisitor):
for o in value.parts:
self.on_value(o)
def on_ArrayProxy(self, value):
for elem in value._iter_as_values():
self.on_value(elem)
self.on_value(value.index)
def on_SwitchValue(self, value):
self.on_value(value.test)
for patterns, val in value.cases:
self.on_value(val)
def on_Initial(self, value):
pass

View file

@ -67,6 +67,53 @@ class _Compiler:
self.state = state
self.emitter = emitter
def _emit_switch(self, test, cases, case_handler):
if not cases:
return
use_match = _USE_PATTERN_MATCHING
for patterns, *_ in cases:
if patterns is None:
continue
for pattern in patterns:
if "-" in pattern:
use_match = False
if use_match:
self.emitter.append(f"match {test}:")
with self.emitter.indent():
for case in cases:
patterns = case[0]
if patterns is None:
self.emitter.append(f"case _:")
elif not patterns:
self.emitter.append(f"case _ if False:")
else:
self.emitter.append(f"case {' | '.join(f'0b0{pattern}' for pattern in patterns)}:")
with self.emitter.indent():
case_handler(*case)
else:
for index, case in enumerate(cases):
patterns = case[0]
gen_checks = []
if patterns is None:
gen_checks.append(f"True")
elif not patterns:
gen_checks.append(f"False")
else:
for pattern in patterns:
if "-" in pattern:
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
value = int("".join("0" if b == "-" else b for b in pattern), 2)
gen_checks.append(f"{value} == ({mask} & {test})")
else:
value = int(pattern or "0", 2)
gen_checks.append(f"{value} == {test}")
if index == 0:
self.emitter.append(f"if {' or '.join(gen_checks)}:")
else:
self.emitter.append(f"elif {' or '.join(gen_checks)}:")
with self.emitter.indent():
case_handler(*case)
class _ValueCompiler(ValueVisitor, _Compiler):
helpers = {
@ -223,36 +270,13 @@ class _RHSValueCompiler(_ValueCompiler):
return f"({' | '.join(gen_parts)})"
return f"0"
def on_ArrayProxy(self, value):
index_mask = (1 << len(value.index)) - 1
gen_index = self.emitter.def_var("rhs_index", f"{index_mask:#x} & {self(value.index)}")
gen_value = self.emitter.gen_var("rhs_proxy")
if value.elems:
if _USE_PATTERN_MATCHING:
self.emitter.append(f"match {gen_index}:")
with self.emitter.indent():
for index, elem in enumerate(value.elems):
self.emitter.append(f"case {index}:")
with self.emitter.indent():
self.emitter.append(f"{gen_value} = {self(elem)}")
self.emitter.append("case _:")
with self.emitter.indent():
self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
else:
for index, elem in enumerate(value.elems):
if index == 0:
self.emitter.append(f"if {index} == {gen_index}:")
else:
self.emitter.append(f"elif {index} == {gen_index}:")
with self.emitter.indent():
self.emitter.append(f"{gen_value} = {self(elem)}")
self.emitter.append(f"else:")
with self.emitter.indent():
self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
return gen_value
else:
return f"0"
def on_SwitchValue(self, value):
gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}")
gen_value = self.emitter.def_var("rhs_switch", "0")
def case_handler(patterns, elem):
self.emitter.append(f"{gen_value} = {self(elem)}")
self._emit_switch(gen_test, value.cases, case_handler)
return gen_value
@classmethod
def compile(cls, state, value, *, mode):
@ -323,34 +347,12 @@ class _LHSValueCompiler(_ValueCompiler):
offset += len(part)
return gen
def on_ArrayProxy(self, value):
def on_SwitchValue(self, value):
def gen(arg):
index_mask = (1 << len(value.index)) - 1
gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask:#x}")
if value.elems:
if _USE_PATTERN_MATCHING:
self.emitter.append(f"match {gen_index}:")
with self.emitter.indent():
for index, elem in enumerate(value.elems):
self.emitter.append(f"case {index}:")
with self.emitter.indent():
self(elem)(arg)
self.emitter.append("case _:")
with self.emitter.indent():
self(value.elems[-1])(arg)
else:
for index, elem in enumerate(value.elems):
if index == 0:
self.emitter.append(f"if {index} == {gen_index}:")
else:
self.emitter.append(f"elif {index} == {gen_index}:")
with self.emitter.indent():
self(elem)(arg)
self.emitter.append(f"else:")
with self.emitter.indent():
self(value.elems[-1])(arg)
else:
self.emitter.append(f"pass")
gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self.rrhs(value.test)}")
def case_handler(patterns, elem):
self(elem)(arg)
self._emit_switch(gen_test, value.cases, case_handler)
return gen
@ -396,27 +398,9 @@ class _StatementCompiler(StatementVisitor, _Compiler):
def on_Switch(self, stmt):
gen_test_value = self.rhs(stmt.test) # check for oversized value before generating mask
gen_test = self.emitter.def_var("test", f"{(1 << len(stmt.test)) - 1:#x} & {gen_test_value}")
for index, (patterns, stmts, _src_loc) in enumerate(stmt.cases):
gen_checks = []
if patterns is None:
gen_checks.append(f"True")
elif not patterns:
gen_checks.append(f"False")
else:
for pattern in patterns:
if "-" in pattern:
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
value = int("".join("0" if b == "-" else b for b in pattern), 2)
gen_checks.append(f"{value} == ({mask} & {gen_test})")
else:
value = int(pattern or "0", 2)
gen_checks.append(f"{value} == {gen_test}")
if index == 0:
self.emitter.append(f"if {' or '.join(gen_checks)}:")
else:
self.emitter.append(f"elif {' or '.join(gen_checks)}:")
with self.emitter.indent():
self(stmts)
def case_handler(pattern, stmt, src_loc):
self(stmt)
self._emit_switch(gen_test, stmt.cases, case_handler)
def emit_format(self, format):
format_string = []

View file

@ -1144,6 +1144,7 @@ class ArrayProxyTestCase(FHDLTestCase):
s = Signal(range(3))
v = a[s]
self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))")
self.assertEqual(repr(v.as_value()), "(switch-value (sig s) (case 00 (const 1'd1)) (case 01 (const 2'd2)) (case 10 (const 2'd3)))")
class SignalTestCase(FHDLTestCase):

View file

@ -1623,6 +1623,69 @@ class AssignTestCase(FHDLTestCase):
)
""")
def test_switchvalue(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
s6 = Signal(8)
s7 = Signal(8)
f = Fragment()
f.add_statements("comb", [
SwitchValue(s5[:4], [
(1, s1),
((2, 3), s2),
((), s3),
('11--', s4),
]).eq(s6),
SwitchValue(s5[4:], [
(4, s1),
(5, s2),
(6, s3),
(None, s4),
]).eq(s7),
])
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
f.add_driver(s3, "comb")
f.add_driver(s4, "comb")
nl = build_netlist(f, ports=[s1, s2, s3, s4, s5, s6, s7])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's5' 0.2:10)
(input 's6' 0.10:18)
(input 's7' 0.18:26)
(output 's1' 9.0:8)
(output 's2' 10.0:8)
(output 's3' 12.0:8)
(output 's4' 11.0:8)
)
(cell 0 0 (top
(input 's5' 2:10)
(input 's6' 10:18)
(input 's7' 18:26)
(output 's1' 9.0:8)
(output 's2' 10.0:8)
(output 's3' 12.0:8)
(output 's4' 11.0:8)
))
(cell 1 0 (matches 0.2:6 0001))
(cell 2 0 (matches 0.2:6 0010 0011))
(cell 3 0 (matches 0.2:6 11--))
(cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0)))
(cell 5 0 (matches 0.6:10 0100))
(cell 6 0 (matches 0.6:10 0101))
(cell 7 0 (matches 0.6:10 0110))
(cell 8 0 (priority_match 1 (cat 5.0 6.0 7.0 1'd1)))
(cell 9 0 (assignment_list 8'd0 (4.0 0:8 0.10:18) (8.0 0:8 0.18:26)))
(cell 10 0 (assignment_list 8'd0 (4.1 0:8 0.10:18) (8.1 0:8 0.18:26)))
(cell 11 0 (assignment_list 8'd0 (4.2 0:8 0.10:18) (8.3 0:8 0.18:26)))
(cell 12 0 (assignment_list 8'd0 (8.2 0:8 0.18:26)))
)
""")
def test_sliced_slice(self):
s1 = Signal(12)
s2 = Signal(4)
@ -2953,7 +3016,7 @@ class RhsTestCase(FHDLTestCase):
(cell 2 0 (matches 0.50:54 0001))
(cell 3 0 (matches 0.50:54 0010))
(cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0)))
(cell 5 0 (assignment_list 0.2:10
(cell 5 0 (assignment_list 8'd0
(4.0 0:8 0.2:10)
(4.1 0:8 0.10:18)
(4.2 0:8 0.18:26)
@ -2962,7 +3025,7 @@ class RhsTestCase(FHDLTestCase):
(cell 7 0 (matches 0.50:54 0001))
(cell 8 0 (matches 0.50:54 0010))
(cell 9 0 (priority_match 1 (cat 6.0 7.0 8.0)))
(cell 10 0 (assignment_list (cat 0.2:10 1'd0)
(cell 10 0 (assignment_list 9'd0
(9.0 0:9 (cat 0.2:10 1'd0))
(9.1 0:9 (cat 0.10:18 1'd0))
(9.2 0:9 (cat 0.42:50 0.49))
@ -2971,7 +3034,7 @@ class RhsTestCase(FHDLTestCase):
(cell 12 0 (matches 0.50:54 0001))
(cell 13 0 (matches 0.50:54 0010))
(cell 14 0 (priority_match 1 (cat 11.0 12.0 13.0)))
(cell 15 0 (assignment_list 0.26:34
(cell 15 0 (assignment_list 8'd0
(14.0 0:8 0.26:34)
(14.1 0:8 0.34:42)
(14.2 0:8 0.42:50)
@ -2980,7 +3043,7 @@ class RhsTestCase(FHDLTestCase):
(cell 17 0 (matches 0.50:54 0001))
(cell 18 0 (matches 0.50:54 0010))
(cell 19 0 (priority_match 1 (cat 16.0 17.0 18.0)))
(cell 20 0 (assignment_list 0.26:34
(cell 20 0 (assignment_list 8'd0
(19.0 0:8 0.26:34)
(19.1 0:8 0.34:42)
(19.2 0:8 (cat 0.50:54 4'd0))
@ -2988,6 +3051,67 @@ class RhsTestCase(FHDLTestCase):
)
""")
def test_switchvalue(self):
i8ua = Signal(8)
i8ub = Signal(8)
i8uc = Signal(8)
i8ud = Signal(8)
i4 = Signal(4)
o1 = Signal(10)
o2 = Signal(10)
m = Module()
m.d.comb += o1.eq(SwitchValue(i4, [
(1, i8ua),
((2, 3), i8ub),
('11--', i8uc),
]))
m.d.comb += o2.eq(SwitchValue(i4, [
((4, 5), i8ua),
((), i8ub),
((6, 7), i8uc),
(None, i8ud),
]))
nl = build_netlist(Fragment.get(m, None), [i8ua, i8ub, i8uc, i8ud, i4, o1, o2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 'i8ua' 0.2:10)
(input 'i8ub' 0.10:18)
(input 'i8uc' 0.18:26)
(input 'i8ud' 0.26:34)
(input 'i4' 0.34:38)
(output 'o1' (cat 5.0:8 2'd0))
(output 'o2' (cat 9.0:8 2'd0))
)
(cell 0 0 (top
(input 'i8ua' 2:10)
(input 'i8ub' 10:18)
(input 'i8uc' 18:26)
(input 'i8ud' 26:34)
(input 'i4' 34:38)
(output 'o1' (cat 5.0:8 2'd0))
(output 'o2' (cat 9.0:8 2'd0))
))
(cell 1 0 (matches 0.34:38 0001))
(cell 2 0 (matches 0.34:38 0010 0011))
(cell 3 0 (matches 0.34:38 11--))
(cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0)))
(cell 5 0 (assignment_list 8'd0
(4.0 0:8 0.2:10)
(4.1 0:8 0.10:18)
(4.2 0:8 0.18:26)
))
(cell 6 0 (matches 0.34:38 0100 0101))
(cell 7 0 (matches 0.34:38 0110 0111))
(cell 8 0 (priority_match 1 (cat 6.0 7.0 1'd1)))
(cell 9 0 (assignment_list 8'd0
(8.0 0:8 0.2:10)
(8.1 0:8 0.18:26)
(8.2 0:8 0.26:34)
))
)
""")
def test_anyvalue(self):
o1 = Signal(12)
o2 = Signal(12)

View file

@ -314,8 +314,8 @@ class SimulatorUnitTestCase(FHDLTestCase):
def test_array_oob(self):
array = Array([1, 4, 10])
stmt = lambda y, a: y.eq(array[a])
self.assertStatement(stmt, [C(3)], C(10))
self.assertStatement(stmt, [C(4)], C(10))
self.assertStatement(stmt, [C(3)], C(0))
self.assertStatement(stmt, [C(4)], C(0))
def test_array_lhs(self):
l = Signal(3, init=1)
@ -333,8 +333,8 @@ class SimulatorUnitTestCase(FHDLTestCase):
n = Signal(3)
array = Array([l, m, n])
stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
self.assertStatement(stmt, [C(3), C(0b001)], C(0b001000000))
self.assertStatement(stmt, [C(4), C(0b010)], C(0b010000000))
self.assertStatement(stmt, [C(3), C(0b001)], C(0))
self.assertStatement(stmt, [C(4), C(0b010)], C(0))
def test_array_index(self):
array = Array(Array(x * y for y in range(10)) for x in range(10))
@ -513,6 +513,8 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
with self.m.Switch(self.s):
with self.m.Case(0):
self.m.d.sync += self.o.eq(self.a + self.b)
with self.m.Case():
self.m.d.sync += self.o.eq(self.a * self.b)
with self.m.Case(1):
self.m.d.sync += self.o.eq(self.a - self.b)
with self.m.Default():