hdl.xfrm: add assignment legalizer.

Co-authored-by: Wanda <wanda@phinode.net>
This commit is contained in:
Catherine 2023-08-21 05:22:33 +00:00
parent 10117607a3
commit 78981232d9
2 changed files with 406 additions and 1 deletions

View file

@ -17,7 +17,7 @@ __all__ = ["ValueVisitor", "ValueTransformer",
"TransformedElaboratable",
"DomainCollector", "DomainRenamer", "DomainLowerer",
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
"ResetInserter", "EnableInserter"]
"ResetInserter", "EnableInserter", "AssignmentLegalizer"]
class ValueVisitor(metaclass=ABCMeta):
@ -670,3 +670,85 @@ class EnableInserter(_ControlInserter):
if port._domain in self.controls:
port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en)))
return new_fragment
class AssignmentLegalizer(FragmentTransformer, StatementTransformer):
"""Ensures all assignments in switches have one of the following on the LHS:
- a `Signal`
- a `Slice` with `value` that is a `Signal`
"""
def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None):
if isinstance(lhs, ArrayProxy):
# Lower into a switch.
cases = {}
for idx, val in enumerate(lhs.elems):
cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop)
return [Switch(lhs.index, cases)]
elif isinstance(lhs, Part):
offset = lhs.offset
width = lhs.width
if lhs_start != 0:
width -= lhs_start
if lhs_stop is not None:
width = lhs_stop - lhs_start
cases = {}
lhs_width = len(lhs.value)
for idx in range(lhs_width):
start = lhs_start + idx * lhs.stride
if start >= lhs_width:
break
stop = min(start + width, lhs_width)
cases[idx] = self.emit_assign(lhs.value, rhs, start, stop)
return [Switch(offset, cases)]
elif isinstance(lhs, Slice):
part_start = lhs_start + lhs.start
if lhs_stop is not None:
part_stop = lhs_stop + lhs.start
else:
part_stop = lhs_start + lhs.stop
return self.emit_assign(lhs.value, rhs, part_start, part_stop)
elif isinstance(lhs, Cat):
# Split into several assignments.
part_stop = 0
res = []
if lhs_stop is None:
lhs_len = len(lhs) - lhs_start
else:
lhs_len = lhs_stop - lhs_start
if len(rhs) < lhs_len:
rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed))
for val in lhs.parts:
part_start = part_stop
part_len = len(val)
part_stop = part_start + part_len
if lhs_start >= part_stop:
continue
if lhs_start < part_start:
part_lhs_start = 0
part_rhs_start = part_start - lhs_start
else:
part_lhs_start = lhs_start - part_start
part_rhs_start = 0
if lhs_stop is not None and lhs_stop <= part_start:
continue
elif lhs_stop is None or lhs_stop >= part_stop:
part_lhs_stop = None
else:
part_lhs_stop = lhs_stop - part_start
res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop)
return res
elif isinstance(lhs, Signal):
# Already ok.
if lhs_start != 0 or lhs_stop is not None:
return [Assign(lhs[lhs_start:lhs_stop], rhs)]
else:
return [Assign(lhs, rhs)]
elif isinstance(lhs, Operator):
assert lhs.operator in ('u', 's')
return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop)
else:
raise TypeError
def on_Assign(self, stmt):
return self.emit_assign(stmt.lhs, stmt.rhs)