fhdl.ast: refactor Operator.shape(). NFC.
This commit is contained in:
parent
3a8685c352
commit
46f5addf05
|
@ -259,61 +259,71 @@ class Operator(Value):
|
||||||
self.operands = [Value.wrap(o) for o in operands]
|
self.operands = [Value.wrap(o) for o in operands]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bitwise_binary_shape(a, b):
|
def _bitwise_binary_shape(a_shape, b_shape):
|
||||||
if not a[1] and not b[1]:
|
a_bits, a_sign = a_shape
|
||||||
|
b_bits, b_sign = b_shape
|
||||||
|
if not a_sign and not b_sign:
|
||||||
# both operands unsigned
|
# both operands unsigned
|
||||||
return max(a[0], b[0]), False
|
return max(a_bits, b_bits), False
|
||||||
elif a[1] and b[1]:
|
elif a_sign and b_sign:
|
||||||
# both operands signed
|
# both operands signed
|
||||||
return max(a[0], b[0]), True
|
return max(a_bits, b_bits), True
|
||||||
elif not a[1] and b[1]:
|
elif not a_sign and b_sign:
|
||||||
# first operand unsigned (add sign bit), second operand signed
|
# first operand unsigned (add sign bit), second operand signed
|
||||||
return max(a[0] + 1, b[0]), True
|
return max(a_bits + 1, b_bits), True
|
||||||
else:
|
else:
|
||||||
# first signed, second operand unsigned (add sign bit)
|
# first signed, second operand unsigned (add sign bit)
|
||||||
return max(a[0], b[0] + 1), True
|
return max(a_bits, b_bits + 1), True
|
||||||
|
|
||||||
def shape(self):
|
def shape(self):
|
||||||
obs = list(map(lambda x: x.shape(), self.operands))
|
op_shapes = list(map(lambda x: x.shape(), self.operands))
|
||||||
|
if len(op_shapes) == 1:
|
||||||
|
(a_bits, a_sign), = op_shapes
|
||||||
|
if self.op in ("+", "~"):
|
||||||
|
return a_bits, a_sign
|
||||||
|
if self.op == "-":
|
||||||
|
if not a_sign:
|
||||||
|
return a_bits + 1, True
|
||||||
|
else:
|
||||||
|
return a_bits, a_sign
|
||||||
|
if self.op == "b":
|
||||||
|
return 1, False
|
||||||
|
elif len(op_shapes) == 2:
|
||||||
|
(a_bits, a_sign), (b_bits, b_sign) = op_shapes
|
||||||
if self.op == "+" or self.op == "-":
|
if self.op == "+" or self.op == "-":
|
||||||
if len(obs) == 1:
|
bits, sign = self._bitwise_binary_shape(*op_shapes)
|
||||||
if self.op == "-" and not obs[0][1]:
|
return bits + 1, sign
|
||||||
return obs[0][0] + 1, True
|
if self.op == "*":
|
||||||
else:
|
if not a_sign and not b_sign:
|
||||||
return obs[0]
|
|
||||||
n, s = self._bitwise_binary_shape(*obs)
|
|
||||||
return n + 1, s
|
|
||||||
elif self.op == "*":
|
|
||||||
if not obs[0][1] and not obs[1][1]:
|
|
||||||
# both operands unsigned
|
# both operands unsigned
|
||||||
return obs[0][0] + obs[1][0], False
|
return a_bits + b_bits, False
|
||||||
elif obs[0][1] and obs[1][1]:
|
if a_sign and b_sign:
|
||||||
# both operands signed
|
# both operands signed
|
||||||
return obs[0][0] + obs[1][0] - 1, True
|
return a_bits + b_bits - 1, True
|
||||||
else:
|
|
||||||
# one operand signed, the other unsigned (add sign bit)
|
# one operand signed, the other unsigned (add sign bit)
|
||||||
return obs[0][0] + obs[1][0] + 1 - 1, True
|
return a_bits + b_bits + 1 - 1, True
|
||||||
elif self.op == "<<<":
|
if self.op in ("<", "<=", "==", "!=", ">", ">=", "b"):
|
||||||
if obs[1][1]:
|
return 1, False
|
||||||
extra = 2**(obs[1][0] - 1) - 1
|
if self.op in ("&", "^", "|"):
|
||||||
|
return self._bitwise_binary_shape(*op_shapes)
|
||||||
|
if self.op == "<<<":
|
||||||
|
if b_sign:
|
||||||
|
extra = 2**(b_bits - 1) - 1
|
||||||
else:
|
else:
|
||||||
extra = 2**obs[1][0] - 1
|
extra = 2**b_bits - 1
|
||||||
return obs[0][0] + extra, obs[0][1]
|
return a_bits + extra, a_sign
|
||||||
elif self.op == ">>>":
|
if self.op == ">>>":
|
||||||
if obs[1][1]:
|
if b_sign:
|
||||||
extra = 2**(obs[1][0] - 1)
|
extra = 2**(b_bits - 1)
|
||||||
else:
|
else:
|
||||||
extra = 0
|
extra = 0
|
||||||
return obs[0][0] + extra, obs[0][1]
|
return a_bits + extra, a_sign
|
||||||
elif self.op in ("&", "^", "|"):
|
elif len(op_shapes) == 3:
|
||||||
return self._bitwise_binary_shape(*obs)
|
if self.op == "m":
|
||||||
elif self.op in ("<", "<=", "==", "!=", ">", ">=", "b"):
|
s_shape, a_shape, b_shape = op_shapes
|
||||||
return 1, False
|
return self._bitwise_binary_shape(a_shape, b_shape)
|
||||||
elif self.op == "~":
|
raise NotImplementedError("Operator {}/{} not implemented"
|
||||||
return obs[0]
|
.format(self.op, len(op_shapes))) # :nocov:
|
||||||
elif self.op == "m":
|
|
||||||
return self._bitwise_binary_shape(obs[1], obs[2])
|
|
||||||
raise NotImplementedError("Operator '{}' not implemented".format(self.op)) # :nocov:
|
|
||||||
|
|
||||||
def _rhs_signals(self):
|
def _rhs_signals(self):
|
||||||
return union(op._rhs_signals() for op in self.operands)
|
return union(op._rhs_signals() for op in self.operands)
|
||||||
|
|
|
@ -88,6 +88,11 @@ class ConstTestCase(FHDLTestCase):
|
||||||
|
|
||||||
|
|
||||||
class OperatorTestCase(FHDLTestCase):
|
class OperatorTestCase(FHDLTestCase):
|
||||||
|
def test_bool(self):
|
||||||
|
v = Const(0, 4).bool()
|
||||||
|
self.assertEqual(repr(v), "(b (const 4'd0))")
|
||||||
|
self.assertEqual(v.shape(), (1, False))
|
||||||
|
|
||||||
def test_invert(self):
|
def test_invert(self):
|
||||||
v = ~Const(0, 4)
|
v = ~Const(0, 4)
|
||||||
self.assertEqual(repr(v), "(~ (const 4'd0))")
|
self.assertEqual(repr(v), "(~ (const 4'd0))")
|
||||||
|
|
Loading…
Reference in a new issue