hdl.xfrm: add assignment legalizer.
Co-authored-by: Wanda <wanda@phinode.net>
This commit is contained in:
		
							parent
							
								
									10117607a3
								
							
						
					
					
						commit
						78981232d9
					
				|  | @ -17,7 +17,7 @@ __all__ = ["ValueVisitor", "ValueTransformer", | ||||||
|            "TransformedElaboratable", |            "TransformedElaboratable", | ||||||
|            "DomainCollector", "DomainRenamer", "DomainLowerer", |            "DomainCollector", "DomainRenamer", "DomainLowerer", | ||||||
|            "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", |            "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", | ||||||
|            "ResetInserter", "EnableInserter"] |            "ResetInserter", "EnableInserter", "AssignmentLegalizer"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ValueVisitor(metaclass=ABCMeta): | class ValueVisitor(metaclass=ABCMeta): | ||||||
|  | @ -670,3 +670,85 @@ class EnableInserter(_ControlInserter): | ||||||
|                 if port._domain in self.controls: |                 if port._domain in self.controls: | ||||||
|                     port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en))) |                     port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en))) | ||||||
|         return new_fragment |         return new_fragment | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AssignmentLegalizer(FragmentTransformer, StatementTransformer): | ||||||
|  |     """Ensures all assignments in switches have one of the following on the LHS: | ||||||
|  | 
 | ||||||
|  |     - a `Signal` | ||||||
|  |     - a `Slice` with `value` that is a `Signal` | ||||||
|  |     """ | ||||||
|  |     def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None): | ||||||
|  |         if isinstance(lhs, ArrayProxy): | ||||||
|  |             # Lower into a switch. | ||||||
|  |             cases = {} | ||||||
|  |             for idx, val in enumerate(lhs.elems): | ||||||
|  |                 cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop) | ||||||
|  |             return [Switch(lhs.index, cases)] | ||||||
|  |         elif isinstance(lhs, Part): | ||||||
|  |             offset = lhs.offset | ||||||
|  |             width = lhs.width | ||||||
|  |             if lhs_start != 0: | ||||||
|  |                 width -= lhs_start | ||||||
|  |             if lhs_stop is not None: | ||||||
|  |                 width = lhs_stop - lhs_start | ||||||
|  |             cases = {} | ||||||
|  |             lhs_width = len(lhs.value) | ||||||
|  |             for idx in range(lhs_width): | ||||||
|  |                 start = lhs_start + idx * lhs.stride | ||||||
|  |                 if start >= lhs_width: | ||||||
|  |                     break | ||||||
|  |                 stop = min(start + width, lhs_width) | ||||||
|  |                 cases[idx] = self.emit_assign(lhs.value, rhs, start, stop) | ||||||
|  |             return [Switch(offset, cases)] | ||||||
|  |         elif isinstance(lhs, Slice): | ||||||
|  |             part_start = lhs_start + lhs.start | ||||||
|  |             if lhs_stop is not None: | ||||||
|  |                 part_stop = lhs_stop + lhs.start | ||||||
|  |             else: | ||||||
|  |                 part_stop = lhs_start + lhs.stop | ||||||
|  |             return self.emit_assign(lhs.value, rhs, part_start, part_stop) | ||||||
|  |         elif isinstance(lhs, Cat): | ||||||
|  |             # Split into several assignments. | ||||||
|  |             part_stop = 0 | ||||||
|  |             res = [] | ||||||
|  |             if lhs_stop is None: | ||||||
|  |                 lhs_len = len(lhs) - lhs_start | ||||||
|  |             else: | ||||||
|  |                 lhs_len = lhs_stop - lhs_start | ||||||
|  |             if len(rhs) < lhs_len: | ||||||
|  |                 rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed)) | ||||||
|  |             for val in lhs.parts: | ||||||
|  |                 part_start = part_stop | ||||||
|  |                 part_len = len(val) | ||||||
|  |                 part_stop = part_start + part_len | ||||||
|  |                 if lhs_start >= part_stop: | ||||||
|  |                     continue | ||||||
|  |                 if lhs_start < part_start: | ||||||
|  |                     part_lhs_start = 0 | ||||||
|  |                     part_rhs_start = part_start - lhs_start | ||||||
|  |                 else: | ||||||
|  |                     part_lhs_start = lhs_start - part_start | ||||||
|  |                     part_rhs_start = 0 | ||||||
|  |                 if lhs_stop is not None and lhs_stop <= part_start: | ||||||
|  |                     continue | ||||||
|  |                 elif lhs_stop is None or lhs_stop >= part_stop: | ||||||
|  |                     part_lhs_stop = None | ||||||
|  |                 else: | ||||||
|  |                     part_lhs_stop = lhs_stop - part_start | ||||||
|  |                 res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop) | ||||||
|  |             return res | ||||||
|  |         elif isinstance(lhs, Signal): | ||||||
|  |             # Already ok. | ||||||
|  |             if lhs_start != 0 or lhs_stop is not None: | ||||||
|  |                 return [Assign(lhs[lhs_start:lhs_stop], rhs)] | ||||||
|  |             else: | ||||||
|  |                 return [Assign(lhs, rhs)] | ||||||
|  |         elif isinstance(lhs, Operator): | ||||||
|  |             assert lhs.operator in ('u', 's') | ||||||
|  |             return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop) | ||||||
|  |         else: | ||||||
|  |             raise TypeError | ||||||
|  | 
 | ||||||
