hdl._ast, hdl._ir: Deduplicate shape unification logic. NFC
This commit is contained in:
parent
31a12c03d1
commit
161b01450e
|
@ -159,6 +159,28 @@ class Shape:
|
|||
return (isinstance(other, Shape) and
|
||||
self.width == other.width and self.signed == other.signed)
|
||||
|
||||
@staticmethod
|
||||
def _unify(shapes):
|
||||
"""Returns the minimal shape that contains all shapes from the list.
|
||||
|
||||
If no shapes passed in, returns unsigned(0).
|
||||
"""
|
||||
unsigned_width = signed_width = 0
|
||||
has_signed = False
|
||||
for shape in shapes:
|
||||
assert isinstance(shape, Shape)
|
||||
if shape.signed:
|
||||
has_signed = True
|
||||
signed_width = max(signed_width, shape.width)
|
||||
else:
|
||||
unsigned_width = max(unsigned_width, shape.width)
|
||||
# If all shapes unsigned, simply take max.
|
||||
if not has_signed:
|
||||
return unsigned(unsigned_width)
|
||||
# Otherwise, result is signed. All unsigned inputs, if any,
|
||||
# need to be converted to signed by adding a zero bit.
|
||||
return signed(max(signed_width, unsigned_width + 1))
|
||||
|
||||
|
||||
def unsigned(width):
|
||||
"""Returns :py:`Shape(width, signed=False)`."""
|
||||
|
@ -1524,20 +1546,6 @@ class Operator(Value):
|
|||
return self._operands
|
||||
|
||||
def shape(self):
|
||||
def _bitwise_binary_shape(a_shape, b_shape):
|
||||
if not a_shape.signed and not b_shape.signed:
|
||||
# both operands unsigned
|
||||
return unsigned(max(a_shape.width, b_shape.width))
|
||||
elif a_shape.signed and b_shape.signed:
|
||||
# both operands signed
|
||||
return signed(max(a_shape.width, b_shape.width))
|
||||
elif not a_shape.signed and b_shape.signed:
|
||||
# first operand unsigned (add sign bit), second operand signed
|
||||
return signed(max(a_shape.width + 1, b_shape.width))
|
||||
else:
|
||||
# first signed, second operand unsigned (add sign bit)
|
||||
return signed(max(a_shape.width, b_shape.width + 1))
|
||||
|
||||
op_shapes = list(map(lambda x: x.shape(), self.operands))
|
||||
if len(op_shapes) == 1:
|
||||
a_shape, = op_shapes
|
||||
|
@ -1554,10 +1562,10 @@ class Operator(Value):
|
|||
elif len(op_shapes) == 2:
|
||||
a_shape, b_shape = op_shapes
|
||||
if self.operator == "+":
|
||||
o_shape = _bitwise_binary_shape(*op_shapes)
|
||||
o_shape = Shape._unify(op_shapes)
|
||||
return Shape(o_shape.width + 1, o_shape.signed)
|
||||
if self.operator == "-":
|
||||
o_shape = _bitwise_binary_shape(*op_shapes)
|
||||
o_shape = Shape._unify(op_shapes)
|
||||
return Shape(o_shape.width + 1, True)
|
||||
if self.operator == "*":
|
||||
return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed)
|
||||
|
@ -1568,7 +1576,7 @@ class Operator(Value):
|
|||
if self.operator in ("<", "<=", "==", "!=", ">", ">="):
|
||||
return Shape(1, False)
|
||||
if self.operator in ("&", "|", "^"):
|
||||
return _bitwise_binary_shape(*op_shapes)
|
||||
return Shape._unify(op_shapes)
|
||||
if self.operator == "<<":
|
||||
assert not b_shape.signed
|
||||
return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed)
|
||||
|
@ -1578,7 +1586,7 @@ class Operator(Value):
|
|||
elif len(op_shapes) == 3:
|
||||
if self.operator == "m":
|
||||
s_shape, a_shape, b_shape = op_shapes
|
||||
return _bitwise_binary_shape(a_shape, b_shape)
|
||||
return Shape._unify((a_shape, b_shape))
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def _lhs_signals(self):
|
||||
|
@ -2254,27 +2262,9 @@ class ArrayProxy(Value):
|
|||
return (Value.cast(elem) for elem in self.elems)
|
||||
|
||||
def shape(self):
|
||||
unsigned_width = signed_width = 0
|
||||
has_unsigned = has_signed = False
|
||||
for elem_shape in (elem.shape() for elem in self._iter_as_values()):
|
||||
if elem_shape.signed:
|
||||
has_signed = True
|
||||
signed_width = max(signed_width, elem_shape.width)
|
||||
else:
|
||||
has_unsigned = True
|
||||
unsigned_width = max(unsigned_width, elem_shape.width)
|
||||
# The shape of the proxy must be such that it preserves the mathematical value of the array
|
||||
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
|
||||
# To ensure this holds, if the array contains both signed and unsigned values, make sure
|
||||
# that every unsigned value is zero-extended by at least one bit.
|
||||
if has_signed and has_unsigned and unsigned_width >= signed_width:
|
||||
# Array contains both signed and unsigned values, and at least one of the unsigned
|
||||
# values won't be zero-extended otherwise.
|
||||
return signed(unsigned_width + 1)
|
||||
else:
|
||||
# Array contains values of the same signedness, or else all of the unsigned values
|
||||
# are zero-extended.
|
||||
return Shape(max(unsigned_width, signed_width), has_signed)
|
||||
return Shape._unify(elem.shape() for elem in self._iter_as_values())
|
||||
|
||||
def _lhs_signals(self):
|
||||
signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
|
||||
|
|
|
@ -677,16 +677,13 @@ class NetlistEmitter:
|
|||
|
||||
def unify_shapes_bitwise(self,
|
||||
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
|
||||
if signed_a == signed_b:
|
||||
width = max(len(operand_a), len(operand_b))
|
||||
elif signed_a:
|
||||
width = max(len(operand_a), len(operand_b) + 1)
|
||||
else: # signed_b
|
||||
width = max(len(operand_a) + 1, len(operand_b))
|
||||
operand_a = self.extend(operand_a, signed_a, width)
|
||||
operand_b = self.extend(operand_b, signed_b, width)
|
||||
signed = signed_a or signed_b
|
||||
return (operand_a, operand_b, signed)
|
||||
shape = _ast.Shape._unify((
|
||||
_ast.Shape(len(operand_a), signed_a),
|
||||
_ast.Shape(len(operand_b), signed_b),
|
||||
))
|
||||
operand_a = self.extend(operand_a, signed_a, shape.width)
|
||||
operand_b = self.extend(operand_b, signed_b, shape.width)
|
||||
return (operand_a, operand_b, shape.signed)
|
||||
|
||||
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
|
||||
"""Emits a RHS value, returns a tuple of (value, is_signed)"""
|
||||
|
@ -825,19 +822,11 @@ class NetlistEmitter:
|
|||
signed = False
|
||||
elif isinstance(value, _ast.ArrayProxy):
|
||||
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
|
||||
width = 0
|
||||
signed = False
|
||||
for elem, elem_signed in elems:
|
||||
if elem_signed:
|
||||
if not signed:
|
||||
width += 1
|
||||
signed = True
|
||||
width = max(width, len(elem))
|
||||
elif signed:
|
||||
width = max(width, len(elem) + 1)
|
||||
else:
|
||||
width = max(width, len(elem))
|
||||
elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems)
|
||||
shape = _ast.Shape._unify(
|
||||
_ast.Shape(len(value), signed)
|
||||
for value, signed in elems
|
||||
)
|
||||
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
|
||||
index, _signed = self.emit_rhs(module_idx, value.index)
|
||||
conds = []
|
||||
for case_index in range(len(elems)):
|
||||
|
@ -855,7 +844,8 @@ class NetlistEmitter:
|
|||
]
|
||||
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
|
||||
src_loc=value.src_loc)
|
||||
result = self.netlist.add_value_cell(width, cell)
|
||||
result = self.netlist.add_value_cell(shape.width, cell)
|
||||
signed = shape.signed
|
||||
elif isinstance(value, _ast.Cat):
|
||||
nets = []
|
||||
for val in value.parts:
|
||||
|
|
Loading…
Reference in a new issue