hdl.xfrm: add assignment legalizer.

Co-authored-by: Wanda <wanda@phinode.net>
This commit is contained in:
Catherine 2023-08-21 05:22:33 +00:00
parent 10117607a3
commit 78981232d9
2 changed files with 406 additions and 1 deletions

View file

@ -17,7 +17,7 @@ __all__ = ["ValueVisitor", "ValueTransformer",
"TransformedElaboratable",
"DomainCollector", "DomainRenamer", "DomainLowerer",
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
"ResetInserter", "EnableInserter"]
"ResetInserter", "EnableInserter", "AssignmentLegalizer"]
class ValueVisitor(metaclass=ABCMeta):
@ -670,3 +670,85 @@ class EnableInserter(_ControlInserter):
if port._domain in self.controls:
port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en)))
return new_fragment
class AssignmentLegalizer(FragmentTransformer, StatementTransformer):
"""Ensures all assignments in switches have one of the following on the LHS:
- a `Signal`
- a `Slice` with `value` that is a `Signal`
"""
def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None):
if isinstance(lhs, ArrayProxy):
# Lower into a switch.
cases = {}
for idx, val in enumerate(lhs.elems):
cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop)
return [Switch(lhs.index, cases)]
elif isinstance(lhs, Part):
offset = lhs.offset
width = lhs.width
if lhs_start != 0:
width -= lhs_start
if lhs_stop is not None:
width = lhs_stop - lhs_start
cases = {}
lhs_width = len(lhs.value)
for idx in range(lhs_width):
start = lhs_start + idx * lhs.stride
if start >= lhs_width:
break
stop = min(start + width, lhs_width)
cases[idx] = self.emit_assign(lhs.value, rhs, start, stop)
return [Switch(offset, cases)]
elif isinstance(lhs, Slice):
part_start = lhs_start + lhs.start
if lhs_stop is not None:
part_stop = lhs_stop + lhs.start
else:
part_stop = lhs_start + lhs.stop
return self.emit_assign(lhs.value, rhs, part_start, part_stop)
elif isinstance(lhs, Cat):
# Split into several assignments.
part_stop = 0
res = []
if lhs_stop is None:
lhs_len = len(lhs) - lhs_start
else:
lhs_len = lhs_stop - lhs_start
if len(rhs) < lhs_len:
rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed))
for val in lhs.parts:
part_start = part_stop
part_len = len(val)
part_stop = part_start + part_len
if lhs_start >= part_stop:
continue
if lhs_start < part_start:
part_lhs_start = 0
part_rhs_start = part_start - lhs_start
else:
part_lhs_start = lhs_start - part_start
part_rhs_start = 0
if lhs_stop is not None and lhs_stop <= part_start:
continue
elif lhs_stop is None or lhs_stop >= part_stop:
part_lhs_stop = None
else:
part_lhs_stop = lhs_stop - part_start
res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop)
return res
elif isinstance(lhs, Signal):
# Already ok.
if lhs_start != 0 or lhs_stop is not None:
return [Assign(lhs[lhs_start:lhs_stop], rhs)]
else:
return [Assign(lhs, rhs)]
elif isinstance(lhs, Operator):
assert lhs.operator in ('u', 's')
return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop)
else:
raise TypeError
def on_Assign(self, stmt):
return self.emit_assign(stmt.lhs, stmt.rhs)

View file

