hdl.ast: add const-shift operations.

Also, clean up the rotate code a bit.

Fixes #378.
This commit is contained in:
whitequark 2020-05-20 03:18:33 +00:00
parent 72ef4303a9
commit 7ea81f5f06
3 changed files with 145 additions and 22 deletions

View file

@ -423,41 +423,85 @@ class Value(metaclass=ABCMeta):
else: else:
return Cat(*matches).any() return Cat(*matches).any()
def rotate_left(self, offset): def shift_left(self, amount):
"""Shift left by constant amount.
Parameters
----------
amount : int
Amount to shift by.
Returns
-------
Value, out
If the amount is positive, the input shifted left. Otherwise, the input shifted right.
"""
if not isinstance(amount, int):
raise TypeError("Shift amount must be an integer, not {!r}".format(amount))
if amount < 0:
return self.shift_right(-amount)
if self.shape().signed:
return Cat(Const(0, amount), self).as_signed()
else:
return Cat(Const(0, amount), self) # unsigned
def shift_right(self, amount):
"""Shift right by constant amount.
Parameters
----------
amount : int
Amount to shift by.
Returns
-------
Value, out
If the amount is positive, the input shifted right. Otherwise, the input shifted left.
"""
if not isinstance(amount, int):
raise TypeError("Shift amount must be an integer, not {!r}".format(amount))
if amount < 0:
return self.shift_left(-amount)
if self.shape().signed:
return self[amount:].as_signed()
else:
return self[amount:] # unsigned
def rotate_left(self, amount):
"""Rotate left by constant amount. """Rotate left by constant amount.
Parameters Parameters
---------- ----------
offset : int amount : int
Amount to rotate by. Amount to rotate by.
Returns Returns
------- -------
Value, out Value, out
If the offset is positive, the input rotated left. Otherwise, the input rotated right. If the amount is positive, the input rotated left. Otherwise, the input rotated right.
""" """
if not isinstance(offset, int): if not isinstance(amount, int):
raise TypeError("Rotate amount must be an integer, not {!r}".format(offset)) raise TypeError("Rotate amount must be an integer, not {!r}".format(amount))
offset %= len(self) amount %= len(self)
return Cat(self[-offset:], self[:-offset]) # meow :3 return Cat(self[-amount:], self[:-amount]) # meow :3
def rotate_right(self, offset): def rotate_right(self, amount):
"""Rotate right by constant amount. """Rotate right by constant amount.
Parameters Parameters
---------- ----------
offset : int amount : int
Amount to rotate by. Amount to rotate by.
Returns Returns
------- -------
Value, out Value, out
If the offset is positive, the input rotated right. Otherwise, the input rotated right. If the amount is positive, the input rotated right. Otherwise, the input rotated right.
""" """
if not isinstance(offset, int): if not isinstance(amount, int):
raise TypeError("Rotate amount must be an integer, not {!r}".format(offset)) raise TypeError("Rotate amount must be an integer, not {!r}".format(amount))
offset %= len(self) amount %= len(self)
return Cat(self[offset:], self[:offset]) return Cat(self[amount:], self[:amount])
def eq(self, value): def eq(self, value):
"""Assignment. """Assignment.

View file

