back.pysim: new simulator backend (WIP).

This commit is contained in:
whitequark 2018-12-13 18:00:05 +00:00
parent 71f1f717c4
commit fb27c2520b
9 changed files with 437 additions and 17 deletions

2
.gitignore vendored
View file

@ -2,5 +2,7 @@
*.egg-info *.egg-info
*.il *.il
*.v *.v
*.vcd
*.gtkw
/.coverage /.coverage
/htmlcov /htmlcov

View file

@ -1,5 +1,5 @@
from nmigen.fhdl import * from nmigen.fhdl import *
from nmigen.back import rtlil, verilog from nmigen.back import rtlil, verilog, pysim
class ClockDivisor: class ClockDivisor:
@ -16,5 +16,10 @@ class ClockDivisor:
ctr = ClockDivisor(factor=16) ctr = ClockDivisor(factor=16)
frag = ctr.get_fragment(platform=None) frag = ctr.get_fragment(platform=None)
# print(rtlil.convert(frag, ports=[ctr.o])) # print(rtlil.convert(frag, ports=[ctr.o]))
print(verilog.convert(frag, ports=[ctr.o])) print(verilog.convert(frag, ports=[ctr.o]))
sim = pysim.Simulator(frag, vcd_file=open("clkdiv.vcd", "w"))
sim.add_clock("sync", 1e-6)
with sim: sim.run_until(100e-6, run_passive=True)

372
nmigen/back/pysim.py Normal file
View file

