back.pysim: implement LHS for Part, Slice, Cat, ArrayProxy.
This commit is contained in:
		
							parent
							
								
									d9579219ee
								
							
						
					
					
						commit
						d4e8d3e95a
					
				|  | @ -45,8 +45,9 @@ normalize = Const.normalize | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class _RHSValueCompiler(ValueTransformer): | class _RHSValueCompiler(ValueTransformer): | ||||||
|     def __init__(self, sensitivity=None): |     def __init__(self, sensitivity=None, mode="rhs"): | ||||||
|         self.sensitivity = sensitivity |         self.sensitivity = sensitivity | ||||||
|  |         self.signal_mode = mode | ||||||
| 
 | 
 | ||||||
|     def on_Const(self, value): |     def on_Const(self, value): | ||||||
|         return lambda state: value.value |         return lambda state: value.value | ||||||
|  | @ -54,7 +55,12 @@ class _RHSValueCompiler(ValueTransformer): | ||||||
|     def on_Signal(self, value): |     def on_Signal(self, value): | ||||||
|         if self.sensitivity is not None: |         if self.sensitivity is not None: | ||||||
|             self.sensitivity.add(value) |             self.sensitivity.add(value) | ||||||
|         return lambda state: state.curr[value] |         if self.signal_mode == "rhs": | ||||||
|  |             return lambda state: state.curr[value] | ||||||
|  |         elif self.signal_mode == "lhs": | ||||||
|  |             return lambda state: state.next[value] | ||||||
|  |         else: | ||||||
|  |             raise ValueError # :nocov: | ||||||
| 
 | 
 | ||||||
|     def on_ClockSignal(self, value): |     def on_ClockSignal(self, value): | ||||||
|         raise NotImplementedError # :nocov: |         raise NotImplementedError # :nocov: | ||||||
|  | @ -160,11 +166,17 @@ class _RHSValueCompiler(ValueTransformer): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class _LHSValueCompiler(ValueTransformer): | class _LHSValueCompiler(ValueTransformer): | ||||||
|  |     def __init__(self, rhs_compiler): | ||||||
|  |         self.rhs_compiler = rhs_compiler | ||||||
|  | 
 | ||||||
|     def on_Const(self, value): |     def on_Const(self, value): | ||||||
|         raise TypeError # :nocov: |         raise TypeError # :nocov: | ||||||
| 
 | 
 | ||||||
|     def on_Signal(self, value): |     def on_Signal(self, value): | ||||||
|         return lambda state, arg: state.set(value, arg) |         shape = value.shape() | ||||||
|  |         def eval(state, rhs): | ||||||
|  |             state.set(value, normalize(rhs, shape)) | ||||||
|  |         return eval | ||||||
| 
 | 
 | ||||||
|     def on_ClockSignal(self, value): |     def on_ClockSignal(self, value): | ||||||
|         raise NotImplementedError # :nocov: |         raise NotImplementedError # :nocov: | ||||||
|  | @ -176,37 +188,69 @@ class _LHSValueCompiler(ValueTransformer): | ||||||
|         raise TypeError # :nocov: |         raise TypeError # :nocov: | ||||||
| 
 | 
 | ||||||
|     def on_Slice(self, value): |     def on_Slice(self, value): | ||||||
|         raise NotImplementedError |         lhs_r = self.rhs_compiler(value.value) | ||||||
|  |         lhs_l = self(value.value) | ||||||
|  |         shift = value.start | ||||||
|  |         mask  = (1 << (value.end - value.start)) - 1 | ||||||
|  |         def eval(state, rhs): | ||||||
|  |             lhs_value  = lhs_r(state) | ||||||
|  |             lhs_value &= ~(mask << shift) | ||||||
|  |             lhs_value |= (rhs & mask) << shift | ||||||
|  |             lhs_l(state, lhs_value) | ||||||
|  |         return eval | ||||||
| 
 | 
 | ||||||
|     def on_Part(self, value): |     def on_Part(self, value): | ||||||
|         raise NotImplementedError |         lhs_r = self.rhs_compiler(value.value) | ||||||
|  |         lhs_l = self(value.value) | ||||||
|  |         shift = self.rhs_compiler(value.offset) | ||||||
|  |         mask  = (1 << value.width) - 1 | ||||||
|  |         def eval(state, rhs): | ||||||
|  |             lhs_value   = lhs_r(state) | ||||||
|  |             shift_value = shift(state) | ||||||
|  |             lhs_value  &= ~(mask << shift_value) | ||||||
|  |             lhs_value  |= (rhs & mask) << shift_value | ||||||
|  |             lhs_l(state, lhs_value) | ||||||
|  |         return eval | ||||||
| 
 | 
 | ||||||
