hdl._ir: Inline AssignmentLegalizer into netlist building.

Fixes #1150.
This commit is contained in:
Wanda 2024-02-28 15:45:08 +01:00 committed by Catherine
parent 3a1f0a7c32
commit f8e2d26b8f
6 changed files with 804 additions and 452 deletions

View file

@ -10,7 +10,7 @@ from .. import __version__
from .._toolchain import *
from ..hdl import *
from ..hdl._ir import IOBufferInstance, Design
from ..hdl._xfrm import DomainLowerer, AssignmentLegalizer
from ..hdl._xfrm import DomainLowerer
from ..lib.cdc import ResetSynchronizer
from ..back import rtlil, verilog
from .res import *
@ -165,7 +165,6 @@ class Platform(ResourceManager, metaclass=ABCMeta):
add_pin_fragment(pin, self.get_diff_input_output(pin, port, attrs, invert))
ports = [(None, signal, None) for signal in self.iter_ports()]
fragment = AssignmentLegalizer()(fragment)
fragment = Design(fragment, ports, hierarchy=(name,))
return self.toolchain_prepare(fragment, name, **kwargs)

View file

@ -386,7 +386,7 @@ class Fragment:
return new_ports
def prepare(self, ports=(), *, hierarchy=("top",), legalize_assignments=False, missing_domain=lambda name: _cd.ClockDomain(name), propagate_domains=True):
def prepare(self, ports=(), *, hierarchy=("top",), missing_domain=lambda name: _cd.ClockDomain(name), propagate_domains=True):
from ._xfrm import DomainLowerer
ports = self._prepare_ports(ports)
@ -416,9 +416,6 @@ class Fragment:
]
fragment = DomainLowerer()(self)
if legalize_assignments:
from ._xfrm import AssignmentLegalizer
fragment = AssignmentLegalizer()(fragment)
# Create design and let it do the rest.
return Design(fragment, ports, hierarchy=hierarchy)
@ -698,10 +695,7 @@ class NetlistEmitter:
except KeyError:
pass
if isinstance(value, _ast.Const):
result = _nir.Value(
_nir.Net.from_const((value.value >> bit) & 1)
for bit in range(value.width)
)
result = _nir.Value.from_const(value.value, value.width)
signed = value.signed
elif isinstance(value, _ast.Signal):
result = self.emit_signal(value)
@ -887,6 +881,95 @@ class NetlistEmitter:
self.netlist.connections[left] = right
self.connect_src_loc[left] = src_loc
def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.Value, lhs_start: int, rhs: _nir.Value, cond: _nir.Net, *, src_loc):
# Assign rhs to lhs[lhs_start:lhs_start+len(rhs)]
if isinstance(lhs, _ast.Signal):
if lhs in self.drivers:
driver = self.drivers[lhs]
if driver.domain is not cd:
domain_name = cd.name if cd is not None else "comb"
other_domain_name = driver.domain.name if driver.domain is not None else "comb"
raise _ir.DriverConflict(
f"Signal {lhs} driven from domain {domain_name} at {src_loc} and domain "
f"{other_domain_name} at {driver.src_loc}")
if driver.module_idx != module_idx:
mod_name = ".".join(self.netlist.modules[module_idx].name or ("<toplevel>",))
other_mod_name = \
".".join(self.netlist.modules[driver.module_idx].name or ("<toplevel>",))
raise _ir.DriverConflict(
f"Signal {lhs} driven from module {mod_name} at {src_loc} and "
f"module {other_mod_name} at {driver.src_loc}")
else:
driver = NetlistDriver(module_idx, lhs, domain=cd, src_loc=src_loc)
self.drivers[lhs] = driver
driver.assignments.append(_nir.Assignment(cond=cond, start=lhs_start, value=rhs,
src_loc=src_loc))
elif isinstance(lhs, _ast.Slice):
self.emit_assign(module_idx, cd, lhs.value, lhs_start + lhs.start, rhs, cond, src_loc=src_loc)
elif isinstance(lhs, _ast.Cat):
part_stop = 0
for part in lhs.parts:
part_start = part_stop
part_len = len(part)
part_stop = part_start + part_len
if lhs_start >= part_stop:
continue
if lhs_start + len(rhs) <= part_start:
continue
if lhs_start < part_start:
part_lhs_start = 0
part_rhs_start = part_start - lhs_start
else:
part_lhs_start = lhs_start - part_start
part_rhs_start = 0
if lhs_start + len(rhs) >= part_stop:
part_rhs_stop = part_stop - lhs_start
else:
part_rhs_stop = len(rhs)
self.emit_assign(module_idx, cd, part, part_lhs_start, rhs[part_rhs_start:part_rhs_stop], cond, src_loc=src_loc)
elif isinstance(lhs, _ast.Part):
offset, _signed = self.emit_rhs(module_idx, lhs.offset)
width = len(lhs.value)
num_cases = min((width + lhs.stride - 1) // lhs.stride, 1 << len(offset))
conds = []
for case_index in range(num_cases):
cell = _nir.Matches(module_idx, value=offset,
patterns=(f"{case_index:0{len(offset)}b}",),
src_loc=lhs.src_loc)
subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond)
conds = _nir.Value(conds)
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
for idx, subcond in enumerate(conds):
start = lhs_start + idx * lhs.stride
if start >= width:
continue
if start + len(rhs) >= width:
subrhs = rhs[:width - start]
else:
subrhs = rhs
self.emit_assign(module_idx, cd, lhs.value, start, subrhs, subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.ArrayProxy):
index, _signed = self.emit_rhs(module_idx, lhs.index)
conds = []
for case_index in range(len(lhs.elems)):
cell = _nir.Matches(module_idx, value=index,
patterns=(f"{case_index:0{len(index)}b}",),
src_loc=lhs.src_loc)
subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond)
conds = _nir.Value(conds)
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
for subcond, val in zip(conds, lhs.elems):
self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.Operator):
assert lhs.operator in ('u', 's')
self.emit_assign(module_idx, cd, lhs.operands[0], lhs_start, rhs, cond, src_loc=src_loc)
else:
assert False # :nocov:
def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
stmt: _ast.Statement, cond: _nir.Net):
if domain == "comb":
@ -894,42 +977,13 @@ class NetlistEmitter:
else:
cd = fragment.domains[domain]
if isinstance(stmt, _ast.Assign):
if isinstance(stmt.lhs, _ast.Signal):
signal = stmt.lhs
start = 0
width = signal.width
elif isinstance(stmt.lhs, _ast.Slice):
signal = stmt.lhs.value
start = stmt.lhs.start
width = stmt.lhs.stop - stmt.lhs.start
else:
assert False # :nocov:
assert isinstance(signal, _ast.Signal)
if signal in self.drivers:
driver = self.drivers[signal]
if driver.domain is not cd:
domain_name = cd.name if cd is not None else "comb"
other_domain_name = driver.domain.name if driver.domain is not None else "comb"
raise _ir.DriverConflict(
f"Signal {signal} driven from domain {domain_name} at {stmt.src_loc} and domain "
f"{other_domain_name} at {driver.src_loc}")
if driver.module_idx != module_idx:
mod_name = ".".join(self.netlist.modules[module_idx].name or ("<toplevel>",))
other_mod_name = \
".".join(self.netlist.modules[driver.module_idx].name or ("<toplevel>",))
raise _ir.DriverConflict(
f"Signal {signal} driven from module {mod_name} at {stmt.src_loc} and "
f"module {other_mod_name} at {driver.src_loc}")
else:
driver = NetlistDriver(module_idx, signal, domain=cd, src_loc=stmt.src_loc)
self.drivers[signal] = driver
rhs, signed = self.emit_rhs(module_idx, stmt.rhs)
width = len(stmt.lhs)
if len(rhs) > width:
rhs = _nir.Value(rhs[:width])
if len(rhs) < width:
rhs = self.extend(rhs, signed, width)
driver.assignments.append(_nir.Assignment(cond=cond, start=start, value=rhs,
src_loc=stmt.src_loc))
self.emit_assign(module_idx, cd, stmt.lhs, 0, rhs, cond, src_loc=stmt.src_loc)
elif isinstance(stmt, _ast.Property):
test, _signed = self.emit_rhs(module_idx, stmt.test)
if len(test) != 1:
@ -1374,7 +1428,7 @@ def build_netlist(fragment, ports=(), *, name="top", **kwargs):
if isinstance(fragment, Design):
design = fragment
else:
design = fragment.prepare(ports=ports, hierarchy=(name,), legalize_assignments=True, **kwargs)
design = fragment.prepare(ports=ports, hierarchy=(name,), **kwargs)
netlist = _nir.Netlist()
_emit_netlist(netlist, design)
netlist.resolve_all_nets()

