hdl._ast, hdl._ir: Deduplicate shape unification logic. NFC
This commit is contained in:
parent
31a12c03d1
commit
161b01450e
|
@ -159,6 +159,28 @@ class Shape:
|
||||||
return (isinstance(other, Shape) and
|
return (isinstance(other, Shape) and
|
||||||
self.width == other.width and self.signed == other.signed)
|
self.width == other.width and self.signed == other.signed)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unify(shapes):
|
||||||
|
"""Returns the minimal shape that contains all shapes from the list.
|
||||||
|
|
||||||
|
If no shapes passed in, returns unsigned(0).
|
||||||
|
"""
|
||||||
|
unsigned_width = signed_width = 0
|
||||||
|
has_signed = False
|
||||||
|
for shape in shapes:
|
||||||
|
assert isinstance(shape, Shape)
|
||||||
|
if shape.signed:
|
||||||
|
has_signed = True
|
||||||
|
signed_width = max(signed_width, shape.width)
|
||||||
|
else:
|
||||||
|
unsigned_width = max(unsigned_width, shape.width)
|
||||||
|
# If all shapes unsigned, simply take max.
|
||||||
|
if not has_signed:
|
||||||
|
return unsigned(unsigned_width)
|
||||||
|
# Otherwise, result is signed. All unsigned inputs, if any,
|
||||||
|
# need to be converted to signed by adding a zero bit.
|
||||||
|
return signed(max(signed_width, unsigned_width + 1))
|
||||||
|
|
||||||
|
|
||||||
def unsigned(width):
|
def unsigned(width):
|
||||||
"""Returns :py:`Shape(width, signed=False)`."""
|
"""Returns :py:`Shape(width, signed=False)`."""
|
||||||
|
@ -1524,20 +1546,6 @@ class Operator(Value):
|
||||||
return self._operands
|
return self._operands
|
||||||
|
|
||||||
def shape(self):
|
def shape(self):
|
||||||
def _bitwise_binary_shape(a_shape, b_shape):
|
|
||||||
if not a_shape.signed and not b_shape.signed:
|
|
||||||
# both operands unsigned
|
|
||||||
return unsigned(max(a_shape.width, b_shape.width))
|
|
||||||
elif a_shape.signed and b_shape.signed:
|
|
||||||
# both operands signed
|
|
||||||
return signed(max(a_shape.width, b_shape.width))
|
|
||||||
elif not a_shape.signed and b_shape.signed:
|
|
||||||
# first operand unsigned (add sign bit), second operand signed
|
|
||||||
return signed(max(a_shape.width + 1, b_shape.width))
|
|
||||||
else:
|
|
||||||
# first signed, second operand unsigned (add sign bit)
|
|
||||||
return signed(max(a_shape.width, b_shape.width + 1))
|
|
||||||
|
|
||||||
op_shapes = list(map(lambda x: x.shape(), self.operands))
|
op_shapes = list(map(lambda x: x.shape(), self.operands))
|
||||||
if len(op_shapes) == 1:
|
if len(op_shapes) == 1:
|
||||||
a_shape, = op_shapes
|
a_shape, = op_shapes
|
||||||
|
@ -1554,10 +1562,10 @@ class Operator(Value):
|
||||||
elif len(op_shapes) == 2:
|
elif len(op_shapes) == 2:
|
||||||
a_shape, b_shape = op_shapes
|
a_shape, b_shape = op_shapes
|
||||||
if self.operator == "+":
|
if self.operator == "+":
|
||||||
o_shape = _bitwise_binary_shape(*op_shapes)
|
o_shape = Shape._unify(op_shapes)
|
||||||
return Shape(o_shape.width + 1, o_shape.signed)
|
return Shape(o_shape.width + 1, o_shape.signed)
|
||||||
if self.operator == "-":
|
if self.operator == "-":
|
||||||
o_shape = _bitwise_binary_shape(*op_shapes)
|
o_shape = Shape._unify(op_shapes)
|
||||||
return Shape(o_shape.width + 1, True)
|
return Shape(o_shape.width + 1, True)
|
||||||
if self.operator == "*":
|
if self.operator == "*":
|
||||||
return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed)
|
return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed)
|
||||||
|
@ -1568,7 +1576,7 @@ class Operator(Value):
|
||||||
if self.operator in ("<", "<=", "==", "!=", ">", ">="):
|
if self.operator in ("<", "<=", "==", "!=", ">", ">="):
|
||||||
return Shape(1, False)
|
return Shape(1, False)
|
||||||
if self.operator in ("&", "|", "^"):
|
if self.operator in ("&", "|", "^"):
|
||||||
return _bitwise_binary_shape(*op_shapes)
|
return Shape._unify(op_shapes)
|
||||||
if self.operator == "<<":
|
if self.operator == "<<":
|
||||||
assert not b_shape.signed
|
assert not b_shape.signed
|
||||||
return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed)
|
return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed)
|
||||||
|
@ -1578,7 +1586,7 @@ class Operator(Value):
|
||||||
elif len(op_shapes) == 3:
|
elif len(op_shapes) == 3:
|
||||||
if self.operator == "m":
|
if self.operator == "m":
|
||||||
s_shape, a_shape, b_shape = op_shapes
|
s_shape, a_shape, b_shape = op_shapes
|
||||||
return _bitwise_binary_shape(a_shape, b_shape)
|
return Shape._unify((a_shape, b_shape))
|
||||||
raise NotImplementedError # :nocov:
|
raise NotImplementedError # :nocov:
|
||||||
|
|
||||||
def _lhs_signals(self):
|
def _lhs_signals(self):
|
||||||
|
@ -2254,27 +2262,9 @@ class ArrayProxy(Value):
|
||||||
return (Value.cast(elem) for elem in self.elems)
|
return (Value.cast(elem) for elem in self.elems)
|
||||||
|
|
||||||
def shape(self):
|
def shape(self):
|
||||||
unsigned_width = signed_width = 0
|
|
||||||
has_unsigned = has_signed = False
|
|
||||||
for elem_shape in (elem.shape() for elem in self._iter_as_values()):
|
|
||||||
if elem_shape.signed:
|
|
||||||
has_signed = True
|
|
||||||
signed_width = max(signed_width, elem_shape.width)
|
|
||||||
else:
|
|
||||||
has_unsigned = True
|
|
||||||
unsigned_width = max(unsigned_width, elem_shape.width)
|
|
||||||
# The shape of the proxy must be such that it preserves the mathematical value of the array
|
# The shape of the proxy must be such that it preserves the mathematical value of the array
|
||||||
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
|
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
|
||||||
# To ensure this holds, if the array contains both signed and unsigned values, make sure
|
return Shape._unify(elem.shape() for elem in self._iter_as_values())
|
||||||
# that every unsigned value is zero-extended by at least one bit.
|
|
||||||
if has_signed and has_unsigned and unsigned_width >= signed_width:
|
|
||||||
# Array contains both signed and unsigned values, and at least one of the unsigned
|
|
||||||
# values won't be zero-extended otherwise.
|
|
||||||
return signed(unsigned_width + 1)
|
|
||||||
else:
|
|
||||||
# Array contains values of the same signedness, or else all of the unsigned values
|
|
||||||
# are zero-extended.
|
|
||||||
return Shape(max(unsigned_width, signed_width), has_signed)
|
|
||||||
|
|
||||||
def _lhs_signals(self):
|
def _lhs_signals(self):
|
||||||
signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
|
signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
|
||||||
|
|
|
@ -677,16 +677,13 @@ class NetlistEmitter:
|
||||||
|
|
||||||
def unify_shapes_bitwise(self,
|
def unify_shapes_bitwise(self,
|
||||||
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
|
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
|
||||||
if signed_a == signed_b:
|
shape = _ast.Shape._unify((
|
||||||
width = max(len(operand_a), len(operand_b))
|
_ast.Shape(len(operand_a), signed_a),
|
||||||
elif signed_a:
|
_ast.Shape(len(operand_b), signed_b),
|
||||||
width = max(len(operand_a), len(operand_b) + 1)
|
))
|
||||||
else: # signed_b
|
operand_a = self.extend(operand_a, signed_a, shape.width)
|
||||||
width = max(len(operand_a) + 1, len(operand_b))
|
operand_b = self.extend(operand_b, signed_b, shape.width)
|
||||||
operand_a = self.extend(operand_a, signed_a, width)
|
return (operand_a, operand_b, shape.signed)
|
||||||
operand_b = self.extend(operand_b, signed_b, width)
|
|
||||||
signed = signed_a or signed_b
|
|
||||||
return (operand_a, operand_b, signed)
|
|
||||||
|
|
||||||
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
|
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
|
||||||
"""Emits a RHS value, returns a tuple of (value, is_signed)"""
|
"""Emits a RHS value, returns a tuple of (value, is_signed)"""
|
||||||
|
@ -825,19 +822,11 @@ class NetlistEmitter:
|
||||||
signed = False
|
signed = False
|
||||||
elif isinstance(value, _ast.ArrayProxy):
|
elif isinstance(value, _ast.ArrayProxy):
|
||||||
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
|
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
|
||||||
width = 0
|
shape = _ast.Shape._unify(
|
||||||
signed = False
|
_ast.Shape(len(value), signed)
|
||||||
for elem, elem_signed in elems:
|
for value, signed in elems
|
||||||
if elem_signed:
|
)
|
||||||
if not signed:
|
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
|
||||||
width += 1
|
|
||||||
signed = True
|
|
||||||
width = max(width, len(elem))
|
|
||||||
elif signed:
|
|
||||||
width = max(width, len(elem) + 1)
|
|
||||||
else:
|
|
||||||
width = max(width, len(elem))
|
|
||||||
elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems)
|
|
||||||
index, _signed = self.emit_rhs(module_idx, value.index)
|
index, _signed = self.emit_rhs(module_idx, value.index)
|
||||||
conds = []
|
conds = []
|
||||||
for case_index in range(len(elems)):
|
for case_index in range(len(elems)):
|
||||||
|
@ -855,7 +844,8 @@ class NetlistEmitter:
|
||||||
]
|
]
|
||||||
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
|
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
|
||||||
src_loc=value.src_loc)
|
src_loc=value.src_loc)
|
||||||
result = self.netlist.add_value_cell(width, cell)
|
result = self.netlist.add_value_cell(shape.width, cell)
|
||||||
|
signed = shape.signed
|
||||||
elif isinstance(value, _ast.Cat):
|
elif isinstance(value, _ast.Cat):
|
||||||
nets = []
|
nets = []
|
||||||
for val in value.parts:
|
for val in value.parts:
|
||||||
|
|
Loading…
Reference in a new issue