|     def on_Cat(self, value): |     def on_Cat(self, value): | ||||||
|         raise NotImplementedError |         parts  = [] | ||||||
|  |         offset = 0 | ||||||
|  |         for opnd in value.operands: | ||||||
|  |             parts.append((offset, (1 << len(opnd)) - 1, self(opnd))) | ||||||
|  |             offset += len(opnd) | ||||||
|  |         def eval(state, rhs): | ||||||
|  |             for offset, mask, opnd in parts: | ||||||
|  |                 opnd(state, (rhs >> offset) & mask) | ||||||
|  |         return eval | ||||||
| 
 | 
 | ||||||
|     def on_Repl(self, value): |     def on_Repl(self, value): | ||||||
|         raise TypeError # :nocov: |         raise TypeError # :nocov: | ||||||
| 
 | 
 | ||||||
|     def on_ArrayProxy(self, value): |     def on_ArrayProxy(self, value): | ||||||
|         raise NotImplementedError |         elems = list(map(self, value.elems)) | ||||||
|  |         index = self.rhs_compiler(value.index) | ||||||
|  |         def eval(state, rhs): | ||||||
|  |             elems[index(state)](state, rhs) | ||||||
|  |         return eval | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class _StatementCompiler(StatementTransformer): | class _StatementCompiler(StatementTransformer): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.sensitivity  = ValueSet() |         self.sensitivity   = ValueSet() | ||||||
|         self.rhs_compiler = _RHSValueCompiler(self.sensitivity) |         self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs") | ||||||
|         self.lhs_compiler = _LHSValueCompiler() |         self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs") | ||||||
|  |         self.lhs_compiler  = _LHSValueCompiler(self.lrhs_compiler) | ||||||
| 
 | 
 | ||||||
|     def on_Assign(self, stmt): |     def on_Assign(self, stmt): | ||||||
|         shape = stmt.lhs.shape() |         shape = stmt.lhs.shape() | ||||||
|         lhs   = self.lhs_compiler(stmt.lhs) |         lhs   = self.lhs_compiler(stmt.lhs) | ||||||
|         rhs   = self.rhs_compiler(stmt.rhs) |         rhs   = self.rrhs_compiler(stmt.rhs) | ||||||
|         def run(state): |         def run(state): | ||||||
|             lhs(state, normalize(rhs(state), shape)) |             lhs(state, normalize(rhs(state), shape)) | ||||||
|         return run |         return run | ||||||
| 
 | 
 | ||||||
|     def on_Switch(self, stmt): |     def on_Switch(self, stmt): | ||||||
|         test  = self.rhs_compiler(stmt.test) |         test  = self.rrhs_compiler(stmt.test) | ||||||
|         cases = [] |         cases = [] | ||||||
|         for value, stmts in stmt.cases.items(): |         for value, stmts in stmt.cases.items(): | ||||||
|             if "-" in value: |             if "-" in value: | ||||||
|  |  | ||||||
|  | @ -813,7 +813,7 @@ class Assign(Statement): | ||||||
|         return self.lhs._lhs_signals() |         return self.lhs._lhs_signals() | ||||||
| 
 | 
 | ||||||
|     def _rhs_signals(self): |     def _rhs_signals(self): | ||||||
|         return self.rhs._rhs_signals() |         return self.lhs._rhs_signals() | self.rhs._rhs_signals() | ||||||
| 
 | 
 | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "(eq {!r} {!r})".format(self.lhs, self.rhs) |         return "(eq {!r} {!r})".format(self.lhs, self.rhs) | ||||||
|  |  | ||||||
|  | @ -1,20 +1,23 @@ | ||||||
| from .tools import * | from .tools import * | ||||||
|  | from ..tools import flatten, union | ||||||
| from ..hdl.ast import * | from ..hdl.ast import * | ||||||
| from ..hdl.ir import * | from ..hdl.ir import * | ||||||
| from ..back.pysim import * | from ..back.pysim import * | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SimulatorUnitTestCase(FHDLTestCase): | class SimulatorUnitTestCase(FHDLTestCase): | ||||||
|     def assertStatement(self, stmt, inputs, output): |     def assertStatement(self, stmt, inputs, output, reset=0): | ||||||
|         inputs = [Value.wrap(i) for i in inputs] |         inputs = [Value.wrap(i) for i in inputs] | ||||||
|         output = Value.wrap(output) |         output = Value.wrap(output) | ||||||
| 
 | 
 | ||||||
|         isigs = [Signal(i.shape(), name=n) for i, n in zip(inputs, "abcd")] |         isigs = [Signal(i.shape(), name=n) for i, n in zip(inputs, "abcd")] | ||||||
|         osig  = Signal(output.shape(), name="y") |         osig  = Signal(output.shape(), name="y", reset=reset) | ||||||
| 
 | 
 | ||||||