View file

@ -88,13 +88,23 @@ class Value(tuple):
return super().__new__(cls, (nets,))
return super().__new__(cls, (Net.ensure(net) for net in nets))
@classmethod
def from_const(cls, value, width):
return cls(Net.from_const((value >> bit) & 1) for bit in range(width))
@classmethod
def zeros(cls, digits=1):
return cls(Net.from_const(0) for _ in range(digits))
return cls.from_const(0, digits)
@classmethod
def ones(cls, digits=1):
return cls(Net.from_const(1) for _ in range(digits))
return cls.from_const(-1, digits)
def __getitem__(self, index):
if isinstance(index, slice):
return type(self)(super().__getitem__(index))
else:
return super().__getitem__(index)
def __repr__(self):
pos = 0

View file

@ -16,7 +16,7 @@ __all__ = ["ValueVisitor", "ValueTransformer",
"FragmentTransformer",
"TransformedElaboratable",
"DomainCollector", "DomainRenamer", "DomainLowerer",
"ResetInserter", "EnableInserter", "AssignmentLegalizer"]
"ResetInserter", "EnableInserter"]
class ValueVisitor(metaclass=ABCMeta):
@ -603,85 +603,3 @@ class EnableInserter(_ControlInserter):
if port._domain in self.controls:
port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en)))
return new_fragment
class AssignmentLegalizer(FragmentTransformer, StatementTransformer):
"""Ensures all assignments in switches have one of the following on the LHS:
- a `Signal`
- a `Slice` with `value` that is a `Signal`
"""
def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None):
if isinstance(lhs, ArrayProxy):
# Lower into a switch.
cases = {}
for idx, val in enumerate(lhs.elems):
cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop)
return [Switch(lhs.index, cases)]
elif isinstance(lhs, Part):
offset = lhs.offset
width = lhs.width
if lhs_start != 0:
width -= lhs_start
if lhs_stop is not None:
width = lhs_stop - lhs_start
cases = {}
lhs_width = len(lhs.value)
for idx in range(lhs_width):
start = lhs_start + idx * lhs.stride
if start >= lhs_width:
break
stop = min(start + width, lhs_width)
cases[idx] = self.emit_assign(lhs.value, rhs, start, stop)
return [Switch(offset, cases)]
elif isinstance(lhs, Slice):
part_start = lhs_start + lhs.start
if lhs_stop is not None:
part_stop = lhs_stop + lhs.start
else:
part_stop = lhs_start + lhs.stop
return self.emit_assign(lhs.value, rhs, part_start, part_stop)
elif isinstance(lhs, Cat):
# Split into several assignments.
part_stop = 0
res = []
if lhs_stop is None:
lhs_len = len(lhs) - lhs_start
else:
lhs_len = lhs_stop - lhs_start
if len(rhs) < lhs_len:
rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed))
for val in lhs.parts:
part_start = part_stop
part_len = len(val)
part_stop = part_start + part_len
if lhs_start >= part_stop:
continue
if lhs_start < part_start:
part_lhs_start = 0
part_rhs_start = part_start - lhs_start
else:
part_lhs_start = lhs_start - part_start
part_rhs_start = 0
if lhs_stop is not None and lhs_stop <= part_start:
continue
elif lhs_stop is None or lhs_stop >= part_stop:
part_lhs_stop = None
else:
part_lhs_stop = lhs_stop - part_start
res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop)
return res
elif isinstance(lhs, Signal):
# Already ok.
if lhs_start != 0 or lhs_stop is not None:
return [Assign(lhs[lhs_start:lhs_stop], rhs)]
else:
return [Assign(lhs, rhs)]
elif isinstance(lhs, Operator):
assert lhs.operator in ('u', 's')
return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop)
else:
raise TypeError
def on_Assign(self, stmt):
return self.emit_assign(stmt.lhs, stmt.rhs)

