diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 57f102e..4bfa39a 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1681,10 +1681,6 @@ class Operator(Value): if self.operator == ">>": assert not b_shape.signed return Shape(a_shape.width, a_shape.signed) - elif len(op_shapes) == 3: - if self.operator == "m": - s_shape, a_shape, b_shape = op_shapes - return Shape._unify((a_shape, b_shape)) raise NotImplementedError # :nocov: def _lhs_signals(self): @@ -1715,7 +1711,7 @@ def Mux(sel, val1, val0): Value, out Output ``Value``. If ``sel`` is asserted, the Mux returns ``val1``, else ``val0``. """ - return Operator("m", [sel, val1, val0]) + return SwitchValue(sel, ((0, val0), (None, val1)), src_loc_at=1) @final diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index b4c3801..5046232 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -874,18 +874,6 @@ class NetlistEmitter: signed = False else: assert False # :nocov: - elif len(value.operands) == 3: - assert value.operator == 'm' - operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0]) - operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1]) - operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2]) - if len(operand_s) != 1: - operand_s = self.emit_operator(module_idx, 'b', operand_s, - src_loc=value.src_loc) - operand_a, operand_b, signed = \ - self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) - result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b, - src_loc=value.src_loc) else: assert False # :nocov: elif isinstance(value, _ast.Slice): @@ -901,39 +889,51 @@ class NetlistEmitter: signed = False elif isinstance(value, _ast.SwitchValue): test, _signed = self.emit_rhs(module_idx, value.test) - conds = [] - elems = [] - for patterns, elem, in value.cases: - if patterns is not None: - if not patterns: - # Hack: empty pattern set cannot be supported by RTLIL. - continue - for pattern in patterns: - assert len(pattern) == len(test) - cell = _nir.Matches(module_idx, value=test, patterns=patterns, - src_loc=value.src_loc) - net, = self.netlist.add_value_cell(1, cell) - conds.append(net) - else: - conds.append(_nir.Net.from_const(1)) - elems.append(self.emit_rhs(module_idx, elem)) - cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), - inputs=_nir.Value(conds), - src_loc=value.src_loc) - conds = self.netlist.add_value_cell(len(conds), cell) - shape = _ast.Shape._unify( - _ast.Shape(len(value), signed) - for value, signed in elems - ) - elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems) - assignments = [ - _nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc) - for subcond, elem in zip(conds, elems) - ] - cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width), - assignments=assignments, src_loc=value.src_loc) - result = self.netlist.add_value_cell(shape.width, cell) - signed = shape.signed + if (len(value.cases) == 2 and + value.cases[0][0] == ("0" * len(test),) and + value.cases[1][0] is None): + operand_a, signed_a = self.emit_rhs(module_idx, value.cases[1][1]) + operand_b, signed_b = self.emit_rhs(module_idx, value.cases[0][1]) + if len(test) != 1: + test = self.emit_operator(module_idx, 'b', test, src_loc=value.src_loc) + operand_a, operand_b, signed = \ + self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b) + result = self.emit_operator(module_idx, 'm', test, operand_a, operand_b, + src_loc=value.src_loc) + else: + conds = [] + elems = [] + for patterns, elem, in value.cases: + if patterns is not None: + if not patterns: + # Hack: empty pattern set cannot be supported by RTLIL. + continue + for pattern in patterns: + assert len(pattern) == len(test) + cell = _nir.Matches(module_idx, value=test, patterns=patterns, + src_loc=value.src_loc) + net, = self.netlist.add_value_cell(1, cell) + conds.append(net) + else: + conds.append(_nir.Net.from_const(1)) + elems.append(self.emit_rhs(module_idx, elem)) + cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), + inputs=_nir.Value(conds), + src_loc=value.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) + shape = _ast.Shape._unify( + _ast.Shape(len(value), signed) + for value, signed in elems + ) + elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems) + assignments = [ + _nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc) + for subcond, elem in zip(conds, elems) + ] + cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width), + assignments=assignments, src_loc=value.src_loc) + result = self.netlist.add_value_cell(shape.width, cell) + signed = shape.signed elif isinstance(value, _ast.Concat): nets = [] for val in value.parts: diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py index 045c1b8..72803b9 100644 --- a/amaranth/hdl/_nir.py +++ b/amaranth/hdl/_nir.py @@ -559,8 +559,8 @@ class Operator(Cell): The ternary operators are: - - 'm': like AST, first input needs to have width of 1, second and third operand need to have the same - width as output + - 'm': multiplexer, first input needs to have width of 1, second and third operand need to have + the same width as output; implements arg0 ? arg1 : arg2 Attributes ---------- diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 14e8b01..f6c4943 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -244,10 +244,6 @@ class _RHSValueCompiler(_ValueCompiler): return f"({sign(lhs)} > {sign(rhs)})" if value.operator == ">=": return f"({sign(lhs)} >= {sign(rhs)})" - elif len(value.operands) == 3: - if value.operator == "m": - sel, val1, val0 = value.operands - return f"({sign(val1)} if {mask(sel)} else {sign(val0)})" raise NotImplementedError(f"Operator '{value.operator}' not implemented") # :nocov: def on_Slice(self, value): @@ -274,7 +270,7 @@ class _RHSValueCompiler(_ValueCompiler): gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}") gen_value = self.emitter.def_var("rhs_switch", "0") def case_handler(patterns, elem): - self.emitter.append(f"{gen_value} = {self(elem)}") + self.emitter.append(f"{gen_value} = {self.sign(elem)}") self._emit_switch(gen_test, value.cases, case_handler) return gen_value diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index e17268a..583de3e 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -746,7 +746,7 @@ class OperatorTestCase(FHDLTestCase): def test_mux(self): s = Const(0) v1 = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6))) - self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))") + self.assertEqual(repr(v1), "(switch-value (const 1'd0) (case 0 (const 6'd0)) (default (const 4'd0)))") self.assertEqual(v1.shape(), unsigned(6)) v2 = Mux(s, Const(0, signed(4)), Const(0, signed(6))) self.assertEqual(v2.shape(), signed(6)) @@ -758,11 +758,11 @@ class OperatorTestCase(FHDLTestCase): def test_mux_wide(self): s = Const(0b100) v = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6))) - self.assertEqual(repr(v), "(m (const 3'd4) (const 4'd0) (const 6'd0))") + self.assertEqual(repr(v), "(switch-value (const 3'd4) (case 000 (const 6'd0)) (default (const 4'd0)))") def test_mux_bool(self): v = Mux(True, Const(0), Const(0)) - self.assertEqual(repr(v), "(m (const 1'd1) (const 1'd0) (const 1'd0))") + self.assertEqual(repr(v), "(switch-value (const 1'd1) (case 0 (const 1'd0)) (default (const 1'd0)))") def test_any(self): v = Const(0b101).any() @@ -842,7 +842,7 @@ class OperatorTestCase(FHDLTestCase): """) s = Signal(signed(4)) self.assertRepr(abs(s), """ - (slice (m (>= (sig s) (const 1'd0)) (sig s) (- (sig s))) 0:4) + (slice (switch-value (>= (sig s) (const 1'd0)) (case 0 (- (sig s))) (default (sig s))) 0:4) """) self.assertEqual(abs(s).shape(), unsigned(4)) diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 450939b..c913628 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -406,10 +406,10 @@ class EnableInserterTestCase(FHDLTestCase): mem.write_port(granularity=2) f = EnableInserter(self.c1)(mem).elaborate(platform=None) self.assertRepr(f._write_ports[0]._en, """ - (m + (switch-value (sig c1) - (sig mem_w_en) - (const 4'd0) + (case 0 (const 4'd0)) + (default (sig mem_w_en)) ) """)