test.sim: generalize assertOperator. NFC.

This commit is contained in:
whitequark 2018-12-15 21:08:29 +00:00
parent bdb8db2826
commit d9579219ee

View file

@ -5,7 +5,7 @@ from ..back.pysim import *
class SimulatorUnitTestCase(FHDLTestCase): class SimulatorUnitTestCase(FHDLTestCase):
def assertOperator(self, stmt, inputs, output): def assertStatement(self, stmt, inputs, output):
inputs = [Value.wrap(i) for i in inputs] inputs = [Value.wrap(i) for i in inputs]
output = Value.wrap(output) output = Value.wrap(output)
@ -13,7 +13,7 @@ class SimulatorUnitTestCase(FHDLTestCase):
osig = Signal(output.shape(), name="y") osig = Signal(output.shape(), name="y")
frag = Fragment() frag = Fragment()
frag.add_statements(osig.eq(stmt(*isigs))) frag.add_statements(stmt(osig, *isigs))
frag.add_driver(osig) frag.add_driver(osig)
with Simulator(frag, with Simulator(frag,
@ -29,139 +29,140 @@ class SimulatorUnitTestCase(FHDLTestCase):
sim.run() sim.run()
def test_invert(self): def test_invert(self):
stmt = lambda a: ~a stmt = lambda y, a: y.eq(~a)
self.assertOperator(stmt, [C(0b0000, 4)], C(0b1111, 4)) self.assertStatement(stmt, [C(0b0000, 4)], C(0b1111, 4))
self.assertOperator(stmt, [C(0b1010, 4)], C(0b0101, 4)) self.assertStatement(stmt, [C(0b1010, 4)], C(0b0101, 4))
self.assertOperator(stmt, [C(0, 4)], C(-1, 4)) self.assertStatement(stmt, [C(0, 4)], C(-1, 4))
def test_neg(self): def test_neg(self):
stmt = lambda a: -a stmt = lambda y, a: y.eq(-a)
self.assertOperator(stmt, [C(0b0000, 4)], C(0b0000, 4)) self.assertStatement(stmt, [C(0b0000, 4)], C(0b0000, 4))
self.assertOperator(stmt, [C(0b0001, 4)], C(0b1111, 4)) self.assertStatement(stmt, [C(0b0001, 4)], C(0b1111, 4))
self.assertOperator(stmt, [C(0b1010, 4)], C(0b0110, 4)) self.assertStatement(stmt, [C(0b1010, 4)], C(0b0110, 4))
self.assertOperator(stmt, [C(1, 4)], C(-1, 4)) self.assertStatement(stmt, [C(1, 4)], C(-1, 4))
self.assertOperator(stmt, [C(5, 4)], C(-5, 4)) self.assertStatement(stmt, [C(5, 4)], C(-5, 4))
def test_bool(self): def test_bool(self):
stmt = lambda a: a.bool() stmt = lambda y, a: y.eq(a.bool())
self.assertOperator(stmt, [C(0, 4)], C(0)) self.assertStatement(stmt, [C(0, 4)], C(0))
self.assertOperator(stmt, [C(1, 4)], C(1)) self.assertStatement(stmt, [C(1, 4)], C(1))
self.assertOperator(stmt, [C(2, 4)], C(1)) self.assertStatement(stmt, [C(2, 4)], C(1))
def test_add(self): def test_add(self):
stmt = lambda a, b: a + b stmt = lambda y, a, b: y.eq(a + b)
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(1, 4)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1, 4))
self.assertOperator(stmt, [C(-5, 4), C(-5, 4)], C(-10, 5)) self.assertStatement(stmt, [C(-5, 4), C(-5, 4)], C(-10, 5))
def test_sub(self): def test_sub(self):
stmt = lambda a, b: a - b stmt = lambda y, a, b: y.eq(a - b)
self.assertOperator(stmt, [C(2, 4), C(1, 4)], C(1, 4)) self.assertStatement(stmt, [C(2, 4), C(1, 4)], C(1, 4))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(-1, 4)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(-1, 4))
self.assertOperator(stmt, [C(0, 4), C(10, 4)], C(-10, 5)) self.assertStatement(stmt, [C(0, 4), C(10, 4)], C(-10, 5))
def test_and(self): def test_and(self):
stmt = lambda a, b: a & b stmt = lambda y, a, b: y.eq(a & b)
self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1000, 4)) self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1000, 4))
def test_or(self): def test_or(self):
stmt = lambda a, b: a | b stmt = lambda y, a, b: y.eq(a | b)
self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1110, 4)) self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1110, 4))
def test_xor(self): def test_xor(self):
stmt = lambda a, b: a ^ b stmt = lambda y, a, b: y.eq(a ^ b)
self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4)) self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4))
def test_shl(self): def test_shl(self):
stmt = lambda a, b: a << b stmt = lambda y, a, b: y.eq(a << b)
self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5)) self.assertStatement(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5))
self.assertOperator(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7)) self.assertStatement(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7))
self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7)) self.assertStatement(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7))
def test_shr(self): def test_shr(self):
stmt = lambda a, b: a >> b stmt = lambda y, a, b: y.eq(a >> b)
self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4)) self.assertStatement(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4))
self.assertOperator(stmt, [C(0b1001, 4), C(2)], C(0b10, 4)) self.assertStatement(stmt, [C(0b1001, 4), C(2)], C(0b10, 4))
self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5)) self.assertStatement(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5))
def test_eq(self): def test_eq(self):
stmt = lambda a, b: a == b stmt = lambda y, a, b: y.eq(a == b)
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1)) self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(1))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(0)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(0))
self.assertOperator(stmt, [C(1, 4), C(0, 4)], C(0)) self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(0))
def test_ne(self): def test_ne(self):
stmt = lambda a, b: a != b stmt = lambda y, a, b: y.eq(a != b)
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(0)) self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(0))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(1)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1))
self.assertOperator(stmt, [C(1, 4), C(0, 4)], C(1)) self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(1))
def test_lt(self): def test_lt(self):
stmt = lambda a, b: a < b stmt = lambda y, a, b: y.eq(a < b)
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(0)) self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(0))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(1)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1))
self.assertOperator(stmt, [C(1, 4), C(0, 4)], C(0)) self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(0))
def test_ge(self): def test_ge(self):
stmt = lambda a, b: a >= b stmt = lambda y, a, b: y.eq(a >= b)
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1)) self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(1))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(0)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(0))
self.assertOperator(stmt, [C(1, 4), C(0, 4)], C(1)) self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(1))
def test_gt(self): def test_gt(self):
stmt = lambda a, b: a > b stmt = lambda y, a, b: y.eq(a > b)
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(0)) self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(0))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(0)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(0))
self.assertOperator(stmt, [C(1, 4), C(0, 4)], C(1)) self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(1))
def test_le(self): def test_le(self):
stmt = lambda a, b: a <= b stmt = lambda y, a, b: y.eq(a <= b)
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1)) self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(1))
self.assertOperator(stmt, [C(0, 4), C(1, 4)], C(1)) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1))
self.assertOperator(stmt, [C(1, 4), C(0, 4)], C(0)) self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(0))
def test_mux(self): def test_mux(self):
stmt = lambda a, b, c: Mux(c, a, b) stmt = lambda y, a, b, c: y.eq(Mux(c, a, b))
self.assertOperator(stmt, [C(2, 4), C(3, 4), C(0)], C(3, 4)) self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0)], C(3, 4))
self.assertOperator(stmt, [C(2, 4), C(3, 4), C(1)], C(2, 4)) self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1)], C(2, 4))
def test_slice(self): def test_slice(self):
stmt1 = lambda a: a[2] stmt1 = lambda y, a: y.eq(a[2])
self.assertOperator(stmt1, [C(0b10110100, 8)], C(0b1, 1)) self.assertStatement(stmt1, [C(0b10110100, 8)], C(0b1, 1))
stmt2 = lambda a: a[2:4] stmt2 = lambda y, a: y.eq(a[2:4])
self.assertOperator(stmt2, [C(0b10110100, 8)], C(0b01, 2)) self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2))
def test_part(self): def test_part(self):
stmt = lambda a, b: a.part(b, 3) stmt = lambda y, a, b: y.eq(a.part(b, 3))
self.assertOperator(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3))
self.assertOperator(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3))
self.assertOperator(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3))
def test_cat(self): def test_cat(self):
self.assertOperator(Cat, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4)) stmt = lambda y, *xs: y.eq(Cat(*xs))
self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4))
def test_repl(self): def test_repl(self):
stmt = lambda a: Repl(a, 3) stmt = lambda y, a: y.eq(Repl(a, 3))
self.assertOperator(stmt, [C(0b10, 2)], C(0b101010, 6)) self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6))
def test_array(self): def test_array(self):
array = Array([1, 4, 10]) array = Array([1, 4, 10])
stmt = lambda a: array[a] stmt = lambda y, a: y.eq(array[a])
self.assertOperator(stmt, [C(0)], C(1)) self.assertStatement(stmt, [C(0)], C(1))
self.assertOperator(stmt, [C(1)], C(4)) self.assertStatement(stmt, [C(1)], C(4))
self.assertOperator(stmt, [C(2)], C(10)) self.assertStatement(stmt, [C(2)], C(10))
def test_array_index(self): def test_array_index(self):
array = Array(Array(x * y for y in range(10)) for x in range(10)) array = Array(Array(x * y for y in range(10)) for x in range(10))
stmt = lambda a, b: array[a][b] stmt = lambda y, a, b: y.eq(array[a][b])
for x in range(10): for x in range(10):
for y in range(10): for y in range(10):
self.assertOperator(stmt, [C(x), C(y)], C(x * y)) self.assertStatement(stmt, [C(x), C(y)], C(x * y))
def test_array_attr(self): def test_array_attr(self):
from collections import namedtuple from collections import namedtuple
pair = namedtuple("pair", ("p", "n")) pair = namedtuple("pair", ("p", "n"))
array = Array(pair(x, -x) for x in range(10)) array = Array(pair(x, -x) for x in range(10))
stmt = lambda a: array[a].p + array[a].n stmt = lambda y, a: y.eq(array[a].p + array[a].n)
for i in range(10): for i in range(10):
self.assertOperator(stmt, [C(i)], C(0)) self.assertStatement(stmt, [C(i)], C(0))