diff --git a/amaranth/build/plat.py b/amaranth/build/plat.py index 12d2bd3..4880f67 100644 --- a/amaranth/build/plat.py +++ b/amaranth/build/plat.py @@ -10,7 +10,7 @@ from .. import __version__ from .._toolchain import * from ..hdl import * from ..hdl._ir import IOBufferInstance, Design -from ..hdl._xfrm import DomainLowerer, AssignmentLegalizer +from ..hdl._xfrm import DomainLowerer from ..lib.cdc import ResetSynchronizer from ..back import rtlil, verilog from .res import * @@ -165,7 +165,6 @@ class Platform(ResourceManager, metaclass=ABCMeta): add_pin_fragment(pin, self.get_diff_input_output(pin, port, attrs, invert)) ports = [(None, signal, None) for signal in self.iter_ports()] - fragment = AssignmentLegalizer()(fragment) fragment = Design(fragment, ports, hierarchy=(name,)) return self.toolchain_prepare(fragment, name, **kwargs) diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 733e276..5f63f56 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -386,7 +386,7 @@ class Fragment: return new_ports - def prepare(self, ports=(), *, hierarchy=("top",), legalize_assignments=False, missing_domain=lambda name: _cd.ClockDomain(name), propagate_domains=True): + def prepare(self, ports=(), *, hierarchy=("top",), missing_domain=lambda name: _cd.ClockDomain(name), propagate_domains=True): from ._xfrm import DomainLowerer ports = self._prepare_ports(ports) @@ -416,9 +416,6 @@ class Fragment: ] fragment = DomainLowerer()(self) - if legalize_assignments: - from ._xfrm import AssignmentLegalizer - fragment = AssignmentLegalizer()(fragment) # Create design and let it do the rest. return Design(fragment, ports, hierarchy=hierarchy) @@ -698,10 +695,7 @@ class NetlistEmitter: except KeyError: pass if isinstance(value, _ast.Const): - result = _nir.Value( - _nir.Net.from_const((value.value >> bit) & 1) - for bit in range(value.width) - ) + result = _nir.Value.from_const(value.value, value.width) signed = value.signed elif isinstance(value, _ast.Signal): result = self.emit_signal(value) @@ -887,6 +881,95 @@ class NetlistEmitter: self.netlist.connections[left] = right self.connect_src_loc[left] = src_loc + def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.Value, lhs_start: int, rhs: _nir.Value, cond: _nir.Net, *, src_loc): + # Assign rhs to lhs[lhs_start:lhs_start+len(rhs)] + if isinstance(lhs, _ast.Signal): + if lhs in self.drivers: + driver = self.drivers[lhs] + if driver.domain is not cd: + domain_name = cd.name if cd is not None else "comb" + other_domain_name = driver.domain.name if driver.domain is not None else "comb" + raise _ir.DriverConflict( + f"Signal {lhs} driven from domain {domain_name} at {src_loc} and domain " + f"{other_domain_name} at {driver.src_loc}") + if driver.module_idx != module_idx: + mod_name = ".".join(self.netlist.modules[module_idx].name or ("",)) + other_mod_name = \ + ".".join(self.netlist.modules[driver.module_idx].name or ("",)) + raise _ir.DriverConflict( + f"Signal {lhs} driven from module {mod_name} at {src_loc} and " + f"module {other_mod_name} at {driver.src_loc}") + else: + driver = NetlistDriver(module_idx, lhs, domain=cd, src_loc=src_loc) + self.drivers[lhs] = driver + driver.assignments.append(_nir.Assignment(cond=cond, start=lhs_start, value=rhs, + src_loc=src_loc)) + elif isinstance(lhs, _ast.Slice): + self.emit_assign(module_idx, cd, lhs.value, lhs_start + lhs.start, rhs, cond, src_loc=src_loc) + elif isinstance(lhs, _ast.Cat): + part_stop = 0 + for part in lhs.parts: + part_start = part_stop + part_len = len(part) + part_stop = part_start + part_len + if lhs_start >= part_stop: + continue + if lhs_start + len(rhs) <= part_start: + continue + if lhs_start < part_start: + part_lhs_start = 0 + part_rhs_start = part_start - lhs_start + else: + part_lhs_start = lhs_start - part_start + part_rhs_start = 0 + if lhs_start + len(rhs) >= part_stop: + part_rhs_stop = part_stop - lhs_start + else: + part_rhs_stop = len(rhs) + self.emit_assign(module_idx, cd, part, part_lhs_start, rhs[part_rhs_start:part_rhs_stop], cond, src_loc=src_loc) + elif isinstance(lhs, _ast.Part): + offset, _signed = self.emit_rhs(module_idx, lhs.offset) + width = len(lhs.value) + num_cases = min((width + lhs.stride - 1) // lhs.stride, 1 << len(offset)) + conds = [] + for case_index in range(num_cases): + cell = _nir.Matches(module_idx, value=offset, + patterns=(f"{case_index:0{len(offset)}b}",), + src_loc=lhs.src_loc) + subcond, = self.netlist.add_value_cell(1, cell) + conds.append(subcond) + conds = _nir.Value(conds) + cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) + for idx, subcond in enumerate(conds): + start = lhs_start + idx * lhs.stride + if start >= width: + continue + if start + len(rhs) >= width: + subrhs = rhs[:width - start] + else: + subrhs = rhs + self.emit_assign(module_idx, cd, lhs.value, start, subrhs, subcond, src_loc=src_loc) + elif isinstance(lhs, _ast.ArrayProxy): + index, _signed = self.emit_rhs(module_idx, lhs.index) + conds = [] + for case_index in range(len(lhs.elems)): + cell = _nir.Matches(module_idx, value=index, + patterns=(f"{case_index:0{len(index)}b}",), + src_loc=lhs.src_loc) + subcond, = self.netlist.add_value_cell(1, cell) + conds.append(subcond) + conds = _nir.Value(conds) + cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc) + conds = self.netlist.add_value_cell(len(conds), cell) + for subcond, val in zip(conds, lhs.elems): + self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc) + elif isinstance(lhs, _ast.Operator): + assert lhs.operator in ('u', 's') + self.emit_assign(module_idx, cd, lhs.operands[0], lhs_start, rhs, cond, src_loc=src_loc) + else: + assert False # :nocov: + def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str, stmt: _ast.Statement, cond: _nir.Net): if domain == "comb": @@ -894,42 +977,13 @@ class NetlistEmitter: else: cd = fragment.domains[domain] if isinstance(stmt, _ast.Assign): - if isinstance(stmt.lhs, _ast.Signal): - signal = stmt.lhs - start = 0 - width = signal.width - elif isinstance(stmt.lhs, _ast.Slice): - signal = stmt.lhs.value - start = stmt.lhs.start - width = stmt.lhs.stop - stmt.lhs.start - else: - assert False # :nocov: - assert isinstance(signal, _ast.Signal) - if signal in self.drivers: - driver = self.drivers[signal] - if driver.domain is not cd: - domain_name = cd.name if cd is not None else "comb" - other_domain_name = driver.domain.name if driver.domain is not None else "comb" - raise _ir.DriverConflict( - f"Signal {signal} driven from domain {domain_name} at {stmt.src_loc} and domain " - f"{other_domain_name} at {driver.src_loc}") - if driver.module_idx != module_idx: - mod_name = ".".join(self.netlist.modules[module_idx].name or ("",)) - other_mod_name = \ - ".".join(self.netlist.modules[driver.module_idx].name or ("",)) - raise _ir.DriverConflict( - f"Signal {signal} driven from module {mod_name} at {stmt.src_loc} and " - f"module {other_mod_name} at {driver.src_loc}") - else: - driver = NetlistDriver(module_idx, signal, domain=cd, src_loc=stmt.src_loc) - self.drivers[signal] = driver rhs, signed = self.emit_rhs(module_idx, stmt.rhs) + width = len(stmt.lhs) if len(rhs) > width: rhs = _nir.Value(rhs[:width]) if len(rhs) < width: rhs = self.extend(rhs, signed, width) - driver.assignments.append(_nir.Assignment(cond=cond, start=start, value=rhs, - src_loc=stmt.src_loc)) + self.emit_assign(module_idx, cd, stmt.lhs, 0, rhs, cond, src_loc=stmt.src_loc) elif isinstance(stmt, _ast.Property): test, _signed = self.emit_rhs(module_idx, stmt.test) if len(test) != 1: @@ -1374,7 +1428,7 @@ def build_netlist(fragment, ports=(), *, name="top", **kwargs): if isinstance(fragment, Design): design = fragment else: - design = fragment.prepare(ports=ports, hierarchy=(name,), legalize_assignments=True, **kwargs) + design = fragment.prepare(ports=ports, hierarchy=(name,), **kwargs) netlist = _nir.Netlist() _emit_netlist(netlist, design) netlist.resolve_all_nets() diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py index 9fa8e0e..264b317 100644 --- a/amaranth/hdl/_nir.py +++ b/amaranth/hdl/_nir.py @@ -88,13 +88,23 @@ class Value(tuple): return super().__new__(cls, (nets,)) return super().__new__(cls, (Net.ensure(net) for net in nets)) + @classmethod + def from_const(cls, value, width): + return cls(Net.from_const((value >> bit) & 1) for bit in range(width)) + @classmethod def zeros(cls, digits=1): - return cls(Net.from_const(0) for _ in range(digits)) + return cls.from_const(0, digits) @classmethod def ones(cls, digits=1): - return cls(Net.from_const(1) for _ in range(digits)) + return cls.from_const(-1, digits) + + def __getitem__(self, index): + if isinstance(index, slice): + return type(self)(super().__getitem__(index)) + else: + return super().__getitem__(index) def __repr__(self): pos = 0 diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index 8a02338..e60be23 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -16,7 +16,7 @@ __all__ = ["ValueVisitor", "ValueTransformer", "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", - "ResetInserter", "EnableInserter", "AssignmentLegalizer"] + "ResetInserter", "EnableInserter"] class ValueVisitor(metaclass=ABCMeta): @@ -603,85 +603,3 @@ class EnableInserter(_ControlInserter): if port._domain in self.controls: port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en))) return new_fragment - - -class AssignmentLegalizer(FragmentTransformer, StatementTransformer): - """Ensures all assignments in switches have one of the following on the LHS: - - - a `Signal` - - a `Slice` with `value` that is a `Signal` - """ - def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None): - if isinstance(lhs, ArrayProxy): - # Lower into a switch. - cases = {} - for idx, val in enumerate(lhs.elems): - cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop) - return [Switch(lhs.index, cases)] - elif isinstance(lhs, Part): - offset = lhs.offset - width = lhs.width - if lhs_start != 0: - width -= lhs_start - if lhs_stop is not None: - width = lhs_stop - lhs_start - cases = {} - lhs_width = len(lhs.value) - for idx in range(lhs_width): - start = lhs_start + idx * lhs.stride - if start >= lhs_width: - break - stop = min(start + width, lhs_width) - cases[idx] = self.emit_assign(lhs.value, rhs, start, stop) - return [Switch(offset, cases)] - elif isinstance(lhs, Slice): - part_start = lhs_start + lhs.start - if lhs_stop is not None: - part_stop = lhs_stop + lhs.start - else: - part_stop = lhs_start + lhs.stop - return self.emit_assign(lhs.value, rhs, part_start, part_stop) - elif isinstance(lhs, Cat): - # Split into several assignments. - part_stop = 0 - res = [] - if lhs_stop is None: - lhs_len = len(lhs) - lhs_start - else: - lhs_len = lhs_stop - lhs_start - if len(rhs) < lhs_len: - rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed)) - for val in lhs.parts: - part_start = part_stop - part_len = len(val) - part_stop = part_start + part_len - if lhs_start >= part_stop: - continue - if lhs_start < part_start: - part_lhs_start = 0 - part_rhs_start = part_start - lhs_start - else: - part_lhs_start = lhs_start - part_start - part_rhs_start = 0 - if lhs_stop is not None and lhs_stop <= part_start: - continue - elif lhs_stop is None or lhs_stop >= part_stop: - part_lhs_stop = None - else: - part_lhs_stop = lhs_stop - part_start - res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop) - return res - elif isinstance(lhs, Signal): - # Already ok. - if lhs_start != 0 or lhs_stop is not None: - return [Assign(lhs[lhs_start:lhs_stop], rhs)] - else: - return [Assign(lhs, rhs)] - elif isinstance(lhs, Operator): - assert lhs.operator in ('u', 's') - return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop) - else: - raise TypeError - - def on_Assign(self, stmt): - return self.emit_assign(stmt.lhs, stmt.rhs) diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index a89879b..b0ca970 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -1059,3 +1059,699 @@ class IOBufferTestCase(FHDLTestCase): with self.assertRaisesRegex(ValueError, r"^`oe` must not be used if `o` is not used"): IOBufferInstance(pad, oe=oe) + + +class AssignTestCase(FHDLTestCase): + def test_simple(self): + s1 = Signal(8) + s2 = Signal(8) + f = Fragment() + f.add_statements( + "comb", + s1.eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:10) + (output 's1' 0.2:10) + ) + (cell 0 0 (top + (output 's1' 0.2:10) + (input 's2' 2:10) + )) + ) + """) + + def test_simple_trunc(self): + s1 = Signal(8) + s2 = Signal(10) + f = Fragment() + f.add_statements( + "comb", + s1.eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:12) + (output 's1' 0.2:10) + ) + (cell 0 0 (top + (output 's1' 0.2:10) + (input 's2' 2:12) + )) + ) + """) + + def test_simple_zext(self): + s1 = Signal(8) + s2 = Signal(6) + f = Fragment() + f.add_statements( + "comb", + s1.eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:8) + (output 's1' (cat 0.2:8 2'd0)) + ) + (cell 0 0 (top + (output 's1' (cat 0.2:8 2'd0)) + (input 's2' 2:8) + )) + ) + """) + + def test_simple_sext(self): + s1 = Signal(8) + s2 = Signal(signed(6)) + f = Fragment() + f.add_statements( + "comb", + s1.eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:8) + (output 's1' (cat 0.2:8 0.7 0.7)) + ) + (cell 0 0 (top + (output 's1' (cat 0.2:8 0.7 0.7)) + (input 's2' 2:8) + )) + ) + """) + + def test_simple_slice(self): + s1 = Signal(8) + s2 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1[2:6].eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (output 's1' 1.0:8) + ) + (cell 0 0 (top + (output 's1' 1.0:8) + (input 's2' 2:6) + )) + (cell 1 0 (assignment_list 8'd0 (1 2:6 0.2:6))) + ) + """) + + def test_simple_part(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1.bit_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 10.0:8) + ) + (cell 0 0 (top + (output 's1' 10.0:8) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (matches 0.6:10 0010)) + (cell 4 0 (matches 0.6:10 0011)) + (cell 5 0 (matches 0.6:10 0100)) + (cell 6 0 (matches 0.6:10 0101)) + (cell 7 0 (matches 0.6:10 0110)) + (cell 8 0 (matches 0.6:10 0111)) + (cell 9 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0))) + (cell 10 0 (assignment_list 8'd0 + (9.0 0:4 0.2:6) + (9.1 1:5 0.2:6) + (9.2 2:6 0.2:6) + (9.3 3:7 0.2:6) + (9.4 4:8 0.2:6) + (9.5 5:8 0.2:5) + (9.6 6:8 0.2:4) + (9.7 7:8 0.2) + )) + ) + """) + + def test_simple_part_short(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(2) + f = Fragment() + f.add_statements( + "comb", + s1.bit_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:8) + (output 's1' 6.0:8) + ) + (cell 0 0 (top + (output 's1' 6.0:8) + (input 's2' 2:6) + (input 's3' 6:8) + )) + (cell 1 0 (matches 0.6:8 00)) + (cell 2 0 (matches 0.6:8 01)) + (cell 3 0 (matches 0.6:8 10)) + (cell 4 0 (matches 0.6:8 11)) + (cell 5 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0))) + (cell 6 0 (assignment_list 8'd0 + (5.0 0:4 0.2:6) + (5.1 1:5 0.2:6) + (5.2 2:6 0.2:6) + (5.3 3:7 0.2:6) + )) + ) + """) + + def test_simple_part_word(self): + s1 = Signal(16) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1.word_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 6.0:16) + ) + (cell 0 0 (top + (output 's1' 6.0:16) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (matches 0.6:10 0010)) + (cell 4 0 (matches 0.6:10 0011)) + (cell 5 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0))) + (cell 6 0 (assignment_list 16'd0 + (5.0 0:4 0.2:6) + (5.1 4:8 0.2:6) + (5.2 8:12 0.2:6) + (5.3 12:16 0.2:6) + )) + ) + """) + + def test_simple_part_word_misalign(self): + s1 = Signal(17) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1.word_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 7.0:17) + ) + (cell 0 0 (top + (output 's1' 7.0:17) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (matches 0.6:10 0010)) + (cell 4 0 (matches 0.6:10 0011)) + (cell 5 0 (matches 0.6:10 0100)) + (cell 6 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0))) + (cell 7 0 (assignment_list 17'd0 + (6.0 0:4 0.2:6) + (6.1 4:8 0.2:6) + (6.2 8:12 0.2:6) + (6.3 12:16 0.2:6) + (6.4 16:17 0.2) + )) + ) + """) + + def test_simple_concat(self): + s1 = Signal(4) + s2 = Signal(4) + s3 = Signal(4) + s4 = Signal(12) + f = Fragment() + f.add_statements( + "comb", + Cat(s1, s2, s3).eq(s4) + ) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + f.add_driver(s3, "comb") + nl = build_netlist(f, ports=[s1, s2, s3, s4]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's4' 0.2:14) + (output 's1' 0.2:6) + (output 's2' 0.6:10) + (output 's3' 0.10:14) + ) + (cell 0 0 (top + (output 's1' 0.2:6) + (output 's2' 0.6:10) + (output 's3' 0.10:14) + (input 's4' 2:14) + )) + ) + """) + + def test_simple_concat_narrow(self): + s1 = Signal(4) + s2 = Signal(4) + s3 = Signal(4) + s4 = Signal(signed(6)) + f = Fragment() + f.add_statements( + "comb", + Cat(s1, s2, s3).eq(s4) + ) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + f.add_driver(s3, "comb") + nl = build_netlist(f, ports=[s1, s2, s3, s4]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's4' 0.2:8) + (output 's1' 0.2:6) + (output 's2' (cat 0.6:8 0.7 0.7)) + (output 's3' (cat 0.7 0.7 0.7 0.7)) + ) + (cell 0 0 (top + (output 's1' 0.2:6) + (output 's2' (cat 0.6:8 0.7 0.7)) + (output 's3' (cat 0.7 0.7 0.7 0.7)) + (input 's4' 2:8) + )) + ) + """) + + def test_simple_operator(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + f = Fragment() + f.add_statements("comb", [ + s1.as_signed().eq(s3), + s2.as_unsigned().eq(s3), + ]) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's3' 0.2:10) + (output 's1' 0.2:10) + (output 's2' 0.2:10) + ) + (cell 0 0 (top + (output 's1' 0.2:10) + (output 's2' 0.2:10) + (input 's3' 2:10) + )) + ) + """) + + def test_simple_array(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + s4 = Signal(8) + s5 = Signal(8) + f = Fragment() + f.add_statements("comb", [ + Array([s1, s2, s3])[s4].eq(s5), + ]) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + f.add_driver(s3, "comb") + nl = build_netlist(f, ports=[s1, s2, s3, s4, s5]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's4' 0.2:10) + (input 's5' 0.10:18) + (output 's1' 5.0:8) + (output 's2' 6.0:8) + (output 's3' 7.0:8) + ) + (cell 0 0 (top + (output 's1' 5.0:8) + (output 's2' 6.0:8) + (output 's3' 7.0:8) + (input 's4' 2:10) + (input 's5' 10:18) + )) + (cell 1 0 (matches 0.2:10 00000000)) + (cell 2 0 (matches 0.2:10 00000001)) + (cell 3 0 (matches 0.2:10 00000010)) + (cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0))) + (cell 5 0 (assignment_list 8'd0 (4.0 0:8 0.10:18))) + (cell 6 0 (assignment_list 8'd0 (4.1 0:8 0.10:18))) + (cell 7 0 (assignment_list 8'd0 (4.2 0:8 0.10:18))) + ) + """) + + def test_sliced_slice(self): + s1 = Signal(12) + s2 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1[1:11][2:6].eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (output 's1' 1.0:12) + ) + (cell 0 0 (top + (output 's1' 1.0:12) + (input 's2' 2:6) + )) + (cell 1 0 (assignment_list 12'd0 (1 3:7 0.2:6))) + ) + """) + + def test_sliced_concat(self): + s1 = Signal(4) + s2 = Signal(4) + s3 = Signal(4) + s4 = Signal(4) + s5 = Signal(4) + s6 = Signal(8) + f = Fragment() + f.add_statements( + "comb", + Cat(s1, s2, s3, s4, s5)[5:14].eq(s6) + ) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + f.add_driver(s3, "comb") + f.add_driver(s4, "comb") + f.add_driver(s5, "comb") + nl = build_netlist(f, ports=[s1, s2, s3, s4, s5, s6]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's1' 0.2:6) + (input 's5' 0.6:10) + (input 's6' 0.10:18) + (output 's2' 1.0:4) + (output 's3' 0.13:17) + (output 's4' 2.0:4) + ) + (cell 0 0 (top + (output 's2' 1.0:4) + (output 's3' 0.13:17) + (output 's4' 2.0:4) + (input 's1' 2:6) + (input 's5' 6:10) + (input 's6' 10:18) + )) + (cell 1 0 (assignment_list 4'd0 (1 1:4 0.10:13))) + (cell 2 0 (assignment_list 4'd0 (1 0:2 (cat 0.17 1'd0)))) + ) + """) + + def test_sliced_part(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1.bit_select(s3, 6)[2:4].eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 10.0:8) + ) + (cell 0 0 (top + (output 's1' 10.0:8) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (matches 0.6:10 0010)) + (cell 4 0 (matches 0.6:10 0011)) + (cell 5 0 (matches 0.6:10 0100)) + (cell 6 0 (matches 0.6:10 0101)) + (cell 7 0 (matches 0.6:10 0110)) + (cell 8 0 (matches 0.6:10 0111)) + (cell 9 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0))) + (cell 10 0 (assignment_list 8'd0 + (9.0 2:4 0.2:4) + (9.1 3:5 0.2:4) + (9.2 4:6 0.2:4) + (9.3 5:7 0.2:4) + (9.4 6:8 0.2:4) + (9.5 7:8 0.2) + )) + ) + """) + + def test_sliced_part_word(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1.word_select(s3, 4)[1:3].eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 4.0:8) + ) + (cell 0 0 (top + (output 's1' 4.0:8) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (priority_match 1 (cat 1.0 2.0))) + (cell 4 0 (assignment_list 8'd0 + (3.0 1:3 0.2:4) + (3.1 5:7 0.2:4) + )) + ) + """) + + def test_sliced_array(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + s4 = Signal(8) + s5 = Signal(8) + f = Fragment() + f.add_statements("comb", [ + Array([s1, s2, s3])[s4][2:7].eq(s5), + ]) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + f.add_driver(s3, "comb") + nl = build_netlist(f, ports=[s1, s2, s3, s4, s5]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's4' 0.2:10) + (input 's5' 0.10:18) + (output 's1' 5.0:8) + (output 's2' 6.0:8) + (output 's3' 7.0:8) + ) + (cell 0 0 (top + (output 's1' 5.0:8) + (output 's2' 6.0:8) + (output 's3' 7.0:8) + (input 's4' 2:10) + (input 's5' 10:18) + )) + (cell 1 0 (matches 0.2:10 00000000)) + (cell 2 0 (matches 0.2:10 00000001)) + (cell 3 0 (matches 0.2:10 00000010)) + (cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0))) + (cell 5 0 (assignment_list 8'd0 (4.0 2:7 0.10:15))) + (cell 6 0 (assignment_list 8'd0 (4.1 2:7 0.10:15))) + (cell 7 0 (assignment_list 8'd0 (4.2 2:7 0.10:15))) + ) + """) + + def test_part_slice(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1[1:7].bit_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 8.0:8) + ) + (cell 0 0 (top + (output 's1' 8.0:8) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (matches 0.6:10 0010)) + (cell 4 0 (matches 0.6:10 0011)) + (cell 5 0 (matches 0.6:10 0100)) + (cell 6 0 (matches 0.6:10 0101)) + (cell 7 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0))) + (cell 8 0 (assignment_list 8'd0 + (7.0 1:5 0.2:6) + (7.1 2:6 0.2:6) + (7.2 3:7 0.2:6) + (7.3 4:7 0.2:5) + (7.4 5:7 0.2:4) + (7.5 6:7 0.2) + )) + ) + """) + + def test_sliced_part_slice(self): + s1 = Signal(12) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "comb", + s1[3:9].bit_select(s3, 4)[1:3].eq(s2) + ) + f.add_driver(s1, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's2' 0.2:6) + (input 's3' 0.6:10) + (output 's1' 8.0:12) + ) + (cell 0 0 (top + (output 's1' 8.0:12) + (input 's2' 2:6) + (input 's3' 6:10) + )) + (cell 1 0 (matches 0.6:10 0000)) + (cell 2 0 (matches 0.6:10 0001)) + (cell 3 0 (matches 0.6:10 0010)) + (cell 4 0 (matches 0.6:10 0011)) + (cell 5 0 (matches 0.6:10 0100)) + (cell 6 0 (matches 0.6:10 0101)) + (cell 7 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0))) + (cell 8 0 (assignment_list 12'd0 + (7.0 4:6 0.2:4) + (7.1 5:7 0.2:4) + (7.2 6:8 0.2:4) + (7.3 7:9 0.2:4) + (7.4 8:9 0.2) + )) + ) + """) + + def test_sliced_operator(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + f = Fragment() + f.add_statements("comb", [ + s1.as_signed()[2:7].eq(s3), + s2.as_unsigned()[2:7].eq(s3), + ]) + f.add_driver(s1, "comb") + f.add_driver(s2, "comb") + nl = build_netlist(f, ports=[s1, s2, s3]) + self.assertRepr(nl, """ + ( + (module 0 None ('top') + (input 's3' 0.2:10) + (output 's1' 1.0:8) + (output 's2' 2.0:8) + ) + (cell 0 0 (top + (output 's1' 1.0:8) + (output 's2' 2.0:8) + (input 's3' 2:10) + )) + (cell 1 0 (assignment_list 8'd0 (1 2:7 0.2:7))) + (cell 2 0 (assignment_list 8'd0 (1 2:7 0.2:7))) + ) + """) diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 0fc7fb7..b13fad2 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -1,7 +1,5 @@ # amaranth: UnusedElaboratable=no -import warnings - from amaranth.hdl._ast import * from amaranth.hdl._cd import * from amaranth.hdl._dsl import * @@ -420,329 +418,6 @@ class EnableInserterTestCase(FHDLTestCase): """) -class AssignmentLegalizerTestCase(FHDLTestCase): - def test_simple(self): - s1 = Signal(8) - s2 = Signal(8) - f = Fragment() - f.add_statements( - "sync", - s1.eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((eq (sig s1) (sig s2))) - """) - - def test_simple_slice(self): - s1 = Signal(8) - s2 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1[2:6].eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((eq (slice (sig s1) 2:6) (sig s2))) - """) - - def test_simple_part(self): - s1 = Signal(8) - s2 = Signal(4) - s3 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1.bit_select(s3, 4).eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s3) - (case 0000 (eq (slice (sig s1) 0:4) (sig s2))) - (case 0001 (eq (slice (sig s1) 1:5) (sig s2))) - (case 0010 (eq (slice (sig s1) 2:6) (sig s2))) - (case 0011 (eq (slice (sig s1) 3:7) (sig s2))) - (case 0100 (eq (slice (sig s1) 4:8) (sig s2))) - (case 0101 (eq (slice (sig s1) 5:8) (sig s2))) - (case 0110 (eq (slice (sig s1) 6:8) (sig s2))) - (case 0111 (eq (slice (sig s1) 7:8) (sig s2))) - )) - """) - - def test_simple_part_word(self): - s1 = Signal(8) - s2 = Signal(4) - s3 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1.word_select(s3, 4).eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s3) - (case 0000 (eq (slice (sig s1) 0:4) (sig s2))) - (case 0001 (eq (slice (sig s1) 4:8) (sig s2))) - )) - """) - - def test_simple_concat(self): - s1 = Signal(4) - s2 = Signal(4) - s3 = Signal(4) - s4 = Signal(12) - f = Fragment() - f.add_statements( - "sync", - Cat(s1, s2, s3).eq(s4) - ) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f.add_driver(s3, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ( - (eq (sig s1) (slice (sig s4) 0:12)) - (eq (sig s2) (slice (sig s4) 4:12)) - (eq (sig s3) (slice (sig s4) 8:12)) - ) - """) - - def test_simple_concat_narrow(self): - s1 = Signal(4) - s2 = Signal(4) - s3 = Signal(4) - s4 = Signal(signed(6)) - f = Fragment() - f.add_statements( - "sync", - Cat(s1, s2, s3).eq(s4) - ) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f.add_driver(s3, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ( - (eq (sig s1) (slice (| (sig s4) (const 12'sd0)) 0:12)) - (eq (sig s2) (slice (| (sig s4) (const 12'sd0)) 4:12)) - (eq (sig s3) (slice (| (sig s4) (const 12'sd0)) 8:12)) - ) - """) - - def test_simple_operator(self): - s1 = Signal(8) - s2 = Signal(8) - s3 = Signal(8) - f = Fragment() - f.add_statements("sync", [ - s1.as_signed().eq(s3), - s2.as_unsigned().eq(s3), - ]) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ( - (eq (sig s1) (sig s3)) - (eq (sig s2) (sig s3)) - ) - """) - - def test_simple_array(self): - s1 = Signal(8) - s2 = Signal(8) - s3 = Signal(8) - s4 = Signal(8) - s5 = Signal(8) - f = Fragment() - f.add_statements("sync", [ - Array([s1, s2, s3])[s4].eq(s5), - ]) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f.add_driver(s3, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s4) - (case 00000000 (eq (sig s1) (sig s5))) - (case 00000001 (eq (sig s2) (sig s5))) - (case 00000010 (eq (sig s3) (sig s5))) - )) - """) - - def test_sliced_slice(self): - s1 = Signal(12) - s2 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1[1:11][2:6].eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((eq (slice (sig s1) 3:7) (sig s2))) - """) - - def test_sliced_concat(self): - s1 = Signal(4) - s2 = Signal(4) - s3 = Signal(4) - s4 = Signal(4) - s5 = Signal(4) - s6 = Signal(8) - f = Fragment() - f.add_statements( - "sync", - Cat(s1, s2, s3, s4, s5)[5:14].eq(s6) - ) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f.add_driver(s3, "sync") - f.add_driver(s4, "sync") - f.add_driver(s5, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ( - (eq (slice (sig s2) 1:4) (slice (| (sig s6) (const 9'd0)) 0:9)) - (eq (sig s3) (slice (| (sig s6) (const 9'd0)) 3:9)) - (eq (slice (sig s4) 0:2) (slice (| (sig s6) (const 9'd0)) 7:9)) - ) - """) - - def test_sliced_part(self): - s1 = Signal(8) - s2 = Signal(4) - s3 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1.bit_select(s3, 4)[1:3].eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s3) - (case 0000 (eq (slice (sig s1) 1:3) (sig s2))) - (case 0001 (eq (slice (sig s1) 2:4) (sig s2))) - (case 0010 (eq (slice (sig s1) 3:5) (sig s2))) - (case 0011 (eq (slice (sig s1) 4:6) (sig s2))) - (case 0100 (eq (slice (sig s1) 5:7) (sig s2))) - (case 0101 (eq (slice (sig s1) 6:8) (sig s2))) - (case 0110 (eq (slice (sig s1) 7:8) (sig s2))) - )) - """) - - def test_sliced_part_word(self): - s1 = Signal(8) - s2 = Signal(4) - s3 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1.word_select(s3, 4)[1:3].eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s3) - (case 0000 (eq (slice (sig s1) 1:3) (sig s2))) - (case 0001 (eq (slice (sig s1) 5:7) (sig s2))) - )) - """) - - def test_sliced_array(self): - s1 = Signal(8) - s2 = Signal(8) - s3 = Signal(8) - s4 = Signal(8) - s5 = Signal(8) - f = Fragment() - f.add_statements("sync", [ - Array([s1, s2, s3])[s4][2:7].eq(s5), - ]) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f.add_driver(s3, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s4) - (case 00000000 (eq (slice (sig s1) 2:7) (sig s5))) - (case 00000001 (eq (slice (sig s2) 2:7) (sig s5))) - (case 00000010 (eq (slice (sig s3) 2:7) (sig s5))) - )) - """) - - def test_part_slice(self): - s1 = Signal(8) - s2 = Signal(4) - s3 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1[1:7].bit_select(s3, 4).eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s3) - (case 0000 (eq (slice (sig s1) 1:5) (sig s2))) - (case 0001 (eq (slice (sig s1) 2:6) (sig s2))) - (case 0010 (eq (slice (sig s1) 3:7) (sig s2))) - (case 0011 (eq (slice (sig s1) 4:7) (sig s2))) - (case 0100 (eq (slice (sig s1) 5:7) (sig s2))) - (case 0101 (eq (slice (sig s1) 6:7) (sig s2))) - )) - """) - - def test_sliced_part_slice(self): - s1 = Signal(12) - s2 = Signal(4) - s3 = Signal(4) - f = Fragment() - f.add_statements( - "sync", - s1[3:9].bit_select(s3, 4)[1:3].eq(s2) - ) - f.add_driver(s1, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ((switch (sig s3) - (case 0000 (eq (slice (sig s1) 4:6) (sig s2))) - (case 0001 (eq (slice (sig s1) 5:7) (sig s2))) - (case 0010 (eq (slice (sig s1) 6:8) (sig s2))) - (case 0011 (eq (slice (sig s1) 7:9) (sig s2))) - (case 0100 (eq (slice (sig s1) 8:9) (sig s2))) - )) - """) - - - def test_sliced_operator(self): - s1 = Signal(8) - s2 = Signal(8) - s3 = Signal(8) - f = Fragment() - f.add_statements("sync", [ - s1.as_signed()[2:7].eq(s3), - s2.as_unsigned()[2:7].eq(s3), - ]) - f.add_driver(s1, "sync") - f.add_driver(s2, "sync") - f = AssignmentLegalizer()(f) - self.assertRepr(f.statements["sync"], """ - ( - (eq (slice (sig s1) 2:7) (sig s3)) - (eq (slice (sig s2) 2:7) (sig s3)) - ) - """) - - class _MockElaboratable(Elaboratable): def __init__(self): self.s1 = Signal()