@ -0,0 +1,372 @@
from vcd import VCDWriter
from ..tools import flatten
from ..fhdl.ast import *
from ..fhdl.xfrm import ValueTransformer, StatementTransformer
__all__ = ["Simulator", "Delay", "Passive"]
class _State:
__slots__ = ("curr", "curr_dirty", "next", "next_dirty")
def __init__(self):
self.curr = ValueDict()
self.next = ValueDict()
self.curr_dirty = ValueSet()
self.next_dirty = ValueSet()
def get(self, signal):
return self.curr[signal]
def set_curr(self, signal, value):
assert isinstance(value, Const)
if self.curr[signal].value != value.value:
self.curr_dirty.add(signal)
self.curr[signal] = value
def set_next(self, signal, value):
assert isinstance(value, Const)
if self.next[signal].value != value.value:
self.next_dirty.add(signal)
self.next[signal] = value
def commit(self, signal):
old_value = self.curr[signal]
if self.curr[signal].value != self.next[signal].value:
self.next_dirty.remove(signal)
self.curr_dirty.add(signal)
self.curr[signal] = self.next[signal]
new_value = self.curr[signal]
return old_value, new_value
def iter_dirty(self):
dirty, self.dirty = self.dirty, ValueSet()
for signal in dirty:
yield signal, self.curr[signal], self.next[signal]
class _RHSValueCompiler(ValueTransformer):
def __init__(self, sensitivity):
self.sensitivity = sensitivity
def on_Const(self, value):
return lambda state: value
def on_Signal(self, value):
self.sensitivity.add(value)
return lambda state: state.get(value)
def on_ClockSignal(self, value):
raise NotImplementedError
def on_ResetSignal(self, value):
raise NotImplementedError
def on_Operator(self, value):
shape = value.shape()
if len(value.operands) == 1:
arg, = map(self, value.operands)
if value.op == "~":
return lambda state: Const(~arg(state).value, shape)
elif value.op == "-":
return lambda state: Const(-arg(state).value, shape)
elif len(value.operands) == 2:
lhs, rhs = map(self, value.operands)
if value.op == "+":
return lambda state: Const(lhs(state).value + rhs(state).value, shape)
if value.op == "-":
return lambda state: Const(lhs(state).value - rhs(state).value, shape)
if value.op == "&":
return lambda state: Const(lhs(state).value & rhs(state).value, shape)
if value.op == "|":
return lambda state: Const(lhs(state).value | rhs(state).value, shape)
if value.op == "^":
return lambda state: Const(lhs(state).value ^ rhs(state).value, shape)
elif value.op == "==":
lhs, rhs = map(self, value.operands)
return lambda state: Const(lhs(state).value == rhs(state).value, shape)
elif len(value.operands) == 3:
if value.op == "m":
sel, val1, val0 = map(self, value.operands)
return lambda state: val1(state) if sel(state).value else val0(state)
raise NotImplementedError("Operator '{}' not implemented".format(value.op))
def on_Slice(self, value):
shape = value.shape()
arg = self(value.value)
shift = value.start
mask = (1 << (value.end - value.start)) - 1
return lambda state: Const((arg(state).value >> shift) & mask, shape)
def on_Part(self, value):
raise NotImplementedError
def on_Cat(self, value):
shape = value.shape()
parts = []
offset = 0
for opnd in value.operands:
parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
offset += len(opnd)
def eval(state):
result = 0
for offset, mask, opnd in parts:
result |= (opnd(state).value & mask) << offset
return Const(result, shape)
return eval
def on_Repl(self, value):
shape = value.shape()
offset = len(value.value)
mask = (1 << len(value.value)) - 1
count = value.count
opnd = self(value.value)
def eval(state):
result = 0
for _ in range(count):
result <<= offset
result |= opnd(state).value
return Const(result, shape)
return eval
class _StatementCompiler(StatementTransformer):
def __init__(self):
self.sensitivity = ValueSet()
self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
def lhs_compiler(self, value):
# TODO
return lambda state, arg: state.set_next(value, arg)
def on_Assign(self, stmt):
assert isinstance(stmt.lhs, Signal)
shape = stmt.lhs.shape()
lhs = self.lhs_compiler(stmt.lhs)
rhs = self.rhs_compiler(stmt.rhs)
def run(state):
lhs(state, Const(rhs(state).value, shape))
return run
def on_Switch(self, stmt):
test = self.rhs_compiler(stmt.test)
cases = []
for value, stmts in stmt.cases.items():
if "-" in value:
mask = "".join("0" if b == "-" else "1" for b in value)
value = "".join("0" if b == "-" else b for b in value)
else:
mask = "1" * len(value)
mask = int(mask, 2)
value = int(value, 2)
cases.append((lambda test: test & mask == value,
self.on_statements(stmts)))
def run(state):
test_value = test(state).value
for check, body in cases:
if check(test_value):
body(state)
return
return run
def on_statements(self, stmts):
stmts = [self.on_statement(stmt) for stmt in stmts]
def run(state):
for stmt in stmts:
stmt(state)
return run
class Simulator:
def __init__(self, fragment=None, vcd_file=None):
self._fragments = {} # fragment -> hierarchy
self._domains = {} # str -> ClockDomain
self._domain_triggers = ValueDict() # Signal -> str
self._domain_signals = {} # str -> {Signal}
self._signals = ValueSet() # {Signal}
self._comb_signals = ValueSet() # {Signal}
self._sync_signals = ValueSet() # {Signal}
self._user_signals = ValueSet() # {Signal}
self._started = False
self._timestamp = 0.
self._state = _State()
self._processes = set() # {process}
self._passive = set() # {process}
self._suspended = {} # process -> until
self._handlers = ValueDict() # Signal -> lambda
self._vcd_file = vcd_file
self._vcd_writer = None
self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
if fragment is not None:
fragment = fragment.prepare()
self._add_fragment(fragment)
self._domains = fragment.domains
for domain, cd in self._domains.items():
self._domain_triggers[cd.clk] = domain
if cd.rst is not None:
self._domain_triggers[cd.rst] = domain
self._domain_signals[domain] = ValueSet()
def _add_fragment(self, fragment, hierarchy=("top",)):
self._fragments[fragment] = hierarchy
for subfragment, name in fragment.subfragments:
self._add_fragment(subfragment, (*hierarchy, name))
def add_process(self, fn):
self._processes.add(fn)
def add_clock(self, domain, period):
clk = self._domains[domain].clk
half_period = period / 2
def clk_process():
yield Passive()
while True:
yield clk.eq(1)
yield Delay(half_period)
yield clk.eq(0)
yield Delay(half_period)
self.add_process(clk_process())
def _signal_name_in_fragment(self, fragment, signal):
for subfragment, name in fragment.subfragments:
if signal in subfragment.ports:
return "{}_{}".format(name, signal.name)
return signal.name
def __enter__(self):
if self._vcd_file:
self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
comment="Generated by nMigen")
for fragment in self._fragments:
for signal in fragment.iter_signals():
self._signals.add(signal)
self._state.curr[signal] = self._state.next[signal] = \
Const(signal.reset, signal.shape())
self._state.curr_dirty.add(signal)
if signal not in self._vcd_signals:
self._vcd_signals[signal] = set()
name = self._signal_name_in_fragment(fragment, signal)
suffix = None
while True:
try:
if suffix is None:
name_suffix = name
else:
name_suffix = "{}${}".format(name, suffix)
self._vcd_signals[signal].add(self._vcd_writer.register_var(
scope=".".join(self._fragments[fragment]), name=name_suffix,
var_type="wire", size=signal.nbits, init=signal.reset))
break
except KeyError:
suffix = (suffix or 0) + 1
for domain, signals in fragment.drivers.items():
if domain is None:
self._comb_signals.update(signals)
else:
self._sync_signals.update(signals)
self._domain_signals[domain].update(signals)
compiler = _StatementCompiler()
handler = compiler(fragment.statements)
for signal in compiler.sensitivity:
self._handlers[signal] = handler
for domain, cd in fragment.domains.items():
self._handlers[cd.clk] = handler
if cd.rst is not None:
self._handlers[cd.rst] = handler
self._user_signals = self._signals - self._comb_signals - self._sync_signals
def _commit_signal(self, signal):
old, new = self._state.commit(signal)
if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
domain = self._domain_triggers[signal]
for sync_signal in self._state.next_dirty:
if sync_signal in self._domain_signals[domain]:
self._commit_signal(sync_signal)
if self._vcd_writer:
for vcd_signal in self._vcd_signals[signal]:
self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
def _handle_event(self):
while self._state.curr_dirty:
signal = self._state.curr_dirty.pop()
if signal in self._handlers:
self._handlers[signal](self._state)
for signal in self._state.next_dirty:
if signal in self._comb_signals or signal in self._user_signals:
self._commit_signal(signal)
def _force_signal(self, signal, value):
assert signal in self._comb_signals or signal in self._user_signals
self._state.set_next(signal, value)
self._commit_signal(signal)
def _run_process(self, proc):
try:
stmt = proc.send(None)
except StopIteration:
self._processes.remove(proc)
self._passive.remove(proc)
self._suspended.remove(proc)
return
if isinstance(stmt, Delay):
self._suspended[proc] = self._timestamp + stmt.interval
elif isinstance(stmt, Passive):
self._passive.add(proc)
elif isinstance(stmt, Assign):
assert isinstance(stmt.lhs, Signal)
assert isinstance(stmt.rhs, Const)
self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
else:
raise TypeError("Received unsupported statement '{!r}' from process {}"
.format(stmt, proc))
def step(self, run_passive=False):
# Are there any delta cycles we should run?
while self._state.curr_dirty:
self._timestamp += 1e-10
self._handle_event()
# Are there any processes that haven't had a chance to run yet?
if len(self._processes) > len(self._suspended):
# Schedule an arbitrary one.
proc = (self._processes - set(self._suspended)).pop()
self._run_process(proc)
return True
# All processes are suspended. Are any of them active?
if len(self._processes) > len(self._passive) or run_passive:
# Schedule the one with the lowest deadline.
proc, deadline = min(self._suspended.items(), key=lambda x: x[1])
del self._suspended[proc]
self._timestamp = deadline
self._run_process(proc)
return True
# No processes, or all processes are passive. Nothing to do!
return False
def run_until(self, deadline, run_passive=False):
while self._timestamp < deadline:
if not self.step(run_passive):
return False
return True
def __exit__(self, *args):
if self._vcd_writer:
self._vcd_writer.close(self._timestamp * 1e10)

