fhdl.ast, back.pysim: implement shifts.

This commit is contained in:
whitequark 2018-12-15 09:58:30 +00:00
parent 46f5addf05
commit db4600d52b
5 changed files with 50 additions and 16 deletions

View file

@ -75,15 +75,23 @@ class _RHSValueCompiler(ValueTransformer):
elif len(value.operands) == 2: elif len(value.operands) == 2:
lhs, rhs = map(self, value.operands) lhs, rhs = map(self, value.operands)
if value.op == "+": if value.op == "+":
return lambda state: normalize(lhs(state) + rhs(state), shape) return lambda state: normalize(lhs(state) + rhs(state), shape)
if value.op == "-": if value.op == "-":
return lambda state: normalize(lhs(state) - rhs(state), shape) return lambda state: normalize(lhs(state) - rhs(state), shape)
if value.op == "&": if value.op == "&":
return lambda state: normalize(lhs(state) & rhs(state), shape) return lambda state: normalize(lhs(state) & rhs(state), shape)
if value.op == "|": if value.op == "|":
return lambda state: normalize(lhs(state) | rhs(state), shape) return lambda state: normalize(lhs(state) | rhs(state), shape)
if value.op == "^": if value.op == "^":
return lambda state: normalize(lhs(state) ^ rhs(state), shape) return lambda state: normalize(lhs(state) ^ rhs(state), shape)
if value.op == "<<":
def sshl(lhs, rhs):
return lhs << rhs if rhs >= 0 else lhs >> -rhs
return lambda state: normalize(sshl(lhs(state), rhs(state)), shape)
if value.op == ">>":
def sshr(lhs, rhs):
return lhs >> rhs if rhs >= 0 else lhs << -rhs
return lambda state: normalize(sshr(lhs(state), rhs(state)), shape)
if value.op == "==": if value.op == "==":
return lambda state: normalize(lhs(state) == rhs(state), shape) return lambda state: normalize(lhs(state) == rhs(state), shape)
if value.op == "!=": if value.op == "!=":

View file

@ -206,8 +206,8 @@ class _ValueTransformer(xfrm.ValueTransformer):
(2, "/"): "$div", (2, "/"): "$div",
(2, "%"): "$mod", (2, "%"): "$mod",
(2, "**"): "$pow", (2, "**"): "$pow",
(2, "<<<"): "$sshl", (2, "<<"): "$sshl",
(2, ">>>"): "$sshr", (2, ">>"): "$sshr",
(2, "&"): "$and", (2, "&"): "$and",
(2, "^"): "$xor", (2, "^"): "$xor",
(2, "|"): "$or", (2, "|"): "$or",

View file

@ -75,13 +75,13 @@ class Value(metaclass=ABCMeta):
def __rdiv__(self, other): def __rdiv__(self, other):
return Operator("/", [other, self]) return Operator("/", [other, self])
def __lshift__(self, other): def __lshift__(self, other):
return Operator("<<<", [self, other]) return Operator("<<", [self, other])
def __rlshift__(self, other): def __rlshift__(self, other):
return Operator("<<<", [other, self]) return Operator("<<", [other, self])
def __rshift__(self, other): def __rshift__(self, other):
return Operator(">>>", [self, other]) return Operator(">>", [self, other])
def __rrshift__(self, other): def __rrshift__(self, other):
return Operator(">>>", [other, self]) return Operator(">>", [other, self])
def __and__(self, other): def __and__(self, other):
return Operator("&", [self, other]) return Operator("&", [self, other])
def __rand__(self, other): def __rand__(self, other):
@ -306,15 +306,15 @@ class Operator(Value):
return 1, False return 1, False
if self.op in ("&", "^", "|"): if self.op in ("&", "^", "|"):
return self._bitwise_binary_shape(*op_shapes) return self._bitwise_binary_shape(*op_shapes)
if self.op == "<<<": if self.op == "<<":
if b_sign: if b_sign:
extra = 2**(b_bits - 1) - 1 extra = 2 ** (b_bits - 1) - 1
else: else:
extra = 2**b_bits - 1 extra = 2 ** (b_bits) - 1
return a_bits + extra, a_sign return a_bits + extra, a_sign
if self.op == ">>>": if self.op == ">>":
if b_sign: if b_sign:
extra = 2**(b_bits - 1) extra = 2 ** (b_bits - 1)
else: else:
extra = 0 extra = 0
return a_bits + extra, a_sign return a_bits + extra, a_sign

View file

@ -182,6 +182,20 @@ class OperatorTestCase(FHDLTestCase):
v5 = 10 ^ Const(0, 4) v5 = 10 ^ Const(0, 4)
self.assertEqual(v5.shape(), (4, False)) self.assertEqual(v5.shape(), (4, False))
def test_shl(self):
v1 = Const(1, 4) << Const(4)
self.assertEqual(repr(v1), "(<< (const 4'd1) (const 3'd4))")
self.assertEqual(v1.shape(), (11, False))
v2 = Const(1, 4) << Const(-3)
self.assertEqual(v2.shape(), (7, False))
def test_shr(self):
v1 = Const(1, 4) >> Const(4)
self.assertEqual(repr(v1), "(>> (const 4'd1) (const 3'd4))")
self.assertEqual(v1.shape(), (4, False))
v2 = Const(1, 4) >> Const(-3)
self.assertEqual(v2.shape(), (8, False))
def test_lt(self): def test_lt(self):
v = Const(0, 4) < Const(0, 6) v = Const(0, 4) < Const(0, 6)
self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")

View file

@ -71,6 +71,18 @@ class SimulatorUnitTestCase(FHDLTestCase):
stmt = lambda a, b: a ^ b stmt = lambda a, b: a ^ b
self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4)) self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4))
def test_shl(self):
stmt = lambda a, b: a << b
self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5))
self.assertOperator(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7))
self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7))
def test_shr(self):
stmt = lambda a, b: a >> b
self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4))
self.assertOperator(stmt, [C(0b1001, 4), C(2)], C(0b10, 4))
self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5))
def test_eq(self): def test_eq(self):
stmt = lambda a, b: a == b stmt = lambda a, b: a == b
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1)) self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1))