diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 0b7f15e..57f102e 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -18,7 +18,7 @@ from .._unused import * __all__ = [ "SyntaxError", "SyntaxWarning", "Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike", - "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", + "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue", "Array", "ArrayProxy", "Signal", "ClockSignal", "ResetSignal", "ValueCastable", "ValueLike", @@ -1892,6 +1892,60 @@ class Concat(Value): return "(cat {})".format(" ".join(map(repr, self.parts))) +@final +class SwitchValue(Value): + def __init__(self, test, cases, *, src_loc=None, src_loc_at=0): + if src_loc is None: + super().__init__(src_loc_at=src_loc_at) + else: + self.src_loc = src_loc + self._test = Value.cast(test) + new_cases = [] + for patterns, value in cases: + if patterns is not None: + if not isinstance(patterns, tuple): + patterns = (patterns,) + new_patterns = () + key_mask = (1 << len(self.test)) - 1 + for key in _normalize_patterns(patterns, self._test.shape()): + if isinstance(key, int): + key = to_binary(key & key_mask, len(self.test)) + new_patterns = (*new_patterns, key) + else: + new_patterns = None + new_cases.append((new_patterns, Value.cast(value))) + self._cases = tuple(new_cases) + + @property + def test(self): + return self._test + + @property + def cases(self): + return self._cases + + def shape(self): + return Shape._unify(value.shape() for _patterns, value in self._cases) + + def _lhs_signals(self): + return union((value._lhs_signals() for _patterns, value in self.cases), start=SignalSet()) + + def _rhs_signals(self): + signals = union((value._rhs_signals() for _patterns, value in self.cases), start=SignalSet()) + return self.test._rhs_signals() | signals + + def __repr__(self): + def case_repr(patterns, value): + if patterns is None: + return f"(default {value!r})" + elif len(patterns) == 1: + return f"(case {patterns[0]} {value!r})" + else: + return "(case ({}) {!r})".format(" ".join(patterns), value) + case_reprs = (case_repr(patterns, value) for patterns, value in self.cases) + return "(switch-value {!r} {})".format(self.test, " ".join(case_reprs)) + + class _SignalMeta(ABCMeta): def __call__(cls, shape=None, src_loc_at=0, **kwargs): signal = super().__call__(shape, **kwargs, src_loc_at=src_loc_at + 1) @@ -2356,10 +2410,17 @@ class Array(MutableSequence): ", ".join(map(repr, self._inner))) +def _proxy_value(name): + @functools.wraps(getattr(Value, name)) + def inner(self, *args, **kwargs): + return getattr(Value.cast(self), name)(*args, **kwargs) + return inner + + @final -class ArrayProxy(Value): +class ArrayProxy(ValueCastable): def __init__(self, elems, index, *, src_loc_at=0): - super().__init__(src_loc_at=1 + src_loc_at) + self.src_loc = tracer.get_src_loc(1 + src_loc_at) self._elems = elems self._index = Value.cast(index) @@ -2385,19 +2446,73 @@ class ArrayProxy(Value): # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree. return Shape._unify(elem.shape() for elem in self._iter_as_values()) - def _lhs_signals(self): - signals = union((elem._lhs_signals() for elem in self._iter_as_values()), - start=SignalSet()) - return signals + def as_value(self): + return SwitchValue( + self._index, + ( + (index, value) + for index, value in enumerate(self._elems) + if index in range(1 << len(self._index)) + ), + src_loc=self.src_loc, + ) - def _rhs_signals(self): - signals = union((elem._rhs_signals() for elem in self._iter_as_values()), - start=SignalSet()) - return self.index._rhs_signals() | signals + def eq(self, value, *, src_loc_at=0): + return self.as_value().eq(value, src_loc_at=1 + src_loc_at) def __repr__(self): return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index) + as_signed = _proxy_value("as_signed") + as_unsigned = _proxy_value("as_unsigned") + __len__ = _proxy_value("__len__") + __bool__ = _proxy_value("__bool__") + bool = _proxy_value("bool") + __pos__ = _proxy_value("__pos__") + __neg__ = _proxy_value("__neg__") + __add__ = _proxy_value("__add__") + __radd__ = _proxy_value("__radd__") + __sub__ = _proxy_value("__sub__") + __rsub__ = _proxy_value("__rsub__") + __mul__ = _proxy_value("__mul__") + __rmul__ = _proxy_value("__rmul__") + __floordiv__ = _proxy_value("__floordiv__") + __rfloordiv__ = _proxy_value("__rfloordiv__") + __mod__ = _proxy_value("__mod__") + __rmod__ = _proxy_value("__rmod__") + __eq__ = _proxy_value("__eq__") + __ne__ = _proxy_value("__ne__") + __lt__ = _proxy_value("__lt__") + __le__ = _proxy_value("__le__") + __gt__ = _proxy_value("__gt__") + __ge__ = _proxy_value("__ge__") + __abs__ = _proxy_value("__abs__") + __invert__ = _proxy_value("__invert__") + __and__ = _proxy_value("__and__") + __rand__ = _proxy_value("__rand__") + __or__ = _proxy_value("__or__") + __ror__ = _proxy_value("__ror__") + __xor__ = _proxy_value("__xor__") + __rxor__ = _proxy_value("__rxor__") + any = _proxy_value("any") + all = _proxy_value("all") + xor = _proxy_value("xor") + implies = _proxy_value("implies") + __lshift__ = _proxy_value("__lshift__") + __rlshift__ = _proxy_value("__rlshift__") + __rshift__ = _proxy_value("__rshift__") + __rrshift__ = _proxy_value("__rrshift__") + shift_left = _proxy_value("shift_left") + shift_right = _proxy_value("shift_right") + rotate_left = _proxy_value("rotate_left") + rotate_right = _proxy_value("rotate_right") + __contains__ = _proxy_value("__contains__") + bit_select = _proxy_value("bit_select") + word_select = _proxy_value("word_select") + replicate = _proxy_value("replicate") + matches = _proxy_value("matches") + __format__ = _proxy_value("__format__") + @final class Initial(Value): @@ -2772,7 +2887,7 @@ class Switch(Statement): self.src_loc = src_loc self._test = Value.cast(test) - self._cases = [] + new_cases = [] for patterns, stmts, case_src_loc in cases: if patterns is not None: # Map: key -> (key,); (key...) -> (key...) @@ -2787,10 +2902,8 @@ class Switch(Statement): new_patterns = (*new_patterns, key) else: new_patterns = None - if not isinstance(stmts, Iterable): - stmts = [stmts] - self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc)) - self._cases = tuple(self._cases) + new_cases.append((new_patterns, Statement.cast(stmts), case_src_loc)) + self._cases = tuple(new_cases) @property def test(self): @@ -2816,7 +2929,7 @@ class Switch(Statement): return f"(case {patterns[0]} {stmts_repr})" else: return "(case ({}) {})".format(" ".join(patterns), stmts_repr) - case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases] + case_reprs = (case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases) return "(switch {!r} {})".format(self.test, " ".join(case_reprs)) diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 717db2e..7f8a998 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -885,30 +885,39 @@ class NetlistEmitter: stride=value.stride, offset=offset, src_loc=value.src_loc) result = self.netlist.add_value_cell(value.width, cell) signed = False - elif isinstance(value, _ast.ArrayProxy): - elems = [self.emit_rhs(module_idx, elem) for elem in value.elems] + elif isinstance(value, _ast.SwitchValue): + test, _signed = self.emit_rhs(module_idx, value.test) + conds = [] + elems = [] + for patterns, elem, in value.cases: + if patterns is not None: + if not patterns: + # Hack: empty pattern set cannot be supported by RTLIL. + continue + for pattern in patterns: + assert len(pattern) == len(test) + cell = _nir.Matches(module_idx, value=test, patterns=patterns, + src_loc=value.src_loc) + net, = self.netlist.add_value_cell(1, cell) + conds.append(net) + else: + conds.append(_nir.Net.from_const(1)) + elems.append(self.emit_rhs(module_idx, elem)) + cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), + inputs=_nir.Value(conds), + src_loc=value.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) shape = _ast.Shape._unify( _ast.Shape(len(value), signed) for value, signed in elems ) elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems) - index, _signed = self.emit_rhs(module_idx, value.index) - conds = [] - for case_index in range(len(elems)): - cell = _nir.Matches(module_idx, value=index, - patterns=(to_binary(case_index, len(index)),), - src_loc=value.src_loc) - subcond, = self.netlist.add_value_cell(1, cell) - conds.append(subcond) - conds = _nir.Value(conds) - cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), inputs=conds, src_loc=value.src_loc) - conds = self.netlist.add_value_cell(len(conds), cell) assignments = [ - _nir.Assignment(cond=cond, start=0, value=elem, src_loc=value.src_loc) - for cond, elem in zip(conds, elems) + _nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc) + for subcond, elem in zip(conds, elems) ] - cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments, - src_loc=value.src_loc) + cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width), + assignments=assignments, src_loc=value.src_loc) result = self.netlist.add_value_cell(shape.width, cell) signed = shape.signed elif isinstance(value, _ast.Concat): @@ -1017,19 +1026,29 @@ class NetlistEmitter: else: subrhs = rhs self.emit_assign(module_idx, cd, lhs.value, start, subrhs, subcond, src_loc=src_loc) - elif isinstance(lhs, _ast.ArrayProxy): - index, _signed = self.emit_rhs(module_idx, lhs.index) + elif isinstance(lhs, _ast.SwitchValue): + test, _signed = self.emit_rhs(module_idx, lhs.test) conds = [] - for case_index in range(len(lhs.elems)): - cell = _nir.Matches(module_idx, value=index, - patterns=(to_binary(case_index, len(index)),), - src_loc=lhs.src_loc) - subcond, = self.netlist.add_value_cell(1, cell) - conds.append(subcond) + elems = [] + for patterns, elem in lhs.cases: + if patterns is not None: + if not patterns: + # Hack: empty pattern set cannot be supported by RTLIL. + continue + for pattern in patterns: + assert len(pattern) == len(test) + cell = _nir.Matches(module_idx, value=test, patterns=patterns, + src_loc=lhs.src_loc) + net, = self.netlist.add_value_cell(1, cell) + conds.append(net) + else: + conds.append(_nir.Net.from_const(1)) + elems.append(elem) conds = _nir.Value(conds) - cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc) + cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), + inputs=conds, src_loc=lhs.src_loc) conds = self.netlist.add_value_cell(len(conds), cell) - for subcond, val in zip(conds, lhs.elems): + for subcond, val in zip(conds, elems): self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc) elif isinstance(lhs, _ast.Operator): assert lhs.operator in ('u', 's') diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index 01b4fa1..dde48d5 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -58,7 +58,7 @@ class ValueVisitor(metaclass=ABCMeta): pass # :nocov: @abstractmethod - def on_ArrayProxy(self, value): + def on_SwitchValue(self, value): pass # :nocov: @abstractmethod @@ -90,8 +90,8 @@ class ValueVisitor(metaclass=ABCMeta): new_value = self.on_Part(value) elif type(value) is Concat: new_value = self.on_Concat(value) - elif type(value) is ArrayProxy: - new_value = self.on_ArrayProxy(value) + elif type(value) is SwitchValue: + new_value = self.on_SwitchValue(value) elif type(value) is Initial: new_value = self.on_Initial(value) else: @@ -133,9 +133,8 @@ class ValueTransformer(ValueVisitor): def on_Concat(self, value): return Concat(self.on_value(o) for o in value.parts) - def on_ArrayProxy(self, value): - return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()], - self.on_value(value.index)) + def on_SwitchValue(self, value): + return SwitchValue(self.on_value(value.test), [(patterns, self.on_value(val)) for patterns, val in value.cases]) def on_Initial(self, value): return value @@ -399,10 +398,10 @@ class DomainCollector(ValueVisitor, StatementVisitor): for o in value.parts: self.on_value(o) - def on_ArrayProxy(self, value): - for elem in value._iter_as_values(): - self.on_value(elem) - self.on_value(value.index) + def on_SwitchValue(self, value): + self.on_value(value.test) + for patterns, val in value.cases: + self.on_value(val) def on_Initial(self, value): pass diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 3f41f34..14e8b01 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -67,6 +67,53 @@ class _Compiler: self.state = state self.emitter = emitter + def _emit_switch(self, test, cases, case_handler): + if not cases: + return + use_match = _USE_PATTERN_MATCHING + for patterns, *_ in cases: + if patterns is None: + continue + for pattern in patterns: + if "-" in pattern: + use_match = False + if use_match: + self.emitter.append(f"match {test}:") + with self.emitter.indent(): + for case in cases: + patterns = case[0] + if patterns is None: + self.emitter.append(f"case _:") + elif not patterns: + self.emitter.append(f"case _ if False:") + else: + self.emitter.append(f"case {' | '.join(f'0b0{pattern}' for pattern in patterns)}:") + with self.emitter.indent(): + case_handler(*case) + else: + for index, case in enumerate(cases): + patterns = case[0] + gen_checks = [] + if patterns is None: + gen_checks.append(f"True") + elif not patterns: + gen_checks.append(f"False") + else: + for pattern in patterns: + if "-" in pattern: + mask = int("".join("0" if b == "-" else "1" for b in pattern), 2) + value = int("".join("0" if b == "-" else b for b in pattern), 2) + gen_checks.append(f"{value} == ({mask} & {test})") + else: + value = int(pattern or "0", 2) + gen_checks.append(f"{value} == {test}") + if index == 0: + self.emitter.append(f"if {' or '.join(gen_checks)}:") + else: + self.emitter.append(f"elif {' or '.join(gen_checks)}:") + with self.emitter.indent(): + case_handler(*case) + class _ValueCompiler(ValueVisitor, _Compiler): helpers = { @@ -223,36 +270,13 @@ class _RHSValueCompiler(_ValueCompiler): return f"({' | '.join(gen_parts)})" return f"0" - def on_ArrayProxy(self, value): - index_mask = (1 << len(value.index)) - 1 - gen_index = self.emitter.def_var("rhs_index", f"{index_mask:#x} & {self(value.index)}") - gen_value = self.emitter.gen_var("rhs_proxy") - if value.elems: - if _USE_PATTERN_MATCHING: - self.emitter.append(f"match {gen_index}:") - with self.emitter.indent(): - for index, elem in enumerate(value.elems): - self.emitter.append(f"case {index}:") - with self.emitter.indent(): - self.emitter.append(f"{gen_value} = {self(elem)}") - self.emitter.append("case _:") - with self.emitter.indent(): - self.emitter.append(f"{gen_value} = {self(value.elems[-1])}") - else: - for index, elem in enumerate(value.elems): - if index == 0: - self.emitter.append(f"if {index} == {gen_index}:") - else: - self.emitter.append(f"elif {index} == {gen_index}:") - with self.emitter.indent(): - self.emitter.append(f"{gen_value} = {self(elem)}") - self.emitter.append(f"else:") - with self.emitter.indent(): - self.emitter.append(f"{gen_value} = {self(value.elems[-1])}") - - return gen_value - else: - return f"0" + def on_SwitchValue(self, value): + gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}") + gen_value = self.emitter.def_var("rhs_switch", "0") + def case_handler(patterns, elem): + self.emitter.append(f"{gen_value} = {self(elem)}") + self._emit_switch(gen_test, value.cases, case_handler) + return gen_value @classmethod def compile(cls, state, value, *, mode): @@ -323,34 +347,12 @@ class _LHSValueCompiler(_ValueCompiler): offset += len(part) return gen - def on_ArrayProxy(self, value): + def on_SwitchValue(self, value): def gen(arg): - index_mask = (1 << len(value.index)) - 1 - gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask:#x}") - if value.elems: - if _USE_PATTERN_MATCHING: - self.emitter.append(f"match {gen_index}:") - with self.emitter.indent(): - for index, elem in enumerate(value.elems): - self.emitter.append(f"case {index}:") - with self.emitter.indent(): - self(elem)(arg) - self.emitter.append("case _:") - with self.emitter.indent(): - self(value.elems[-1])(arg) - else: - for index, elem in enumerate(value.elems): - if index == 0: - self.emitter.append(f"if {index} == {gen_index}:") - else: - self.emitter.append(f"elif {index} == {gen_index}:") - with self.emitter.indent(): - self(elem)(arg) - self.emitter.append(f"else:") - with self.emitter.indent(): - self(value.elems[-1])(arg) - else: - self.emitter.append(f"pass") + gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self.rrhs(value.test)}") + def case_handler(patterns, elem): + self(elem)(arg) + self._emit_switch(gen_test, value.cases, case_handler) return gen @@ -396,27 +398,9 @@ class _StatementCompiler(StatementVisitor, _Compiler): def on_Switch(self, stmt): gen_test_value = self.rhs(stmt.test) # check for oversized value before generating mask gen_test = self.emitter.def_var("test", f"{(1 << len(stmt.test)) - 1:#x} & {gen_test_value}") - for index, (patterns, stmts, _src_loc) in enumerate(stmt.cases): - gen_checks = [] - if patterns is None: - gen_checks.append(f"True") - elif not patterns: - gen_checks.append(f"False") - else: - for pattern in patterns: - if "-" in pattern: - mask = int("".join("0" if b == "-" else "1" for b in pattern), 2) - value = int("".join("0" if b == "-" else b for b in pattern), 2) - gen_checks.append(f"{value} == ({mask} & {gen_test})") - else: - value = int(pattern or "0", 2) - gen_checks.append(f"{value} == {gen_test}") - if index == 0: - self.emitter.append(f"if {' or '.join(gen_checks)}:") - else: - self.emitter.append(f"elif {' or '.join(gen_checks)}:") - with self.emitter.indent(): - self(stmts) + def case_handler(pattern, stmt, src_loc): + self(stmt) + self._emit_switch(gen_test, stmt.cases, case_handler) def emit_format(self, format): format_string = [] diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 23e8e76..e17268a 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1144,6 +1144,7 @@ class ArrayProxyTestCase(FHDLTestCase): s = Signal(range(3)) v = a[s] self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))") + self.assertEqual(repr(v.as_value()), "(switch-value (sig s) (case 00 (const 1'd1)) (case 01 (const 2'd2)) (case 10 (const 2'd3)))") class SignalTestCase(FHDLTestCase): diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index 07464e3..8161bd2 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -1623,6 +1623,69 @@ class AssignTestCase(FHDLTestCase): ) """) + def test_switchvalue(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + s4 = Signal(8) + s5 = Signal(8) + s6 = Signal(8) + s7 = Signal(8) + f = Fragment() + f.add_statements("comb", [ + SwitchValue(s5[:4], [ + (1, s1), + ((2, 3), s2), + ((), s3), + ('11--', s4), + ]).eq(s6), + SwitchValue(s5[4:], [ + (4, s1), + (5, s2), + (6, s3), + (None, s4), + ]).eq(s7), + ]) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + f.add_driver(s3, "comb") + f.add_driver(s4, "comb") + nl = build_netlist(f, ports=[s1, s2, s3, s4, s5, s6, s7]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's5' 0.2:10) + (input 's6' 0.10:18) + (input 's7' 0.18:26) + (output 's1' 9.0:8) + (output 's2' 10.0:8) + (output 's3' 12.0:8) + (output 's4' 11.0:8) + ) + (cell 0 0 (top + (input 's5' 2:10) + (input 's6' 10:18) + (input 's7' 18:26) + (output 's1' 9.0:8) + (output 's2' 10.0:8) + (output 's3' 12.0:8) + (output 's4' 11.0:8) + )) + (cell 1 0 (matches 0.2:6 0001)) + (cell 2 0 (matches 0.2:6 0010 0011)) + (cell 3 0 (matches 0.2:6 11--)) + (cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0))) + (cell 5 0 (matches 0.6:10 0100)) + (cell 6 0 (matches 0.6:10 0101)) + (cell 7 0 (matches 0.6:10 0110)) + (cell 8 0 (priority_match 1 (cat 5.0 6.0 7.0 1'd1))) + (cell 9 0 (assignment_list 8'd0 (4.0 0:8 0.10:18) (8.0 0:8 0.18:26))) + (cell 10 0 (assignment_list 8'd0 (4.1 0:8 0.10:18) (8.1 0:8 0.18:26))) + (cell 11 0 (assignment_list 8'd0 (4.2 0:8 0.10:18) (8.3 0:8 0.18:26))) + (cell 12 0 (assignment_list 8'd0 (8.2 0:8 0.18:26))) + ) + """) + def test_sliced_slice(self): s1 = Signal(12) s2 = Signal(4) @@ -2953,7 +3016,7 @@ class RhsTestCase(FHDLTestCase): (cell 2 0 (matches 0.50:54 0001)) (cell 3 0 (matches 0.50:54 0010)) (cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0))) - (cell 5 0 (assignment_list 0.2:10 + (cell 5 0 (assignment_list 8'd0 (4.0 0:8 0.2:10) (4.1 0:8 0.10:18) (4.2 0:8 0.18:26) @@ -2962,7 +3025,7 @@ class RhsTestCase(FHDLTestCase): (cell 7 0 (matches 0.50:54 0001)) (cell 8 0 (matches 0.50:54 0010)) (cell 9 0 (priority_match 1 (cat 6.0 7.0 8.0))) - (cell 10 0 (assignment_list (cat 0.2:10 1'd0) + (cell 10 0 (assignment_list 9'd0 (9.0 0:9 (cat 0.2:10 1'd0)) (9.1 0:9 (cat 0.10:18 1'd0)) (9.2 0:9 (cat 0.42:50 0.49)) @@ -2971,7 +3034,7 @@ class RhsTestCase(FHDLTestCase): (cell 12 0 (matches 0.50:54 0001)) (cell 13 0 (matches 0.50:54 0010)) (cell 14 0 (priority_match 1 (cat 11.0 12.0 13.0))) - (cell 15 0 (assignment_list 0.26:34 + (cell 15 0 (assignment_list 8'd0 (14.0 0:8 0.26:34) (14.1 0:8 0.34:42) (14.2 0:8 0.42:50) @@ -2980,7 +3043,7 @@ class RhsTestCase(FHDLTestCase): (cell 17 0 (matches 0.50:54 0001)) (cell 18 0 (matches 0.50:54 0010)) (cell 19 0 (priority_match 1 (cat 16.0 17.0 18.0))) - (cell 20 0 (assignment_list 0.26:34 + (cell 20 0 (assignment_list 8'd0 (19.0 0:8 0.26:34) (19.1 0:8 0.34:42) (19.2 0:8 (cat 0.50:54 4'd0)) @@ -2988,6 +3051,67 @@ class RhsTestCase(FHDLTestCase): ) """) + def test_switchvalue(self): + i8ua = Signal(8) + i8ub = Signal(8) + i8uc = Signal(8) + i8ud = Signal(8) + i4 = Signal(4) + o1 = Signal(10) + o2 = Signal(10) + m = Module() + m.d.comb += o1.eq(SwitchValue(i4, [ + (1, i8ua), + ((2, 3), i8ub), + ('11--', i8uc), + ])) + m.d.comb += o2.eq(SwitchValue(i4, [ + ((4, 5), i8ua), + ((), i8ub), + ((6, 7), i8uc), + (None, i8ud), + ])) + nl = build_netlist(Fragment.get(m, None), [i8ua, i8ub, i8uc, i8ud, i4, o1, o2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 'i8ua' 0.2:10) + (input 'i8ub' 0.10:18) + (input 'i8uc' 0.18:26) + (input 'i8ud' 0.26:34) + (input 'i4' 0.34:38) + (output 'o1' (cat 5.0:8 2'd0)) + (output 'o2' (cat 9.0:8 2'd0)) + ) + (cell 0 0 (top + (input 'i8ua' 2:10) + (input 'i8ub' 10:18) + (input 'i8uc' 18:26) + (input 'i8ud' 26:34) + (input 'i4' 34:38) + (output 'o1' (cat 5.0:8 2'd0)) + (output 'o2' (cat 9.0:8 2'd0)) + )) + (cell 1 0 (matches 0.34:38 0001)) + (cell 2 0 (matches 0.34:38 0010 0011)) + (cell 3 0 (matches 0.34:38 11--)) + (cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0))) + (cell 5 0 (assignment_list 8'd0 + (4.0 0:8 0.2:10) + (4.1 0:8 0.10:18) + (4.2 0:8 0.18:26) + )) + (cell 6 0 (matches 0.34:38 0100 0101)) + (cell 7 0 (matches 0.34:38 0110 0111)) + (cell 8 0 (priority_match 1 (cat 6.0 7.0 1'd1))) + (cell 9 0 (assignment_list 8'd0 + (8.0 0:8 0.2:10) + (8.1 0:8 0.18:26) + (8.2 0:8 0.26:34) + )) + ) + """) + def test_anyvalue(self): o1 = Signal(12) o2 = Signal(12) diff --git a/tests/test_sim.py b/tests/test_sim.py index 9be4299..bb5d495 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -314,8 +314,8 @@ class SimulatorUnitTestCase(FHDLTestCase): def test_array_oob(self): array = Array([1, 4, 10]) stmt = lambda y, a: y.eq(array[a]) - self.assertStatement(stmt, [C(3)], C(10)) - self.assertStatement(stmt, [C(4)], C(10)) + self.assertStatement(stmt, [C(3)], C(0)) + self.assertStatement(stmt, [C(4)], C(0)) def test_array_lhs(self): l = Signal(3, init=1) @@ -333,8 +333,8 @@ class SimulatorUnitTestCase(FHDLTestCase): n = Signal(3) array = Array([l, m, n]) stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))] - self.assertStatement(stmt, [C(3), C(0b001)], C(0b001000000)) - self.assertStatement(stmt, [C(4), C(0b010)], C(0b010000000)) + self.assertStatement(stmt, [C(3), C(0b001)], C(0)) + self.assertStatement(stmt, [C(4), C(0b010)], C(0)) def test_array_index(self): array = Array(Array(x * y for y in range(10)) for x in range(10)) @@ -513,6 +513,8 @@ class SimulatorIntegrationTestCase(FHDLTestCase): with self.m.Switch(self.s): with self.m.Case(0): self.m.d.sync += self.o.eq(self.a + self.b) + with self.m.Case(): + self.m.d.sync += self.o.eq(self.a * self.b) with self.m.Case(1): self.m.d.sync += self.o.eq(self.a - self.b) with self.m.Default():