View file

@ -10,7 +10,7 @@ from ..tools import *
__all__ = [ __all__ = [
"Value", "Const", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Value", "Const", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
"Signal", "ClockSignal", "ResetSignal", "Signal", "ClockSignal", "ResetSignal",
"Statement", "Assign", "Switch", "Statement", "Assign", "Switch", "Delay", "Passive",
"ValueKey", "ValueDict", "ValueSet", "ValueKey", "ValueDict", "ValueSet",
] ]
@ -216,17 +216,23 @@ class Const(Value):
nbits : int nbits : int
signed : bool signed : bool
""" """
src_loc = None
def __init__(self, value, shape=None): def __init__(self, value, shape=None):
super().__init__()
self.value = int(value) self.value = int(value)
if shape is None: if shape is None:
shape = self.value.bit_length(), self.value < 0 shape = bits_for(self.value), self.value < 0
if isinstance(shape, int): if isinstance(shape, int):
shape = shape, self.value < 0 shape = shape, self.value < 0
self.nbits, self.signed = shape self.nbits, self.signed = shape
if not isinstance(self.nbits, int) or self.nbits < 0: if not isinstance(self.nbits, int) or self.nbits < 0:
raise TypeError("Width must be a positive integer") raise TypeError("Width must be a positive integer")
mask = (1 << self.nbits) - 1
self.value &= mask
if self.signed and self.value >> (self.nbits - 1):
self.value |= ~mask
def shape(self): def shape(self):
return self.nbits, self.signed return self.nbits, self.signed
@ -347,6 +353,8 @@ class Slice(Value):
raise IndexError("Cannot end slice {} bits into {}-bit value".format(end, n)) raise IndexError("Cannot end slice {} bits into {}-bit value".format(end, n))
if end < 0: if end < 0:
end += n end += n
if start > end:
raise IndexError("Slice start {} must be less than slice end {}".format(start, end))
super().__init__() super().__init__()
self.value = Value.wrap(value) self.value = Value.wrap(value)
@ -680,6 +688,25 @@ class Switch(Statement):
return "(switch {!r} {})".format(self.test, " ".join(cases)) return "(switch {!r} {})".format(self.test, " ".join(cases))
class Delay(Statement):
def __init__(self, interval):
self.interval = float(interval)
def _rhs_signals(self):
return ValueSet()
def __repr__(self):
return "(delay {:.3}us)".format(self.interval * 10e6)
class Passive(Statement):
def _rhs_signals(self):
return ValueSet()
def __repr__(self):
return "(passive)"
class ValueKey: class ValueKey:
def __init__(self, value): def __init__(self, value):
self.value = Value.wrap(value) self.value = Value.wrap(value)

