sim: evaluate simulator commands in-place instead of compiling them.
This commit is contained in:
parent
967dabc2fe
commit
f71bee499d
|
@ -18,7 +18,7 @@ from .._unused import *
|
|||
__all__ = [
|
||||
"SyntaxError", "SyntaxWarning",
|
||||
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
|
||||
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue",
|
||||
"Value", "Const", "C", "AnyValue", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue",
|
||||
"Array", "ArrayProxy",
|
||||
"Signal", "ClockSignal", "ResetSignal",
|
||||
"ValueCastable", "ValueLike",
|
||||
|
|
|
@ -5,7 +5,7 @@ from ..hdl._ast import Statement, Assign, SignalSet, ValueCastable
|
|||
from ..hdl._mem import MemorySimRead, MemorySimWrite
|
||||
from .core import Tick, Settle, Delay, Passive, Active
|
||||
from ._base import BaseProcess, BaseMemoryState
|
||||
from ._pyrtl import _ValueCompiler, _RHSValueCompiler, _StatementCompiler
|
||||
from ._pyeval import eval_value, eval_assign
|
||||
|
||||
|
||||
__all__ = ["PyCoroProcess"]
|
||||
|
@ -28,11 +28,6 @@ class PyCoroProcess(BaseProcess):
|
|||
self.passive = False
|
||||
|
||||
self.coroutine = self.constructor()
|
||||
self.exec_locals = {
|
||||
"slots": self.state.slots,
|
||||
"result": None,
|
||||
**_ValueCompiler.helpers
|
||||
}
|
||||
self.waits_on = SignalSet()
|
||||
|
||||
def src_loc(self):
|
||||
|
@ -87,14 +82,11 @@ class PyCoroProcess(BaseProcess):
|
|||
if isinstance(command, ValueCastable):
|
||||
command = Value.cast(command)
|
||||
if isinstance(command, Value):
|
||||
exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
|
||||
self.exec_locals)
|
||||
response = Const(self.exec_locals["result"], command.shape()).value
|
||||
response = eval_value(self.state, command)
|
||||
|
||||
elif isinstance(command, Statement):
|
||||
exec(_StatementCompiler.compile(self.state, command),
|
||||
self.exec_locals)
|
||||
if isinstance(command, Assign) and self.testbench:
|
||||
elif isinstance(command, Assign):
|
||||
eval_assign(self.state, command.lhs, eval_value(self.state, command.rhs))
|
||||
if self.testbench:
|
||||
return True # assignment; run a delta cycle
|
||||
|
||||
elif type(command) is Tick:
|
||||
|
@ -132,21 +124,15 @@ class PyCoroProcess(BaseProcess):
|
|||
self.passive = False
|
||||
|
||||
elif type(command) is MemorySimRead:
|
||||
exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"),
|
||||
self.exec_locals)
|
||||
addr = Const(self.exec_locals["result"], command._addr.shape()).value
|
||||
addr = eval_value(self.state, command._addr)
|
||||
index = self.state.get_memory(command._memory)
|
||||
state = self.state.slots[index]
|
||||
assert isinstance(state, BaseMemoryState)
|
||||
response = state.read(addr)
|
||||
|
||||
elif type(command) is MemorySimWrite:
|
||||
exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"),
|
||||
self.exec_locals)
|
||||
addr = Const(self.exec_locals["result"], command._addr.shape()).value
|
||||
exec(_RHSValueCompiler.compile(self.state, command._data, mode="curr"),
|
||||
self.exec_locals)
|
||||
data = Const(self.exec_locals["result"], command._data.shape()).value
|
||||
addr = eval_value(self.state, command._addr)
|
||||
data = eval_value(self.state, command._data)
|
||||
index = self.state.get_memory(command._memory)
|
||||
state = self.state.slots[index]
|
||||
assert isinstance(state, BaseMemoryState)
|
||||
|
|
184
amaranth/sim/_pyeval.py
Normal file
184
amaranth/sim/_pyeval.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
from amaranth.hdl._ast import *
|
||||
|
||||
|
||||
def _eval_matches(test, patterns):
|
||||
if patterns is None:
|
||||
return True
|
||||
for pattern in patterns:
|
||||
if isinstance(pattern, str):
|
||||
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
|
||||
value = int("".join("0" if b == "-" else b for b in pattern), 2)
|
||||
if value == (mask & test):
|
||||
return True
|
||||
else:
|
||||
if pattern == test:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def eval_value(sim, value):
|
||||
if isinstance(value, Const):
|
||||
return value.value
|
||||
elif isinstance(value, Operator):
|
||||
if len(value.operands) == 1:
|
||||
op_a = eval_value(sim, value.operands[0])
|
||||
if value.operator in ("u", "s"):
|
||||
width = value.shape().width
|
||||
res = op_a
|
||||
res &= (1 << width) - 1
|
||||
if value.operator == "s" and res & (1 << (width - 1)):
|
||||
res |= -1 << (width - 1)
|
||||
return res
|
||||
elif value.operator == "-":
|
||||
return -op_a
|
||||
elif value.operator == "~":
|
||||
shape = value.shape()
|
||||
if shape.signed:
|
||||
return ~op_a
|
||||
else:
|
||||
return ~op_a & ((1 << shape.width) - 1)
|
||||
elif value.operator in ("b", "r|"):
|
||||
return int(op_a != 0)
|
||||
elif value.operator == "r&":
|
||||
width = value.operands[0].shape().width
|
||||
mask = (1 << width) - 1
|
||||
return int((op_a & mask) == mask)
|
||||
elif value.operator == "r^":
|
||||
width = value.operands[0].shape().width
|
||||
mask = (1 << width) - 1
|
||||
# Believe it or not, this is the fastest way to compute a sideways XOR in Python.
|
||||
return format(op_a & mask, 'b').count('1') % 2
|
||||
elif len(value.operands) == 2:
|
||||
op_a = eval_value(sim, value.operands[0])
|
||||
op_b = eval_value(sim, value.operands[1])
|
||||
if value.operator == "|":
|
||||
return op_a | op_b
|
||||
elif value.operator == "&":
|
||||
return op_a & op_b
|
||||
elif value.operator == "^":
|
||||
return op_a ^ op_b
|
||||
elif value.operator == "+":
|
||||
return op_a + op_b
|
||||
elif value.operator == "-":
|
||||
return op_a - op_b
|
||||
elif value.operator == "*":
|
||||
return op_a * op_b
|
||||
elif value.operator == "//":
|
||||
if op_b == 0:
|
||||
return 0
|
||||
return op_a // op_b
|
||||
elif value.operator == "%":
|
||||
if op_b == 0:
|
||||
return 0
|
||||
return op_a % op_b
|
||||
elif value.operator == "<<":
|
||||
return op_a << op_b
|
||||
elif value.operator == ">>":
|
||||
return op_a >> op_b
|
||||
elif value.operator == "==":
|
||||
return int(op_a == op_b)
|
||||
elif value.operator == "!=":
|
||||
return int(op_a != op_b)
|
||||
elif value.operator == "<":
|
||||
return int(op_a < op_b)
|
||||
elif value.operator == "<=":
|
||||
return int(op_a <= op_b)
|
||||
elif value.operator == ">":
|
||||
return int(op_a > op_b)
|
||||
elif value.operator == ">=":
|
||||
return int(op_a >= op_b)
|
||||
assert False # :nocov:
|
||||
elif isinstance(value, Slice):
|
||||
res = eval_value(sim, value.value)
|
||||
res >>= value.start
|
||||
width = value.stop - value.start
|
||||
return res & ((1 << width) - 1)
|
||||
elif isinstance(value, Part):
|
||||
res = eval_value(sim, value.value)
|
||||
offset = eval_value(sim, value.offset)
|
||||
offset *= value.stride
|
||||
res >>= offset
|
||||
return res & ((1 << value.width) - 1)
|
||||
elif isinstance(value, Concat):
|
||||
res = 0
|
||||
pos = 0
|
||||
for part in value.parts:
|
||||
width = len(part)
|
||||
part = eval_value(sim, part)
|
||||
part &= (1 << width) - 1
|
||||
res |= part << pos
|
||||
pos += width
|
||||
return res
|
||||
elif isinstance(value, SwitchValue):
|
||||
test = eval_value(sim, value.test)
|
||||
for patterns, val in value.cases:
|
||||
if _eval_matches(test, patterns):
|
||||
return eval_value(sim, val)
|
||||
return 0
|
||||
elif isinstance(value, Signal):
|
||||
slot = sim.get_signal(value)
|
||||
return sim.slots[slot].curr
|
||||
elif isinstance(value, (ResetSignal, ClockSignal, AnyValue, Initial)):
|
||||
raise ValueError(f"Value {value!r} cannot be used in simulation")
|
||||
else:
|
||||
assert False # :nocov:
|
||||
|
||||
|
||||
def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len):
|
||||
if isinstance(lhs, Operator) and lhs.operator in ("u", "s"):
|
||||
_eval_assign_inner(sim, lhs.operands[0], lhs_start, rhs, rhs_len)
|
||||
elif isinstance(lhs, Signal):
|
||||
lhs_stop = lhs_start + rhs_len
|
||||
if lhs_stop > len(lhs):
|
||||
lhs_stop = len(lhs)
|
||||
if lhs_start >= len(lhs):
|
||||
return
|
||||
slot = sim.get_signal(lhs)
|
||||
value = sim.slots[slot].next
|
||||
mask = (1 << lhs_stop) - (1 << lhs_start)
|
||||
value &= ~mask
|
||||
value |= (rhs << lhs_start) & mask
|
||||
value &= (1 << len(lhs)) - 1
|
||||
if lhs._signed and (value & (1 << (len(lhs) - 1))):
|
||||
value |= -1 << (len(lhs) - 1)
|
||||
sim.slots[slot].set(value)
|
||||
elif isinstance(lhs, Slice):
|
||||
_eval_assign_inner(sim, lhs.value, lhs_start + lhs.start, rhs, rhs_len)
|
||||
elif isinstance(lhs, Concat):
|
||||
part_stop = 0
|
||||
for part in lhs.parts:
|
||||
part_start = part_stop
|
||||
part_len = len(part)
|
||||
part_stop = part_start + part_len
|
||||
if lhs_start >= part_stop:
|
||||
continue
|
||||
if lhs_start + rhs_len <= part_start:
|
||||
continue
|
||||
if lhs_start < part_start:
|
||||
part_lhs_start = 0
|
||||
part_rhs_start = part_start - lhs_start
|
||||
else:
|
||||
part_lhs_start = lhs_start - part_start
|
||||
part_rhs_start = 0
|
||||
if lhs_start + rhs_len >= part_stop:
|
||||
part_rhs_len = part_stop - lhs_start - part_rhs_start
|
||||
else:
|
||||
part_rhs_len = rhs_len - part_rhs_start
|
||||
part_rhs = rhs >> part_rhs_start
|
||||
part_rhs &= (1 << part_rhs_len) - 1
|
||||
_eval_assign_inner(sim, part, part_lhs_start, part_rhs, part_rhs_len)
|
||||
elif isinstance(lhs, Part):
|
||||
offset = eval_value(sim, lhs.offset)
|
||||
offset *= lhs.stride
|
||||
_eval_assign_inner(sim, lhs.value, lhs_start + offset, rhs, rhs_len)
|
||||
elif isinstance(lhs, SwitchValue):
|
||||
test = eval_value(sim, lhs.test)
|
||||
for patterns, val in lhs.cases:
|
||||
if _eval_matches(test, patterns):
|
||||
_eval_assign_inner(sim, val, lhs_start, rhs, rhs_len)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Value {lhs!r} cannot be assigned")
|
||||
|
||||
def eval_assign(sim, lhs, value):
|
||||
_eval_assign_inner(sim, lhs, 0, value, len(lhs))
|
|
@ -43,11 +43,30 @@ class SimulatorUnitTestCase(FHDLTestCase):
|
|||
with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]):
|
||||
sim.run()
|
||||
|
||||
frag = Fragment()
|
||||
sim = Simulator(frag)
|
||||
def process():
|
||||
for isig, input in zip(isigs, inputs):
|
||||
yield isig.eq(input)
|
||||
yield Delay(0)
|
||||
if isinstance(stmt, Assign):
|
||||
yield stmt
|
||||
else:
|
||||
yield from stmt
|
||||
yield Delay(0)
|
||||
self.assertEqual((yield osig), output.value)
|
||||
sim.add_testbench(process)
|
||||
with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]):
|
||||
sim.run()
|
||||
|
||||
|
||||
def test_invert(self):
|
||||
stmt = lambda y, a: y.eq(~a)
|
||||
self.assertStatement(stmt, [C(0b0000, 4)], C(0b1111, 4))
|
||||
self.assertStatement(stmt, [C(0b1010, 4)], C(0b0101, 4))
|
||||
self.assertStatement(stmt, [C(0, 4)], C(-1, 4))
|
||||
self.assertStatement(stmt, [C(0b0000, signed(4))], C(-1, signed(4)))
|
||||
self.assertStatement(stmt, [C(0b1010, signed(4))], C(0b0101, signed(4)))
|
||||
|
||||
def test_neg(self):
|
||||
stmt = lambda y, a: y.eq(-a)
|
||||
|
@ -126,6 +145,7 @@ class SimulatorUnitTestCase(FHDLTestCase):
|
|||
|
||||
def test_floordiv(self):
|
||||
stmt = lambda y, a, b: y.eq(a // b)
|
||||
self.assertStatement(stmt, [C(2, 4), C(0, 4)], C(0, 8))
|
||||
self.assertStatement(stmt, [C(2, 4), C(1, 4)], C(2, 8))
|
||||
self.assertStatement(stmt, [C(2, 4), C(2, 4)], C(1, 8))
|
||||
self.assertStatement(stmt, [C(7, 4), C(2, 4)], C(3, 8))
|
||||
|
@ -285,6 +305,17 @@ class SimulatorUnitTestCase(FHDLTestCase):
|
|||
stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))]
|
||||
self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9))
|
||||
|
||||
def test_cat_slice_lhs(self):
|
||||
l = Signal(3)
|
||||
m = Signal(3)
|
||||
n = Signal(3)
|
||||
o = Signal(3)
|
||||
p = Signal(3)
|
||||
stmt = lambda y, a: [Cat(l, m, n, o, p).eq(-1), Cat(l, m, n, o, p)[4:11].eq(a), y.eq(Cat(p, o, n, m, l))]
|
||||
self.assertStatement(stmt, [C(0b0000000, 7)], C(0b111001000100111, 15))
|
||||
self.assertStatement(stmt, [C(0b1001011, 7)], C(0b111111010110111, 15))
|
||||
self.assertStatement(stmt, [C(0b1111111, 7)], C(0b111111111111111, 15))
|
||||
|
||||
def test_nested_cat_lhs(self):
|
||||
l = Signal(3)
|
||||
m = Signal(3)
|
||||
|
@ -327,6 +358,16 @@ class SimulatorUnitTestCase(FHDLTestCase):
|
|||
self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001))
|
||||
self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001))
|
||||
|
||||
def test_array_lhs_heterogenous(self):
|
||||
l = Signal(1, init=1)
|
||||
m = Signal(3, init=4)
|
||||
n = Signal(5, init=7)
|
||||
array = Array([l, m, n])
|
||||
stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
|
||||
self.assertStatement(stmt, [C(0), C(0b000)], C(0b001111000, 9))
|
||||
self.assertStatement(stmt, [C(1), C(0b010)], C(0b001110101, 9))
|
||||
self.assertStatement(stmt, [C(2), C(0b100)], C(0b001001001, 9))
|
||||
|
||||
def test_array_lhs_oob(self):
|
||||
l = Signal(3)
|
||||
m = Signal(3)
|
||||
|
|
Loading…
Reference in a new issue