diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index 2ba1fa3..a8cf574 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -17,7 +17,7 @@ __all__ = ["ValueVisitor", "ValueTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", - "ResetInserter", "EnableInserter"] + "ResetInserter", "EnableInserter", "AssignmentLegalizer"] class ValueVisitor(metaclass=ABCMeta): @@ -670,3 +670,85 @@ class EnableInserter(_ControlInserter): if port._domain in self.controls: port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en))) return new_fragment + + +class AssignmentLegalizer(FragmentTransformer, StatementTransformer): + """Ensures all assignments in switches have one of the following on the LHS: + + - a `Signal` + - a `Slice` with `value` that is a `Signal` + """ + def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None): + if isinstance(lhs, ArrayProxy): + # Lower into a switch. + cases = {} + for idx, val in enumerate(lhs.elems): + cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop) + return [Switch(lhs.index, cases)] + elif isinstance(lhs, Part): + offset = lhs.offset + width = lhs.width + if lhs_start != 0: + width -= lhs_start + if lhs_stop is not None: + width = lhs_stop - lhs_start + cases = {} + lhs_width = len(lhs.value) + for idx in range(lhs_width): + start = lhs_start + idx * lhs.stride + if start >= lhs_width: + break + stop = min(start + width, lhs_width) + cases[idx] = self.emit_assign(lhs.value, rhs, start, stop) + return [Switch(offset, cases)] + elif isinstance(lhs, Slice): + part_start = lhs_start + lhs.start + if lhs_stop is not None: + part_stop = lhs_stop + lhs.start + else: + part_stop = lhs_start + lhs.stop + return self.emit_assign(lhs.value, rhs, part_start, part_stop) + elif isinstance(lhs, Cat): + # Split into several assignments. + part_stop = 0 + res = [] + if lhs_stop is None: + lhs_len = len(lhs) - lhs_start + else: + lhs_len = lhs_stop - lhs_start + if len(rhs) < lhs_len: + rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed)) + for val in lhs.parts: + part_start = part_stop + part_len = len(val) + part_stop = part_start + part_len + if lhs_start >= part_stop: + continue + if lhs_start < part_start: + part_lhs_start = 0 + part_rhs_start = part_start - lhs_start + else: + part_lhs_start = lhs_start - part_start + part_rhs_start = 0 + if lhs_stop is not None and lhs_stop <= part_start: + continue + elif lhs_stop is None or lhs_stop >= part_stop: + part_lhs_stop = None + else: + part_lhs_stop = lhs_stop - part_start + res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop) + return res + elif isinstance(lhs, Signal): + # Already ok. + if lhs_start != 0 or lhs_stop is not None: + return [Assign(lhs[lhs_start:lhs_stop], rhs)] + else: + return [Assign(lhs, rhs)] + elif isinstance(lhs, Operator): + assert lhs.operator in ('u', 's') + return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop) + else: + raise TypeError + + def on_Assign(self, stmt): + return self.emit_assign(stmt.lhs, stmt.rhs) diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 7709ec3..94c50ec 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -547,6 +547,329 @@ class EnableInserterTestCase(FHDLTestCase): """) +class AssignmentLegalizerTestCase(FHDLTestCase): + def test_simple(self): + s1 = Signal(8) + s2 = Signal(8) + f = Fragment() + f.add_statements( + "sync", + s1.eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((eq (sig s1) (sig s2))) + """) + + def test_simple_slice(self): + s1 = Signal(8) + s2 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1[2:6].eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((eq (slice (sig s1) 2:6) (sig s2))) + """) + + def test_simple_part(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1.bit_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s3) + (case 0000 (eq (slice (sig s1) 0:4) (sig s2))) + (case 0001 (eq (slice (sig s1) 1:5) (sig s2))) + (case 0010 (eq (slice (sig s1) 2:6) (sig s2))) + (case 0011 (eq (slice (sig s1) 3:7) (sig s2))) + (case 0100 (eq (slice (sig s1) 4:8) (sig s2))) + (case 0101 (eq (slice (sig s1) 5:8) (sig s2))) + (case 0110 (eq (slice (sig s1) 6:8) (sig s2))) + (case 0111 (eq (slice (sig s1) 7:8) (sig s2))) + )) + """) + + def test_simple_part_word(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1.word_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s3) + (case 0000 (eq (slice (sig s1) 0:4) (sig s2))) + (case 0001 (eq (slice (sig s1) 4:8) (sig s2))) + )) + """) + + def test_simple_concat(self): + s1 = Signal(4) + s2 = Signal(4) + s3 = Signal(4) + s4 = Signal(12) + f = Fragment() + f.add_statements( + "sync", + Cat(s1, s2, s3).eq(s4) + ) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f.add_driver(s3, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (sig s1) (slice (sig s4) 0:12)) + (eq (sig s2) (slice (sig s4) 4:12)) + (eq (sig s3) (slice (sig s4) 8:12)) + ) + """) + + def test_simple_concat_narrow(self): + s1 = Signal(4) + s2 = Signal(4) + s3 = Signal(4) + s4 = Signal(signed(6)) + f = Fragment() + f.add_statements( + "sync", + Cat(s1, s2, s3).eq(s4) + ) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f.add_driver(s3, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (sig s1) (slice (| (sig s4) (const 12'sd0)) 0:12)) + (eq (sig s2) (slice (| (sig s4) (const 12'sd0)) 4:12)) + (eq (sig s3) (slice (| (sig s4) (const 12'sd0)) 8:12)) + ) + """) + + def test_simple_operator(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + f = Fragment() + f.add_statements("sync", [ + s1.as_signed().eq(s3), + s2.as_unsigned().eq(s3), + ]) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (sig s1) (sig s3)) + (eq (sig s2) (sig s3)) + ) + """) + + def test_simple_array(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + s4 = Signal(8) + s5 = Signal(8) + f = Fragment() + f.add_statements("sync", [ + Array([s1, s2, s3])[s4].eq(s5), + ]) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f.add_driver(s3, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s4) + (case 00000000 (eq (sig s1) (sig s5))) + (case 00000001 (eq (sig s2) (sig s5))) + (case 00000010 (eq (sig s3) (sig s5))) + )) + """) + + def test_sliced_slice(self): + s1 = Signal(12) + s2 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1[1:11][2:6].eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((eq (slice (sig s1) 3:7) (sig s2))) + """) + + def test_sliced_concat(self): + s1 = Signal(4) + s2 = Signal(4) + s3 = Signal(4) + s4 = Signal(4) + s5 = Signal(4) + s6 = Signal(8) + f = Fragment() + f.add_statements( + "sync", + Cat(s1, s2, s3, s4, s5)[5:14].eq(s6) + ) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f.add_driver(s3, "sync") + f.add_driver(s4, "sync") + f.add_driver(s5, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (slice (sig s2) 1:4) (slice (| (sig s6) (const 9'd0)) 0:9)) + (eq (sig s3) (slice (| (sig s6) (const 9'd0)) 3:9)) + (eq (slice (sig s4) 0:2) (slice (| (sig s6) (const 9'd0)) 7:9)) + ) + """) + + def test_sliced_part(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1.bit_select(s3, 4)[1:3].eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s3) + (case 0000 (eq (slice (sig s1) 1:3) (sig s2))) + (case 0001 (eq (slice (sig s1) 2:4) (sig s2))) + (case 0010 (eq (slice (sig s1) 3:5) (sig s2))) + (case 0011 (eq (slice (sig s1) 4:6) (sig s2))) + (case 0100 (eq (slice (sig s1) 5:7) (sig s2))) + (case 0101 (eq (slice (sig s1) 6:8) (sig s2))) + (case 0110 (eq (slice (sig s1) 7:8) (sig s2))) + )) + """) + + def test_sliced_part_word(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1.word_select(s3, 4)[1:3].eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s3) + (case 0000 (eq (slice (sig s1) 1:3) (sig s2))) + (case 0001 (eq (slice (sig s1) 5:7) (sig s2))) + )) + """) + + def test_sliced_array(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + s4 = Signal(8) + s5 = Signal(8) + f = Fragment() + f.add_statements("sync", [ + Array([s1, s2, s3])[s4][2:7].eq(s5), + ]) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f.add_driver(s3, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s4) + (case 00000000 (eq (slice (sig s1) 2:7) (sig s5))) + (case 00000001 (eq (slice (sig s2) 2:7) (sig s5))) + (case 00000010 (eq (slice (sig s3) 2:7) (sig s5))) + )) + """) + + def test_part_slice(self): + s1 = Signal(8) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1[1:7].bit_select(s3, 4).eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s3) + (case 0000 (eq (slice (sig s1) 1:5) (sig s2))) + (case 0001 (eq (slice (sig s1) 2:6) (sig s2))) + (case 0010 (eq (slice (sig s1) 3:7) (sig s2))) + (case 0011 (eq (slice (sig s1) 4:7) (sig s2))) + (case 0100 (eq (slice (sig s1) 5:7) (sig s2))) + (case 0101 (eq (slice (sig s1) 6:7) (sig s2))) + )) + """) + + def test_sliced_part_slice(self): + s1 = Signal(12) + s2 = Signal(4) + s3 = Signal(4) + f = Fragment() + f.add_statements( + "sync", + s1[3:9].bit_select(s3, 4)[1:3].eq(s2) + ) + f.add_driver(s1, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ((switch (sig s3) + (case 0000 (eq (slice (sig s1) 4:6) (sig s2))) + (case 0001 (eq (slice (sig s1) 5:7) (sig s2))) + (case 0010 (eq (slice (sig s1) 6:8) (sig s2))) + (case 0011 (eq (slice (sig s1) 7:9) (sig s2))) + (case 0100 (eq (slice (sig s1) 8:9) (sig s2))) + )) + """) + + + def test_sliced_operator(self): + s1 = Signal(8) + s2 = Signal(8) + s3 = Signal(8) + f = Fragment() + f.add_statements("sync", [ + s1.as_signed()[2:7].eq(s3), + s2.as_unsigned()[2:7].eq(s3), + ]) + f.add_driver(s1, "sync") + f.add_driver(s2, "sync") + f = AssignmentLegalizer()(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (slice (sig s1) 2:7) (sig s3)) + (eq (slice (sig s2) 2:7) (sig s3)) + ) + """) + + class _MockElaboratable(Elaboratable): def __init__(self): self.s1 = Signal()