|  |         stmt = stmt(osig, *isigs) | ||||||
|         frag = Fragment() |         frag = Fragment() | ||||||
|         frag.add_statements(stmt(osig, *isigs)) |         frag.add_statements(stmt) | ||||||
|         frag.add_driver(osig) |         for signal in flatten(s._lhs_signals() for s in Statement.wrap(stmt)): | ||||||
|  |             frag.add_driver(signal) | ||||||
| 
 | 
 | ||||||
|         with Simulator(frag, |         with Simulator(frag, | ||||||
|                 vcd_file =open("test.vcd",  "w"), |                 vcd_file =open("test.vcd",  "w"), | ||||||
|  | @ -130,16 +133,35 @@ class SimulatorUnitTestCase(FHDLTestCase): | ||||||
|         stmt2 = lambda y, a: y.eq(a[2:4]) |         stmt2 = lambda y, a: y.eq(a[2:4]) | ||||||
|         self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2)) |         self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2)) | ||||||
| 
 | 
 | ||||||
|  |     def test_slice_lhs(self): | ||||||
|  |         stmt1 = lambda y, a: y[2].eq(a) | ||||||
|  |         self.assertStatement(stmt1, [C(0b0,  1)], C(0b11111011, 8), reset=0b11111111) | ||||||
|  |         stmt2 = lambda y, a: y[2:4].eq(a) | ||||||
|  |         self.assertStatement(stmt2, [C(0b01, 2)], C(0b11110111, 8), reset=0b11111011) | ||||||
|  | 
 | ||||||
|     def test_part(self): |     def test_part(self): | ||||||
|         stmt = lambda y, a, b: y.eq(a.part(b, 3)) |         stmt = lambda y, a, b: y.eq(a.part(b, 3)) | ||||||
|         self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3)) |         self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3)) | ||||||
|         self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3)) |         self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3)) | ||||||
|         self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3)) |         self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3)) | ||||||
| 
 | 
 | ||||||
|  |     def test_part_lhs(self): | ||||||
|  |         stmt = lambda y, a, b: y.part(a, 3).eq(b) | ||||||
|  |         self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111) | ||||||
|  |         self.assertStatement(stmt, [C(2), C(0b101, 3)], C(0b11110111, 8), reset=0b11111111) | ||||||
|  |         self.assertStatement(stmt, [C(3), C(0b110, 3)], C(0b11110111, 8), reset=0b11111111) | ||||||
|  | 
 | ||||||
|     def test_cat(self): |     def test_cat(self): | ||||||
|         stmt = lambda y, *xs: y.eq(Cat(*xs)) |         stmt = lambda y, *xs: y.eq(Cat(*xs)) | ||||||
|         self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4)) |         self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4)) | ||||||
| 
 | 
 | ||||||
|  |     def test_cat_lhs(self): | ||||||
|  |         l = Signal(3) | ||||||
|  |         m = Signal(3) | ||||||
|  |         n = Signal(3) | ||||||
|  |         stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))] | ||||||
|  |         self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9)) | ||||||
|  | 
 | ||||||
|     def test_repl(self): |     def test_repl(self): | ||||||
|         stmt = lambda y, a: y.eq(Repl(a, 3)) |         stmt = lambda y, a: y.eq(Repl(a, 3)) | ||||||
|         self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6)) |         self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6)) | ||||||
|  | @ -151,6 +173,16 @@ class SimulatorUnitTestCase(FHDLTestCase): | ||||||
|         self.assertStatement(stmt, [C(1)], C(4)) |         self.assertStatement(stmt, [C(1)], C(4)) | ||||||
|         self.assertStatement(stmt, [C(2)], C(10)) |         self.assertStatement(stmt, [C(2)], C(10)) | ||||||
| 
 | 
 | ||||||
|  |     def test_array_lhs(self): | ||||||
|  |         l = Signal(3, reset=1) | ||||||
|  |         m = Signal(3, reset=4) | ||||||
|  |         n = Signal(3, reset=7) | ||||||
|  |         array = Array([l, m, n]) | ||||||
|  |         stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))] | ||||||
|  |         self.assertStatement(stmt, [C(0), C(0b000)], C(0b111100000)) | ||||||
|  |         self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001)) | ||||||
|  |         self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001)) | ||||||
|  | 
 | ||||||
|     def test_array_index(self): |     def test_array_index(self): | ||||||
|         array = Array(Array(x * y for y in range(10)) for x in range(10)) |         array = Array(Array(x * y for y in range(10)) for x in range(10)) | ||||||
|         stmt = lambda y, a, b: y.eq(array[a][b]) |         stmt = lambda y, a, b: y.eq(array[a][b]) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue
	
	 whitequark
						whitequark