@ -205,11 +205,73 @@ class ValueTestCase(FHDLTestCase):
msg="Cannot index value with 'str'"): msg="Cannot index value with 'str'"):
Const(31)["str"] Const(31)["str"]
def test_shift_left(self):
self.assertRepr(Const(256, unsigned(9)).shift_left(0),
"(cat (const 0'd0) (const 9'd256))")
self.assertRepr(Const(256, unsigned(9)).shift_left(1),
"(cat (const 1'd0) (const 9'd256))")
self.assertRepr(Const(256, unsigned(9)).shift_left(5),
"(cat (const 5'd0) (const 9'd256))")
self.assertRepr(Const(256, signed(9)).shift_left(1),
"(s (cat (const 1'd0) (const 9'sd-256)))")
self.assertRepr(Const(256, signed(9)).shift_left(5),
"(s (cat (const 5'd0) (const 9'sd-256)))")
self.assertRepr(Const(256, unsigned(9)).shift_left(-1),
"(slice (const 9'd256) 1:9)")
self.assertRepr(Const(256, unsigned(9)).shift_left(-5),
"(slice (const 9'd256) 5:9)")
self.assertRepr(Const(256, signed(9)).shift_left(-1),
"(s (slice (const 9'sd-256) 1:9))")
self.assertRepr(Const(256, signed(9)).shift_left(-5),
"(s (slice (const 9'sd-256) 5:9))")
self.assertRepr(Const(256, signed(9)).shift_left(-15),
"(s (slice (const 9'sd-256) 9:9))")
def test_shift_left_wrong(self):
with self.assertRaises(TypeError,
msg="Shift amount must be an integer, not 'str'"):
Const(31).shift_left("str")
def test_shift_right(self):
self.assertRepr(Const(256, unsigned(9)).shift_right(0),
"(slice (const 9'd256) 0:9)")
self.assertRepr(Const(256, unsigned(9)).shift_right(-1),
"(cat (const 1'd0) (const 9'd256))")
self.assertRepr(Const(256, unsigned(9)).shift_right(-5),
"(cat (const 5'd0) (const 9'd256))")
self.assertRepr(Const(256, signed(9)).shift_right(-1),
"(s (cat (const 1'd0) (const 9'sd-256)))")
self.assertRepr(Const(256, signed(9)).shift_right(-5),
"(s (cat (const 5'd0) (const 9'sd-256)))")
self.assertRepr(Const(256, unsigned(9)).shift_right(1),
"(slice (const 9'd256) 1:9)")
self.assertRepr(Const(256, unsigned(9)).shift_right(5),
"(slice (const 9'd256) 5:9)")
self.assertRepr(Const(256, signed(9)).shift_right(1),
"(s (slice (const 9'sd-256) 1:9))")
self.assertRepr(Const(256, signed(9)).shift_right(5),
"(s (slice (const 9'sd-256) 5:9))")
self.assertRepr(Const(256, signed(9)).shift_right(15),
"(s (slice (const 9'sd-256) 9:9))")
def test_shift_right_wrong(self):
with self.assertRaises(TypeError,
msg="Shift amount must be an integer, not 'str'"):
Const(31).shift_left("str")
def test_rotate_left(self): def test_rotate_left(self):
self.assertRepr(Value.cast(256).rotate_left(1), "(cat (slice (const 9'd256) 8:9) (slice (const 9'd256) 0:8))") self.assertRepr(Const(256).rotate_left(1),
self.assertRepr(Value.cast(256).rotate_left(7), "(cat (slice (const 9'd256) 2:9) (slice (const 9'd256) 0:2))") "(cat (slice (const 9'd256) 8:9) (slice (const 9'd256) 0:8))")
self.assertRepr(Value.cast(256).rotate_left(-1), "(cat (slice (const 9'd256) 1:9) (slice (const 9'd256) 0:1))") self.assertRepr(Const(256).rotate_left(7),
self.assertRepr(Value.cast(256).rotate_left(-7), "(cat (slice (const 9'd256) 7:9) (slice (const 9'd256) 0:7))") "(cat (slice (const 9'd256) 2:9) (slice (const 9'd256) 0:2))")
self.assertRepr(Const(256).rotate_left(-1),
"(cat (slice (const 9'd256) 1:9) (slice (const 9'd256) 0:1))")
self.assertRepr(Const(256).rotate_left(-7),
"(cat (slice (const 9'd256) 7:9) (slice (const 9'd256) 0:7))")
def test_rotate_left_wrong(self): def test_rotate_left_wrong(self):
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
@ -217,16 +279,21 @@ class ValueTestCase(FHDLTestCase):
Const(31).rotate_left("str") Const(31).rotate_left("str")
def test_rotate_right(self): def test_rotate_right(self):
self.assertRepr(Value.cast(256).rotate_right(1), "(cat (slice (const 9'd256) 1:9) (slice (const 9'd256) 0:1))") self.assertRepr(Const(256).rotate_right(1),
self.assertRepr(Value.cast(256).rotate_right(7), "(cat (slice (const 9'd256) 7:9) (slice (const 9'd256) 0:7))") "(cat (slice (const 9'd256) 1:9) (slice (const 9'd256) 0:1))")
self.assertRepr(Value.cast(256).rotate_right(-1), "(cat (slice (const 9'd256) 8:9) (slice (const 9'd256) 0:8))") self.assertRepr(Const(256).rotate_right(7),
self.assertRepr(Value.cast(256).rotate_right(-7), "(cat (slice (const 9'd256) 2:9) (slice (const 9'd256) 0:2))") "(cat (slice (const 9'd256) 7:9) (slice (const 9'd256) 0:7))")
self.assertRepr(Const(256).rotate_right(-1),
"(cat (slice (const 9'd256) 8:9) (slice (const 9'd256) 0:8))")
self.assertRepr(Const(256).rotate_right(-7),
"(cat (slice (const 9'd256) 2:9) (slice (const 9'd256) 0:2))")
def test_rotate_right_wrong(self): def test_rotate_right_wrong(self):
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
msg="Rotate amount must be an integer, not 'str'"): msg="Rotate amount must be an integer, not 'str'"):
Const(31).rotate_right("str") Const(31).rotate_right("str")
class ConstTestCase(FHDLTestCase): class ConstTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
self.assertEqual(Const(0).shape(), unsigned(1)) self.assertEqual(Const(0).shape(), unsigned(1))

View file

@ -301,6 +301,18 @@ class SimulatorUnitTestCase(FHDLTestCase):
for i in range(10): for i in range(10):
self.assertStatement(stmt, [C(i)], C(0)) self.assertStatement(stmt, [C(i)], C(0))
def test_shift_left(self):
stmt1 = lambda y, a: y.eq(a.shift_left(1))
self.assertStatement(stmt1, [C(0b10100010, 8)], C( 0b101000100, 9))
stmt2 = lambda y, a: y.eq(a.shift_left(4))
self.assertStatement(stmt2, [C(0b10100010, 8)], C(0b101000100000, 12))
def test_shift_right(self):
stmt1 = lambda y, a: y.eq(a.shift_right(1))
self.assertStatement(stmt1, [C(0b10100010, 8)], C(0b1010001, 7))
stmt2 = lambda y, a: y.eq(a.shift_right(4))
self.assertStatement(stmt2, [C(0b10100010, 8)], C( 0b1010, 4))
def test_rotate_left(self): def test_rotate_left(self):
stmt = lambda y, a: y.eq(a.rotate_left(1)) stmt = lambda y, a: y.eq(a.rotate_left(1))
self.assertStatement(stmt, [C(0b1)], C(0b1)) self.assertStatement(stmt, [C(0b1)], C(0b1))