hdl._ast: Implement Mux in terms of SwitchValue.

Fixes #1075.
This commit is contained in:
Wanda 2024-04-03 12:52:01 +02:00 committed by Catherine
parent 466536efcf
commit 606ebcd7a9
6 changed files with 56 additions and 64 deletions

View file

@ -1681,10 +1681,6 @@ class Operator(Value):
if self.operator == ">>": if self.operator == ">>":
assert not b_shape.signed assert not b_shape.signed
return Shape(a_shape.width, a_shape.signed) return Shape(a_shape.width, a_shape.signed)
elif len(op_shapes) == 3:
if self.operator == "m":
s_shape, a_shape, b_shape = op_shapes
return Shape._unify((a_shape, b_shape))
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
def _lhs_signals(self): def _lhs_signals(self):
@ -1715,7 +1711,7 @@ def Mux(sel, val1, val0):
Value, out Value, out
Output ``Value``. If ``sel`` is asserted, the Mux returns ``val1``, else ``val0``. Output ``Value``. If ``sel`` is asserted, the Mux returns ``val1``, else ``val0``.
""" """
return Operator("m", [sel, val1, val0]) return SwitchValue(sel, ((0, val0), (None, val1)), src_loc_at=1)
@final @final

View file

@ -874,18 +874,6 @@ class NetlistEmitter:
signed = False signed = False
else: else:
assert False # :nocov: assert False # :nocov:
elif len(value.operands) == 3:
assert value.operator == 'm'
operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0])
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1])
operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2])
if len(operand_s) != 1:
operand_s = self.emit_operator(module_idx, 'b', operand_s,
src_loc=value.src_loc)
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b,
src_loc=value.src_loc)
else: else:
assert False # :nocov: assert False # :nocov:
elif isinstance(value, _ast.Slice): elif isinstance(value, _ast.Slice):
@ -901,39 +889,51 @@ class NetlistEmitter:
signed = False signed = False
elif isinstance(value, _ast.SwitchValue): elif isinstance(value, _ast.SwitchValue):
test, _signed = self.emit_rhs(module_idx, value.test) test, _signed = self.emit_rhs(module_idx, value.test)
conds = [] if (len(value.cases) == 2 and
elems = [] value.cases[0][0] == ("0" * len(test),) and
for patterns, elem, in value.cases: value.cases[1][0] is None):
if patterns is not None: operand_a, signed_a = self.emit_rhs(module_idx, value.cases[1][1])
if not patterns: operand_b, signed_b = self.emit_rhs(module_idx, value.cases[0][1])
# Hack: empty pattern set cannot be supported by RTLIL. if len(test) != 1:
continue test = self.emit_operator(module_idx, 'b', test, src_loc=value.src_loc)
for pattern in patterns: operand_a, operand_b, signed = \
assert len(pattern) == len(test) self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
cell = _nir.Matches(module_idx, value=test, patterns=patterns, result = self.emit_operator(module_idx, 'm', test, operand_a, operand_b,
src_loc=value.src_loc) src_loc=value.src_loc)
net, = self.netlist.add_value_cell(1, cell) else:
conds.append(net) conds = []
else: elems = []
conds.append(_nir.Net.from_const(1)) for patterns, elem, in value.cases:
elems.append(self.emit_rhs(module_idx, elem)) if patterns is not None:
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), if not patterns:
inputs=_nir.Value(conds), # Hack: empty pattern set cannot be supported by RTLIL.
src_loc=value.src_loc) continue
conds = self.netlist.add_value_cell(len(conds), cell) for pattern in patterns:
shape = _ast.Shape._unify( assert len(pattern) == len(test)
_ast.Shape(len(value), signed) cell = _nir.Matches(module_idx, value=test, patterns=patterns,
for value, signed in elems src_loc=value.src_loc)
) net, = self.netlist.add_value_cell(1, cell)
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems) conds.append(net)
assignments = [ else:
_nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc) conds.append(_nir.Net.from_const(1))
for subcond, elem in zip(conds, elems) elems.append(self.emit_rhs(module_idx, elem))
] cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width), inputs=_nir.Value(conds),
assignments=assignments, src_loc=value.src_loc) src_loc=value.src_loc)
result = self.netlist.add_value_cell(shape.width, cell) conds = self.netlist.add_value_cell(len(conds), cell)
signed = shape.signed shape = _ast.Shape._unify(
_ast.Shape(len(value), signed)
for value, signed in elems
)
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
assignments = [
_nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc)
for subcond, elem in zip(conds, elems)
]
cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width),
assignments=assignments, src_loc=value.src_loc)
result = self.netlist.add_value_cell(shape.width, cell)
signed = shape.signed
elif isinstance(value, _ast.Concat): elif isinstance(value, _ast.Concat):
nets = [] nets = []
for val in value.parts: for val in value.parts:

View file

@ -559,8 +559,8 @@ class Operator(Cell):
The ternary operators are: The ternary operators are:
- 'm': like AST, first input needs to have width of 1, second and third operand need to have the same - 'm': multiplexer, first input needs to have width of 1, second and third operand need to have
width as output the same width as output; implements arg0 ? arg1 : arg2
Attributes Attributes
---------- ----------

View file

@ -244,10 +244,6 @@ class _RHSValueCompiler(_ValueCompiler):
return f"({sign(lhs)} > {sign(rhs)})" return f"({sign(lhs)} > {sign(rhs)})"
if value.operator == ">=": if value.operator == ">=":
return f"({sign(lhs)} >= {sign(rhs)})" return f"({sign(lhs)} >= {sign(rhs)})"
elif len(value.operands) == 3:
if value.operator == "m":
sel, val1, val0 = value.operands
return f"({sign(val1)} if {mask(sel)} else {sign(val0)})"
raise NotImplementedError(f"Operator '{value.operator}' not implemented") # :nocov: raise NotImplementedError(f"Operator '{value.operator}' not implemented") # :nocov:
def on_Slice(self, value): def on_Slice(self, value):
@ -274,7 +270,7 @@ class _RHSValueCompiler(_ValueCompiler):
gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}") gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}")
gen_value = self.emitter.def_var("rhs_switch", "0") gen_value = self.emitter.def_var("rhs_switch", "0")
def case_handler(patterns, elem): def case_handler(patterns, elem):
self.emitter.append(f"{gen_value} = {self(elem)}") self.emitter.append(f"{gen_value} = {self.sign(elem)}")
self._emit_switch(gen_test, value.cases, case_handler) self._emit_switch(gen_test, value.cases, case_handler)
return gen_value return gen_value

View file

@ -746,7 +746,7 @@ class OperatorTestCase(FHDLTestCase):
def test_mux(self): def test_mux(self):
s = Const(0) s = Const(0)
v1 = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6))) v1 = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6)))
self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))") self.assertEqual(repr(v1), "(switch-value (const 1'd0) (case 0 (const 6'd0)) (default (const 4'd0)))")
self.assertEqual(v1.shape(), unsigned(6)) self.assertEqual(v1.shape(), unsigned(6))
v2 = Mux(s, Const(0, signed(4)), Const(0, signed(6))) v2 = Mux(s, Const(0, signed(4)), Const(0, signed(6)))
self.assertEqual(v2.shape(), signed(6)) self.assertEqual(v2.shape(), signed(6))
@ -758,11 +758,11 @@ class OperatorTestCase(FHDLTestCase):
def test_mux_wide(self): def test_mux_wide(self):
s = Const(0b100) s = Const(0b100)
v = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6))) v = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6)))
self.assertEqual(repr(v), "(m (const 3'd4) (const 4'd0) (const 6'd0))") self.assertEqual(repr(v), "(switch-value (const 3'd4) (case 000 (const 6'd0)) (default (const 4'd0)))")
def test_mux_bool(self): def test_mux_bool(self):
v = Mux(True, Const(0), Const(0)) v = Mux(True, Const(0), Const(0))
self.assertEqual(repr(v), "(m (const 1'd1) (const 1'd0) (const 1'd0))") self.assertEqual(repr(v), "(switch-value (const 1'd1) (case 0 (const 1'd0)) (default (const 1'd0)))")
def test_any(self): def test_any(self):
v = Const(0b101).any() v = Const(0b101).any()
@ -842,7 +842,7 @@ class OperatorTestCase(FHDLTestCase):
""") """)
s = Signal(signed(4)) s = Signal(signed(4))
self.assertRepr(abs(s), """ self.assertRepr(abs(s), """
(slice (m (>= (sig s) (const 1'd0)) (sig s) (- (sig s))) 0:4) (slice (switch-value (>= (sig s) (const 1'd0)) (case 0 (- (sig s))) (default (sig s))) 0:4)
""") """)
self.assertEqual(abs(s).shape(), unsigned(4)) self.assertEqual(abs(s).shape(), unsigned(4))

View file

@ -406,10 +406,10 @@ class EnableInserterTestCase(FHDLTestCase):
mem.write_port(granularity=2) mem.write_port(granularity=2)
f = EnableInserter(self.c1)(mem).elaborate(platform=None) f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f._write_ports[0]._en, """ self.assertRepr(f._write_ports[0]._en, """
(m (switch-value
(sig c1) (sig c1)
(sig mem_w_en) (case 0 (const 4'd0))
(const 4'd0) (default (sig mem_w_en))
) )
""") """)