@ -547,6 +547,329 @@ class EnableInserterTestCase(FHDLTestCase):
""")
class AssignmentLegalizerTestCase(FHDLTestCase):
def test_simple(self):
s1 = Signal(8)
s2 = Signal(8)
f = Fragment()
f.add_statements(
"sync",
s1.eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((eq (sig s1) (sig s2)))
""")
def test_simple_slice(self):
s1 = Signal(8)
s2 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[2:6].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((eq (slice (sig s1) 2:6) (sig s2)))
""")
def test_simple_part(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 0:4) (sig s2)))
(case 0001 (eq (slice (sig s1) 1:5) (sig s2)))
(case 0010 (eq (slice (sig s1) 2:6) (sig s2)))
(case 0011 (eq (slice (sig s1) 3:7) (sig s2)))
(case 0100 (eq (slice (sig s1) 4:8) (sig s2)))
(case 0101 (eq (slice (sig s1) 5:8) (sig s2)))
(case 0110 (eq (slice (sig s1) 6:8) (sig s2)))
(case 0111 (eq (slice (sig s1) 7:8) (sig s2)))
))
""")
def test_simple_part_word(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.word_select(s3, 4).eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 0:4) (sig s2)))
(case 0001 (eq (slice (sig s1) 4:8) (sig s2)))
))
""")
def test_simple_concat(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(12)
f = Fragment()
f.add_statements(
"sync",
Cat(s1, s2, s3).eq(s4)
)
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (sig s1) (slice (sig s4) 0:12))
(eq (sig s2) (slice (sig s4) 4:12))
(eq (sig s3) (slice (sig s4) 8:12))
)
""")
def test_simple_concat_narrow(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(signed(6))
f = Fragment()
f.add_statements(
"sync",
Cat(s1, s2, s3).eq(s4)
)
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (sig s1) (slice (| (sig s4) (const 12'sd0)) 0:12))
(eq (sig s2) (slice (| (sig s4) (const 12'sd0)) 4:12))
(eq (sig s3) (slice (| (sig s4) (const 12'sd0)) 8:12))
)
""")
def test_simple_operator(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
f = Fragment()
f.add_statements("sync", [
s1.as_signed().eq(s3),
s2.as_unsigned().eq(s3),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (sig s1) (sig s3))
(eq (sig s2) (sig s3))
)
""")
def test_simple_array(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
f = Fragment()
f.add_statements("sync", [
Array([s1, s2, s3])[s4].eq(s5),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s4)
(case 00000000 (eq (sig s1) (sig s5)))
(case 00000001 (eq (sig s2) (sig s5)))
(case 00000010 (eq (sig s3) (sig s5)))
))
""")
def test_sliced_slice(self):
s1 = Signal(12)
s2 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[1:11][2:6].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((eq (slice (sig s1) 3:7) (sig s2)))
""")
def test_sliced_concat(self):
s1 = Signal(4)
s2 = Signal(4)
s3 = Signal(4)
s4 = Signal(4)
s5 = Signal(4)
s6 = Signal(8)
f = Fragment()
f.add_statements(
"sync",
Cat(s1, s2, s3, s4, s5)[5:14].eq(s6)
)
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f.add_driver(s4, "sync")
f.add_driver(s5, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (slice (sig s2) 1:4) (slice (| (sig s6) (const 9'd0)) 0:9))
(eq (sig s3) (slice (| (sig s6) (const 9'd0)) 3:9))
(eq (slice (sig s4) 0:2) (slice (| (sig s6) (const 9'd0)) 7:9))
)
""")
def test_sliced_part(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.bit_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 1:3) (sig s2)))
(case 0001 (eq (slice (sig s1) 2:4) (sig s2)))
(case 0010 (eq (slice (sig s1) 3:5) (sig s2)))
(case 0011 (eq (slice (sig s1) 4:6) (sig s2)))
(case 0100 (eq (slice (sig s1) 5:7) (sig s2)))
(case 0101 (eq (slice (sig s1) 6:8) (sig s2)))
(case 0110 (eq (slice (sig s1) 7:8) (sig s2)))
))
""")
def test_sliced_part_word(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1.word_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 1:3) (sig s2)))
(case 0001 (eq (slice (sig s1) 5:7) (sig s2)))
))
""")
def test_sliced_array(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
s4 = Signal(8)
s5 = Signal(8)
f = Fragment()
f.add_statements("sync", [
Array([s1, s2, s3])[s4][2:7].eq(s5),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f.add_driver(s3, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s4)
(case 00000000 (eq (slice (sig s1) 2:7) (sig s5)))
(case 00000001 (eq (slice (sig s2) 2:7) (sig s5)))
(case 00000010 (eq (slice (sig s3) 2:7) (sig s5)))
))
""")
def test_part_slice(self):
s1 = Signal(8)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[1:7].bit_select(s3, 4).eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 1:5) (sig s2)))
(case 0001 (eq (slice (sig s1) 2:6) (sig s2)))
(case 0010 (eq (slice (sig s1) 3:7) (sig s2)))
(case 0011 (eq (slice (sig s1) 4:7) (sig s2)))
(case 0100 (eq (slice (sig s1) 5:7) (sig s2)))
(case 0101 (eq (slice (sig s1) 6:7) (sig s2)))
))
""")
def test_sliced_part_slice(self):
s1 = Signal(12)
s2 = Signal(4)
s3 = Signal(4)
f = Fragment()
f.add_statements(
"sync",
s1[3:9].bit_select(s3, 4)[1:3].eq(s2)
)
f.add_driver(s1, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
((switch (sig s3)
(case 0000 (eq (slice (sig s1) 4:6) (sig s2)))
(case 0001 (eq (slice (sig s1) 5:7) (sig s2)))
(case 0010 (eq (slice (sig s1) 6:8) (sig s2)))
(case 0011 (eq (slice (sig s1) 7:9) (sig s2)))
(case 0100 (eq (slice (sig s1) 8:9) (sig s2)))
))
""")
def test_sliced_operator(self):
s1 = Signal(8)
s2 = Signal(8)
s3 = Signal(8)
f = Fragment()
f.add_statements("sync", [
s1.as_signed()[2:7].eq(s3),
s2.as_unsigned()[2:7].eq(s3),
])
f.add_driver(s1, "sync")
f.add_driver(s2, "sync")
f = AssignmentLegalizer()(f)
self.assertRepr(f.statements["sync"], """
(
(eq (slice (sig s1) 2:7) (sig s3))
(eq (slice (sig s2) 2:7) (sig s3))
)
""")
class _MockElaboratable(Elaboratable):
def __init__(self):
self.s1 = Signal()