View file

@ -52,6 +52,11 @@ class Fragment:
signals = ValueSet() signals = ValueSet()
signals |= self.ports.keys() signals |= self.ports.keys()
for domain, domain_signals in self.drivers.items(): for domain, domain_signals in self.drivers.items():
if domain is not None:
cd = self.domains[domain]
signals.add(cd.clk)
if cd.rst is not None:
signals.add(cd.rst)
signals |= domain_signals signals |= domain_signals
return signals return signals

View file

@ -116,7 +116,7 @@ class DSLTestCase(FHDLTestCase):
( (
(switch (cat (sig s1) (sig s2)) (switch (cat (sig s1) (sig s2))
(case -1 (eq (sig c1) (const 1'd1))) (case -1 (eq (sig c1) (const 1'd1)))
(case 1- (eq (sig c2) (const 0'd0))) (case 1- (eq (sig c2) (const 1'd0)))
) )
) )
""") """)
@ -134,7 +134,7 @@ class DSLTestCase(FHDLTestCase):
( (
(switch (cat (sig s1) (sig s2)) (switch (cat (sig s1) (sig s2))
(case -1 (eq (sig c1) (const 1'd1))) (case -1 (eq (sig c1) (const 1'd1)))
(case 1- (eq (sig c2) (const 0'd0))) (case 1- (eq (sig c2) (const 1'd0)))
(case -- (eq (sig c3) (const 1'd1))) (case -- (eq (sig c3) (const 1'd1)))
) )
) )

View file