|  |     def on_Assign(self, stmt): | ||||||
|  |         return self.emit_assign(stmt.lhs, stmt.rhs) | ||||||
|  |  | ||||||
|  | @ -547,6 +547,329 @@ class EnableInserterTestCase(FHDLTestCase): | ||||||
|         """) |         """) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class AssignmentLegalizerTestCase(FHDLTestCase): | ||||||
|  |     def test_simple(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(8) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1.eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((eq (sig s1) (sig s2))) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_slice(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1[2:6].eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((eq (slice (sig s1) 2:6) (sig s2))) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_part(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1.bit_select(s3, 4).eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s3) | ||||||
|  |                 (case 0000 (eq (slice (sig s1) 0:4) (sig s2))) | ||||||
|  |                 (case 0001 (eq (slice (sig s1) 1:5) (sig s2))) | ||||||
|  |                 (case 0010 (eq (slice (sig s1) 2:6) (sig s2))) | ||||||
|  |                 (case 0011 (eq (slice (sig s1) 3:7) (sig s2))) | ||||||
|  |                 (case 0100 (eq (slice (sig s1) 4:8) (sig s2))) | ||||||
|  |                 (case 0101 (eq (slice (sig s1) 5:8) (sig s2))) | ||||||
|  |                 (case 0110 (eq (slice (sig s1) 6:8) (sig s2))) | ||||||
|  |                 (case 0111 (eq (slice (sig s1) 7:8) (sig s2))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_part_word(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1.word_select(s3, 4).eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s3) | ||||||
|  |                 (case 0000 (eq (slice (sig s1) 0:4) (sig s2))) | ||||||
|  |                 (case 0001 (eq (slice (sig s1) 4:8) (sig s2))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_concat(self): | ||||||
|  |         s1 = Signal(4) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         s4 = Signal(12) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             Cat(s1, s2, s3).eq(s4) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f.add_driver(s3, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ( | ||||||
|  |                 (eq (sig s1) (slice (sig s4) 0:12)) | ||||||
|  |                 (eq (sig s2) (slice (sig s4) 4:12)) | ||||||
|  |                 (eq (sig s3) (slice (sig s4) 8:12)) | ||||||
|  |             ) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_concat_narrow(self): | ||||||
|  |         s1 = Signal(4) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         s4 = Signal(signed(6)) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             Cat(s1, s2, s3).eq(s4) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f.add_driver(s3, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ( | ||||||
|  |                 (eq (sig s1) (slice (| (sig s4) (const 12'sd0)) 0:12)) | ||||||
|  |                 (eq (sig s2) (slice (| (sig s4) (const 12'sd0)) 4:12)) | ||||||
|  |                 (eq (sig s3) (slice (| (sig s4) (const 12'sd0)) 8:12)) | ||||||
|  |             ) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_operator(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(8) | ||||||
|  |         s3 = Signal(8) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements("sync", [ | ||||||
|  |             s1.as_signed().eq(s3), | ||||||
|  |             s2.as_unsigned().eq(s3), | ||||||
|  |         ]) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ( | ||||||
|  |                 (eq (sig s1) (sig s3)) | ||||||
|  |                 (eq (sig s2) (sig s3)) | ||||||
|  |             ) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_simple_array(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(8) | ||||||
|  |         s3 = Signal(8) | ||||||
|  |         s4 = Signal(8) | ||||||
|  |         s5 = Signal(8) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements("sync", [ | ||||||
|  |             Array([s1, s2, s3])[s4].eq(s5), | ||||||
|  |         ]) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f.add_driver(s3, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s4) | ||||||
|  |                 (case 00000000 (eq (sig s1) (sig s5))) | ||||||
|  |                 (case 00000001 (eq (sig s2) (sig s5))) | ||||||
|  |                 (case 00000010 (eq (sig s3) (sig s5))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_sliced_slice(self): | ||||||
|  |         s1 = Signal(12) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1[1:11][2:6].eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((eq (slice (sig s1) 3:7) (sig s2))) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_sliced_concat(self): | ||||||
|  |         s1 = Signal(4) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         s4 = Signal(4) | ||||||
|  |         s5 = Signal(4) | ||||||
|  |         s6 = Signal(8) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             Cat(s1, s2, s3, s4, s5)[5:14].eq(s6) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f.add_driver(s3, "sync") | ||||||
|  |         f.add_driver(s4, "sync") | ||||||
|  |         f.add_driver(s5, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ( | ||||||
|  |                 (eq (slice (sig s2) 1:4) (slice (| (sig s6) (const 9'd0)) 0:9)) | ||||||
|  |                 (eq (sig s3)             (slice (| (sig s6) (const 9'd0)) 3:9)) | ||||||
|  |                 (eq (slice (sig s4) 0:2) (slice (| (sig s6) (const 9'd0)) 7:9)) | ||||||
|  |             ) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_sliced_part(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1.bit_select(s3, 4)[1:3].eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s3) | ||||||
|  |                 (case 0000 (eq (slice (sig s1) 1:3) (sig s2))) | ||||||
|  |                 (case 0001 (eq (slice (sig s1) 2:4) (sig s2))) | ||||||
|  |                 (case 0010 (eq (slice (sig s1) 3:5) (sig s2))) | ||||||
|  |                 (case 0011 (eq (slice (sig s1) 4:6) (sig s2))) | ||||||
|  |                 (case 0100 (eq (slice (sig s1) 5:7) (sig s2))) | ||||||
|  |                 (case 0101 (eq (slice (sig s1) 6:8) (sig s2))) | ||||||
|  |                 (case 0110 (eq (slice (sig s1) 7:8) (sig s2))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_sliced_part_word(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1.word_select(s3, 4)[1:3].eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s3) | ||||||
|  |                 (case 0000 (eq (slice (sig s1) 1:3) (sig s2))) | ||||||
|  |                 (case 0001 (eq (slice (sig s1) 5:7) (sig s2))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_sliced_array(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(8) | ||||||
|  |         s3 = Signal(8) | ||||||
|  |         s4 = Signal(8) | ||||||
|  |         s5 = Signal(8) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements("sync", [ | ||||||
|  |             Array([s1, s2, s3])[s4][2:7].eq(s5), | ||||||
|  |         ]) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f.add_driver(s3, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s4) | ||||||
|  |                 (case 00000000 (eq (slice (sig s1) 2:7) (sig s5))) | ||||||
|  |                 (case 00000001 (eq (slice (sig s2) 2:7) (sig s5))) | ||||||
|  |                 (case 00000010 (eq (slice (sig s3) 2:7) (sig s5))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_part_slice(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1[1:7].bit_select(s3, 4).eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s3) | ||||||
|  |                 (case 0000 (eq (slice (sig s1) 1:5) (sig s2))) | ||||||
|  |                 (case 0001 (eq (slice (sig s1) 2:6) (sig s2))) | ||||||
|  |                 (case 0010 (eq (slice (sig s1) 3:7) (sig s2))) | ||||||
|  |                 (case 0011 (eq (slice (sig s1) 4:7) (sig s2))) | ||||||
|  |                 (case 0100 (eq (slice (sig s1) 5:7) (sig s2))) | ||||||
|  |                 (case 0101 (eq (slice (sig s1) 6:7) (sig s2))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  |     def test_sliced_part_slice(self): | ||||||
|  |         s1 = Signal(12) | ||||||
|  |         s2 = Signal(4) | ||||||
|  |         s3 = Signal(4) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements( | ||||||
|  |             "sync", | ||||||
|  |             s1[3:9].bit_select(s3, 4)[1:3].eq(s2) | ||||||
|  |         ) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ((switch (sig s3) | ||||||
|  |                 (case 0000 (eq (slice (sig s1) 4:6) (sig s2))) | ||||||
|  |                 (case 0001 (eq (slice (sig s1) 5:7) (sig s2))) | ||||||
|  |                 (case 0010 (eq (slice (sig s1) 6:8) (sig s2))) | ||||||
|  |                 (case 0011 (eq (slice (sig s1) 7:9) (sig s2))) | ||||||
|  |                 (case 0100 (eq (slice (sig s1) 8:9) (sig s2))) | ||||||
|  |             )) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def test_sliced_operator(self): | ||||||
|  |         s1 = Signal(8) | ||||||
|  |         s2 = Signal(8) | ||||||
|  |         s3 = Signal(8) | ||||||
|  |         f = Fragment() | ||||||
|  |         f.add_statements("sync", [ | ||||||
|  |             s1.as_signed()[2:7].eq(s3), | ||||||
|  |             s2.as_unsigned()[2:7].eq(s3), | ||||||
|  |         ]) | ||||||
|  |         f.add_driver(s1, "sync") | ||||||
|  |         f.add_driver(s2, "sync") | ||||||
|  |         f = AssignmentLegalizer()(f) | ||||||
|  |         self.assertRepr(f.statements["sync"], """ | ||||||
|  |             ( | ||||||
|  |                 (eq (slice (sig s1) 2:7) (sig s3)) | ||||||
|  |                 (eq (slice (sig s2) 2:7) (sig s3)) | ||||||
|  |             ) | ||||||
|  |         """) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class _MockElaboratable(Elaboratable): | class _MockElaboratable(Elaboratable): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.s1 = Signal() |         self.s1 = Signal() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue
	
	 Catherine
						Catherine