hdl.ast: fix shape for subtraction.

Fixes #813.
This commit is contained in:
Marcelina Kościelnicka 2023-06-07 14:29:33 +02:00 committed by Catherine
parent 3180a17fd9
commit 1d5e090580
2 changed files with 8 additions and 3 deletions

View file

@ -730,9 +730,12 @@ class Operator(Value):
return Shape(a_shape.width, True) return Shape(a_shape.width, True)
elif len(op_shapes) == 2: elif len(op_shapes) == 2:
a_shape, b_shape = op_shapes a_shape, b_shape = op_shapes
if self.operator in ("+", "-"): if self.operator == "+":
o_shape = _bitwise_binary_shape(*op_shapes) o_shape = _bitwise_binary_shape(*op_shapes)
return Shape(o_shape.width + 1, o_shape.signed) return Shape(o_shape.width + 1, o_shape.signed)
if self.operator == "-":
o_shape = _bitwise_binary_shape(*op_shapes)
return Shape(o_shape.width + 1, True)
if self.operator == "*": if self.operator == "*":
return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed) return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed)
if self.operator == "//": if self.operator == "//":

View file

@ -443,7 +443,7 @@ class OperatorTestCase(FHDLTestCase):
def test_sub(self): def test_sub(self):
v1 = Const(0, unsigned(4)) - Const(0, unsigned(6)) v1 = Const(0, unsigned(4)) - Const(0, unsigned(6))
self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))") self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))")
self.assertEqual(v1.shape(), unsigned(7)) self.assertEqual(v1.shape(), signed(7))
v2 = Const(0, signed(4)) - Const(0, signed(6)) v2 = Const(0, signed(4)) - Const(0, signed(6))
self.assertEqual(v2.shape(), signed(7)) self.assertEqual(v2.shape(), signed(7))
v3 = Const(0, signed(4)) - Const(0, unsigned(4)) v3 = Const(0, signed(4)) - Const(0, unsigned(4))
@ -451,7 +451,9 @@ class OperatorTestCase(FHDLTestCase):
v4 = Const(0, unsigned(4)) - Const(0, signed(4)) v4 = Const(0, unsigned(4)) - Const(0, signed(4))
self.assertEqual(v4.shape(), signed(6)) self.assertEqual(v4.shape(), signed(6))
v5 = 10 - Const(0, 4) v5 = 10 - Const(0, 4)
self.assertEqual(v5.shape(), unsigned(5)) self.assertEqual(v5.shape(), signed(5))
v6 = 1 - Const(2)
self.assertEqual(v6.shape(), signed(3))
def test_mul(self): def test_mul(self):
v1 = Const(0, unsigned(4)) * Const(0, unsigned(6)) v1 = Const(0, unsigned(4)) * Const(0, unsigned(6))