View file

@ -1059,3 +1059,699 @@ class IOBufferTestCase(FHDLTestCase):
with self.assertRaisesRegex(ValueError,
r"^`oe` must not be used if `o` is not used"):
IOBufferInstance(pad, oe=oe)
class AssignTestCase(FHDLTestCase):
def test_simple(self):
s1 = Signal(8)
s2 = Signal(8)
f = Fragment()
f.add_statements(
"comb",
s1.eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:10)
(output 's1' 0.2:10)
)
(cell 0 0 (top
(output 's1' 0.2:10)
(input 's2' 2:10)
))
)
""")
def test_simple_trunc(self):
s1 = Signal(8)
s2 = Signal(10)
f = Fragment()
f.add_statements(
"comb",
s1.eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:12)
(output 's1' 0.2:10)
)
(cell 0 0 (top
(output 's1' 0.2:10)
(input 's2' 2:12)
))
)
""")
def test_simple_zext(self):
s1 = Signal(8)
s2 = Signal(6)
f = Fragment()
f.add_statements(
"comb",
s1.eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:8)
(output 's1' (cat 0.2:8 2'd0))
)
(cell 0 0 (top
(output 's1' (cat 0.2:8 2'd0))
(input 's2' 2:8)
))
)
""")
def test_simple_sext(self):
s1 = Signal(8)
s2 = Signal(signed(6))
f = Fragment()
f.add_statements(
"comb",
s1.eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:8)
(output 's1' (cat 0.2:8 0.7 0.7))
)
(cell 0 0 (top
(output 's1' (cat 0.2:8 0.7 0.7))
(input 's2' 2:8)
))
)
""")
def test_simple_slice(self):
s1 = Signal(8)
s2 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1[2:6].eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(output 's1' 1.0:8)
)
(cell 0 0 (top
(output 's1' 1.0:8)
(input 's2' 2:6)
))
(cell 1 0 (assignment_list 8'd0 (1 2:6 0.2:6)))
)
""")
def test_simple_part(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1.bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 10.0:8)
)
(cell 0 0 (top
(output 's1' 10.0:8)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (matches 0.6:10 0010))
(cell 4 0 (matches 0.6:10 0011))
(cell 5 0 (matches 0.6:10 0100))
(cell 6 0 (matches 0.6:10 0101))
(cell 7 0 (matches 0.6:10 0110))
(cell 8 0 (matches 0.6:10 0111))
(cell 9 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0)))
(cell 10 0 (assignment_list 8'd0
(9.0 0:4 0.2:6)
(9.1 1:5 0.2:6)
(9.2 2:6 0.2:6)
(9.3 3:7 0.2:6)
(9.4 4:8 0.2:6)
(9.5 5:8 0.2:5)
(9.6 6:8 0.2:4)
(9.7 7:8 0.2)
))
)
""")
def test_simple_part_short(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(2)
f = Fragment()
f.add_statements(
"comb",
s1.bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:8)
(output 's1' 6.0:8)
)
(cell 0 0 (top
(output 's1' 6.0:8)
(input 's2' 2:6)
(input 's3' 6:8)
))
(cell 1 0 (matches 0.6:8 00))
(cell 2 0 (matches 0.6:8 01))
(cell 3 0 (matches 0.6:8 10))
(cell 4 0 (matches 0.6:8 11))
(cell 5 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0)))
(cell 6 0 (assignment_list 8'd0
(5.0 0:4 0.2:6)
(5.1 1:5 0.2:6)
(5.2 2:6 0.2:6)
(5.3 3:7 0.2:6)
))
)
""")
def test_simple_part_word(self):
s1 = Signal(16)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1.word_select(s3, 4).eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 6.0:16)
)
(cell 0 0 (top
(output 's1' 6.0:16)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (matches 0.6:10 0010))
(cell 4 0 (matches 0.6:10 0011))
(cell 5 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0)))
(cell 6 0 (assignment_list 16'd0
(5.0 0:4 0.2:6)
(5.1 4:8 0.2:6)
(5.2 8:12 0.2:6)
(5.3 12:16 0.2:6)
))
)
""")
def test_simple_part_word_misalign(self):
s1 = Signal(17)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1.word_select(s3, 4).eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 7.0:17)
)
(cell 0 0 (top
(output 's1' 7.0:17)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (matches 0.6:10 0010))
(cell 4 0 (matches 0.6:10 0011))
(cell 5 0 (matches 0.6:10 0100))
(cell 6 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0)))
(cell 7 0 (assignment_list 17'd0
(6.0 0:4 0.2:6)
(6.1 4:8 0.2:6)
(6.2 8:12 0.2:6)
(6.3 12:16 0.2:6)
(6.4 16:17 0.2)
))
)
""")
def test_simple_concat(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(12)
f = Fragment()
f.add_statements(
"comb",
Cat(s1, s2, s3).eq(s4)
)
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
f.add_driver(s3, "comb")
nl = build_netlist(f, ports=[s1, s2, s3, s4])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's4' 0.2:14)
(output 's1' 0.2:6)
(output 's2' 0.6:10)
(output 's3' 0.10:14)
)
(cell 0 0 (top
(output 's1' 0.2:6)
(output 's2' 0.6:10)
(output 's3' 0.10:14)
(input 's4' 2:14)
))
)
""")
def test_simple_concat_narrow(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(signed(6))
f = Fragment()
f.add_statements(
"comb",
Cat(s1, s2, s3).eq(s4)
)
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
f.add_driver(s3, "comb")
nl = build_netlist(f, ports=[s1, s2, s3, s4])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's4' 0.2:8)
(output 's1' 0.2:6)
(output 's2' (cat 0.6:8 0.7 0.7))
(output 's3' (cat 0.7 0.7 0.7 0.7))
)
(cell 0 0 (top
(output 's1' 0.2:6)
(output 's2' (cat 0.6:8 0.7 0.7))
(output 's3' (cat 0.7 0.7 0.7 0.7))
(input 's4' 2:8)
))
)
""")
def test_simple_operator(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
f = Fragment()
f.add_statements("comb", [
s1.as_signed().eq(s3),
s2.as_unsigned().eq(s3),
])
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's3' 0.2:10)
(output 's1' 0.2:10)
(output 's2' 0.2:10)
)
(cell 0 0 (top
(output 's1' 0.2:10)
(output 's2' 0.2:10)
(input 's3' 2:10)
))
)
""")
def test_simple_array(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
f = Fragment()
f.add_statements("comb", [
Array([s1, s2, s3])[s4].eq(s5),
])
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
f.add_driver(s3, "comb")
nl = build_netlist(f, ports=[s1, s2, s3, s4, s5])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's4' 0.2:10)
(input 's5' 0.10:18)
(output 's1' 5.0:8)
(output 's2' 6.0:8)
(output 's3' 7.0:8)
)
(cell 0 0 (top
(output 's1' 5.0:8)
(output 's2' 6.0:8)
(output 's3' 7.0:8)
(input 's4' 2:10)
(input 's5' 10:18)
))
(cell 1 0 (matches 0.2:10 00000000))
(cell 2 0 (matches 0.2:10 00000001))
(cell 3 0 (matches 0.2:10 00000010))
(cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0)))
(cell 5 0 (assignment_list 8'd0 (4.0 0:8 0.10:18)))
(cell 6 0 (assignment_list 8'd0 (4.1 0:8 0.10:18)))
(cell 7 0 (assignment_list 8'd0 (4.2 0:8 0.10:18)))
)
""")
def test_sliced_slice(self):
s1 = Signal(12)
s2 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1[1:11][2:6].eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(output 's1' 1.0:12)
)
(cell 0 0 (top
(output 's1' 1.0:12)
(input 's2' 2:6)
))
(cell 1 0 (assignment_list 12'd0 (1 3:7 0.2:6)))
)
""")
def test_sliced_concat(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(4)
s5 = Signal(4)
s6 = Signal(8)
f = Fragment()
f.add_statements(
"comb",
Cat(s1, s2, s3, s4, s5)[5:14].eq(s6)
)
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
f.add_driver(s3, "comb")
f.add_driver(s4, "comb")
f.add_driver(s5, "comb")
nl = build_netlist(f, ports=[s1, s2, s3, s4, s5, s6])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's1' 0.2:6)
(input 's5' 0.6:10)
(input 's6' 0.10:18)
(output 's2' 1.0:4)
(output 's3' 0.13:17)
(output 's4' 2.0:4)
)
(cell 0 0 (top
(output 's2' 1.0:4)
(output 's3' 0.13:17)
(output 's4' 2.0:4)
(input 's1' 2:6)
(input 's5' 6:10)
(input 's6' 10:18)
))
(cell 1 0 (assignment_list 4'd0 (1 1:4 0.10:13)))
(cell 2 0 (assignment_list 4'd0 (1 0:2 (cat 0.17 1'd0))))
)
""")
def test_sliced_part(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1.bit_select(s3, 6)[2:4].eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 10.0:8)
)
(cell 0 0 (top
(output 's1' 10.0:8)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (matches 0.6:10 0010))
(cell 4 0 (matches 0.6:10 0011))
(cell 5 0 (matches 0.6:10 0100))
(cell 6 0 (matches 0.6:10 0101))
(cell 7 0 (matches 0.6:10 0110))
(cell 8 0 (matches 0.6:10 0111))
(cell 9 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0)))
(cell 10 0 (assignment_list 8'd0
(9.0 2:4 0.2:4)
(9.1 3:5 0.2:4)
(9.2 4:6 0.2:4)
(9.3 5:7 0.2:4)
(9.4 6:8 0.2:4)
(9.5 7:8 0.2)
))
)
""")
def test_sliced_part_word(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1.word_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 4.0:8)
)
(cell 0 0 (top
(output 's1' 4.0:8)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (priority_match 1 (cat 1.0 2.0)))
(cell 4 0 (assignment_list 8'd0
(3.0 1:3 0.2:4)
(3.1 5:7 0.2:4)
))
)
""")
def test_sliced_array(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
f = Fragment()
f.add_statements("comb", [
Array([s1, s2, s3])[s4][2:7].eq(s5),
])
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
f.add_driver(s3, "comb")
nl = build_netlist(f, ports=[s1, s2, s3, s4, s5])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's4' 0.2:10)
(input 's5' 0.10:18)
(output 's1' 5.0:8)
(output 's2' 6.0:8)
(output 's3' 7.0:8)
)
(cell 0 0 (top
(output 's1' 5.0:8)
(output 's2' 6.0:8)
(output 's3' 7.0:8)
(input 's4' 2:10)
(input 's5' 10:18)
))
(cell 1 0 (matches 0.2:10 00000000))
(cell 2 0 (matches 0.2:10 00000001))
(cell 3 0 (matches 0.2:10 00000010))
(cell 4 0 (priority_match 1 (cat 1.0 2.0 3.0)))
(cell 5 0 (assignment_list 8'd0 (4.0 2:7 0.10:15)))
(cell 6 0 (assignment_list 8'd0 (4.1 2:7 0.10:15)))
(cell 7 0 (assignment_list 8'd0 (4.2 2:7 0.10:15)))
)
""")
def test_part_slice(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1[1:7].bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 8.0:8)
)
(cell 0 0 (top
(output 's1' 8.0:8)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (matches 0.6:10 0010))
(cell 4 0 (matches 0.6:10 0011))
(cell 5 0 (matches 0.6:10 0100))
(cell 6 0 (matches 0.6:10 0101))
(cell 7 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0)))
(cell 8 0 (assignment_list 8'd0
(7.0 1:5 0.2:6)
(7.1 2:6 0.2:6)
(7.2 3:7 0.2:6)
(7.3 4:7 0.2:5)
(7.4 5:7 0.2:4)
(7.5 6:7 0.2)
))
)
""")
def test_sliced_part_slice(self):
s1 = Signal(12)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"comb",
s1[3:9].bit_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's2' 0.2:6)
(input 's3' 0.6:10)
(output 's1' 8.0:12)
)
(cell 0 0 (top
(output 's1' 8.0:12)
(input 's2' 2:6)
(input 's3' 6:10)
))
(cell 1 0 (matches 0.6:10 0000))
(cell 2 0 (matches 0.6:10 0001))
(cell 3 0 (matches 0.6:10 0010))
(cell 4 0 (matches 0.6:10 0011))
(cell 5 0 (matches 0.6:10 0100))
(cell 6 0 (matches 0.6:10 0101))
(cell 7 0 (priority_match 1 (cat 1.0 2.0 3.0 4.0 5.0 6.0)))
(cell 8 0 (assignment_list 12'd0
(7.0 4:6 0.2:4)
(7.1 5:7 0.2:4)
(7.2 6:8 0.2:4)
(7.3 7:9 0.2:4)
(7.4 8:9 0.2)
))
)
""")
def test_sliced_operator(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
f = Fragment()
f.add_statements("comb", [
s1.as_signed()[2:7].eq(s3),
s2.as_unsigned()[2:7].eq(s3),
])
f.add_driver(s1, "comb")
f.add_driver(s2, "comb")
nl = build_netlist(f, ports=[s1, s2, s3])
self.assertRepr(nl, """
(
(module 0 None ('top')
(input 's3' 0.2:10)
(output 's1' 1.0:8)
(output 's2' 2.0:8)
)
(cell 0 0 (top
(output 's1' 1.0:8)
(output 's2' 2.0:8)
(input 's3' 2:10)
))
(cell 1 0 (assignment_list 8'd0 (1 2:7 0.2:7)))
(cell 2 0 (assignment_list 8'd0 (1 2:7 0.2:7)))
)
""")

View file

@ -1,7 +1,5 @@
# amaranth: UnusedElaboratable=no
import warnings
from amaranth.hdl._ast import *
from amaranth.hdl._cd import *
from amaranth.hdl._dsl import *
@ -420,329 +418,6 @@ class EnableInserterTestCase(FHDLTestCase):
""")
class AssignmentLegalizerTestCase(FHDLTestCase):
def test_simple(self):
s1 = Signal(8)
s2 = Signal(8)
f = Fragment()
f.add_statements(
"sync",
s1.eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((eq (sig s1) (sig s2)))
""")
def test_simple_slice(self):
s1 = Signal(8)
s2 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[2:6].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((eq (slice (sig s1) 2:6) (sig s2)))
""")
def test_simple_part(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 0:4) (sig s2)))
(case 0001 (eq (slice (sig s1) 1:5) (sig s2)))
(case 0010 (eq (slice (sig s1) 2:6) (sig s2)))
(case 0011 (eq (slice (sig s1) 3:7) (sig s2)))
(case 0100 (eq (slice (sig s1) 4:8) (sig s2)))
(case 0101 (eq (slice (sig s1) 5:8) (sig s2)))
(case 0110 (eq (slice (sig s1) 6:8) (sig s2)))
(case 0111 (eq (slice (sig s1) 7:8) (sig s2)))
))
""")
def test_simple_part_word(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.word_select(s3, 4).eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 0:4) (sig s2)))
(case 0001 (eq (slice (sig s1) 4:8) (sig s2)))
))
""")
def test_simple_concat(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(12)
f = Fragment()
f.add_statements(
"sync",
Cat(s1, s2, s3).eq(s4)
)
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (sig s1) (slice (sig s4) 0:12))
(eq (sig s2) (slice (sig s4) 4:12))
(eq (sig s3) (slice (sig s4) 8:12))
)
""")
def test_simple_concat_narrow(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(signed(6))
f = Fragment()
f.add_statements(
"sync",
Cat(s1, s2, s3).eq(s4)
)
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (sig s1) (slice (| (sig s4) (const 12'sd0)) 0:12))
(eq (sig s2) (slice (| (sig s4) (const 12'sd0)) 4:12))
(eq (sig s3) (slice (| (sig s4) (const 12'sd0)) 8:12))
)
""")
def test_simple_operator(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
f = Fragment()
f.add_statements("sync", [
s1.as_signed().eq(s3),
s2.as_unsigned().eq(s3),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (sig s1) (sig s3))
(eq (sig s2) (sig s3))
)
""")
def test_simple_array(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
f = Fragment()
f.add_statements("sync", [
Array([s1, s2, s3])[s4].eq(s5),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s4)
(case 00000000 (eq (sig s1) (sig s5)))
(case 00000001 (eq (sig s2) (sig s5)))
(case 00000010 (eq (sig s3) (sig s5)))
))
""")
def test_sliced_slice(self):
s1 = Signal(12)
s2 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[1:11][2:6].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((eq (slice (sig s1) 3:7) (sig s2)))
""")
def test_sliced_concat(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(4)
s5 = Signal(4)
s6 = Signal(8)
f = Fragment()
f.add_statements(
"sync",
Cat(s1, s2, s3, s4, s5)[5:14].eq(s6)
)
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f.add_driver(s4, "sync")
f.add_driver(s5, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (slice (sig s2) 1:4) (slice (| (sig s6) (const 9'd0)) 0:9))
(eq (sig s3) (slice (| (sig s6) (const 9'd0)) 3:9))
(eq (slice (sig s4) 0:2) (slice (| (sig s6) (const 9'd0)) 7:9))
)
""")
def test_sliced_part(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.bit_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 1:3) (sig s2)))
(case 0001 (eq (slice (sig s1) 2:4) (sig s2)))
(case 0010 (eq (slice (sig s1) 3:5) (sig s2)))
(case 0011 (eq (slice (sig s1) 4:6) (sig s2)))
(case 0100 (eq (slice (sig s1) 5:7) (sig s2)))
(case 0101 (eq (slice (sig s1) 6:8) (sig s2)))
(case 0110 (eq (slice (sig s1) 7:8) (sig s2)))
))
""")
def test_sliced_part_word(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.word_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 1:3) (sig s2)))
(case 0001 (eq (slice (sig s1) 5:7) (sig s2)))
))
""")
def test_sliced_array(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
f = Fragment()
f.add_statements("sync", [
Array([s1, s2, s3])[s4][2:7].eq(s5),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s4)
(case 00000000 (eq (slice (sig s1) 2:7) (sig s5)))
(case 00000001 (eq (slice (sig s2) 2:7) (sig s5)))
(case 00000010 (eq (slice (sig s3) 2:7) (sig s5)))
))
""")
def test_part_slice(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[1:7].bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 1:5) (sig s2)))
(case 0001 (eq (slice (sig s1) 2:6) (sig s2)))
(case 0010 (eq (slice (sig s1) 3:7) (sig s2)))
(case 0011 (eq (slice (sig s1) 4:7) (sig s2)))
(case 0100 (eq (slice (sig s1) 5:7) (sig s2)))
(case 0101 (eq (slice (sig s1) 6:7) (sig s2)))
))
""")
def test_sliced_part_slice(self):
s1 = Signal(12)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[3:9].bit_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 4:6) (sig s2)))
(case 0001 (eq (slice (sig s1) 5:7) (sig s2)))
(case 0010 (eq (slice (sig s1) 6:8) (sig s2)))
(case 0011 (eq (slice (sig s1) 7:9) (sig s2)))
(case 0100 (eq (slice (sig s1) 8:9) (sig s2)))
))
""")
def test_sliced_operator(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
f = Fragment()
f.add_statements("sync", [
s1.as_signed()[2:7].eq(s3),
s2.as_unsigned()[2:7].eq(s3),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (slice (sig s1) 2:7) (sig s3))
(eq (slice (sig s2) 2:7) (sig s3))
)
""")
class _MockElaboratable(Elaboratable):
def __init__(self):
self.s1 = Signal()