@ -59,10 +59,10 @@ class ValueTestCase(FHDLTestCase):
class ConstTestCase(FHDLTestCase): class ConstTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
self.assertEqual(Const(0).shape(), (0, False)) self.assertEqual(Const(0).shape(), (1, False))
self.assertEqual(Const(1).shape(), (1, False)) self.assertEqual(Const(1).shape(), (1, False))
self.assertEqual(Const(10).shape(), (4, False)) self.assertEqual(Const(10).shape(), (4, False))
self.assertEqual(Const(-10).shape(), (4, True)) self.assertEqual(Const(-10).shape(), (5, True))
self.assertEqual(Const(1, 4).shape(), (4, False)) self.assertEqual(Const(1, 4).shape(), (4, False))
self.assertEqual(Const(1, (4, True)).shape(), (4, True)) self.assertEqual(Const(1, (4, True)).shape(), (4, True))
@ -70,12 +70,15 @@ class ConstTestCase(FHDLTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
Const(1, -1) Const(1, -1)
def test_normalization(self):
self.assertEqual(Const(0b10110, (5, True)).value, -10)
def test_value(self): def test_value(self):
self.assertEqual(Const(10).value, 10) self.assertEqual(Const(10).value, 10)
def test_repr(self): def test_repr(self):
self.assertEqual(repr(Const(10)), "(const 4'd10)") self.assertEqual(repr(Const(10)), "(const 4'd10)")
self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") self.assertEqual(repr(Const(-10)), "(const 5'sd-10)")
def test_hash(self): def test_hash(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
@ -205,7 +208,7 @@ class OperatorTestCase(FHDLTestCase):
def test_mux(self): def test_mux(self):
s = Const(0) s = Const(0)
v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))")
self.assertEqual(v1.shape(), (6, False)) self.assertEqual(v1.shape(), (6, False))
v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
self.assertEqual(v2.shape(), (6, True)) self.assertEqual(v2.shape(), (6, True))
@ -216,7 +219,7 @@ class OperatorTestCase(FHDLTestCase):
def test_bool(self): def test_bool(self):
v = Const(0).bool() v = Const(0).bool()
self.assertEqual(repr(v), "(b (const 0'd0))") self.assertEqual(repr(v), "(b (const 1'd0))")
self.assertEqual(v.shape(), (1, False)) self.assertEqual(v.shape(), (1, False))
def test_hash(self): def test_hash(self):
@ -243,7 +246,7 @@ class CatTestCase(FHDLTestCase):
c2 = Cat(Const(10), Const(1)) c2 = Cat(Const(10), Const(1))
self.assertEqual(c2.shape(), (5, False)) self.assertEqual(c2.shape(), (5, False))
c3 = Cat(Const(10), Const(1), Const(0)) c3 = Cat(Const(10), Const(1), Const(0))
self.assertEqual(c3.shape(), (5, False)) self.assertEqual(c3.shape(), (6, False))
def test_repr(self): def test_repr(self):
c1 = Cat(Const(10), Const(1)) c1 = Cat(Const(10), Const(1))

View file

@ -32,7 +32,7 @@ class DomainRenamerTestCase(FHDLTestCase):
( (
(eq (sig s1) (clk pix)) (eq (sig s1) (clk pix))
(eq (rst pix) (sig s2)) (eq (rst pix) (sig s2))
(eq (sig s3) (const 0'd0)) (eq (sig s3) (const 1'd0))
(eq (sig s4) (clk other)) (eq (sig s4) (clk other))
(eq (sig s5) (rst other)) (eq (sig s5) (rst other))
) )
@ -127,7 +127,7 @@ class ResetInserterTestCase(FHDLTestCase):
self.assertRepr(f.statements, """ self.assertRepr(f.statements, """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
(eq (sig s2) (const 0'd0)) (eq (sig s2) (const 1'd0))
(switch (sig c1) (switch (sig c1)
(case 1 (eq (sig s2) (const 1'd1))) (case 1 (eq (sig s2) (const 1'd1)))
) )
@ -144,7 +144,7 @@ class ResetInserterTestCase(FHDLTestCase):
f = ResetInserter(self.c1)(f) f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements, """
( (
(eq (sig s2) (const 0'd0)) (eq (sig s2) (const 1'd0))
(switch (sig c1) (switch (sig c1)
(case 1 (eq (sig s2) (const 1'd1))) (case 1 (eq (sig s2) (const 1'd1)))
) )
@ -161,7 +161,7 @@ class ResetInserterTestCase(FHDLTestCase):
f = ResetInserter(self.c1)(f) f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements, """
( (
(eq (sig s3) (const 0'd0)) (eq (sig s3) (const 1'd0))
(switch (sig c1) (switch (sig c1)
(case 1 ) (case 1 )
) )
@ -206,7 +206,7 @@ class CEInserterTestCase(FHDLTestCase):
self.assertRepr(f.statements, """ self.assertRepr(f.statements, """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
(eq (sig s2) (const 0'd0)) (eq (sig s2) (const 1'd0))
(switch (sig c1) (switch (sig c1)
(case 0 (eq (sig s2) (sig s2))) (case 0 (eq (sig s2) (sig s2)))
) )

View file

@ -13,5 +13,11 @@ setup(
description="Python toolbox for building complex digital hardware", description="Python toolbox for building complex digital hardware",
#long_description="""TODO""", #long_description="""TODO""",
license="BSD", license="BSD",
install_requires=["pyvcd"],
packages=find_packages(), packages=find_packages(),
project_urls={
#"Documentation": "https://glasgow.readthedocs.io/",
"Source Code": "https://github.com/m-labs/nmigen",
"Bug Tracker": "https://github.com/m-labs/nmigen/issues",
}
) )