amaranth/nmigen/back/pysim.py
2018-12-13 18:02:46 +00:00

373 lines
14 KiB
Python

from vcd import VCDWriter
from ..tools import flatten
from ..fhdl.ast import *
from ..fhdl.xfrm import ValueTransformer, StatementTransformer
__all__ = ["Simulator", "Delay", "Passive"]
class _State:
__slots__ = ("curr", "curr_dirty", "next", "next_dirty")
def __init__(self):
self.curr = ValueDict()
self.next = ValueDict()
self.curr_dirty = ValueSet()
self.next_dirty = ValueSet()
def get(self, signal):
return self.curr[signal]
def set_curr(self, signal, value):
assert isinstance(value, Const)
if self.curr[signal].value != value.value:
self.curr_dirty.add(signal)
self.curr[signal] = value
def set_next(self, signal, value):
assert isinstance(value, Const)
if self.next[signal].value != value.value:
self.next_dirty.add(signal)
self.next[signal] = value
def commit(self, signal):
old_value = self.curr[signal]
if self.curr[signal].value != self.next[signal].value:
self.next_dirty.remove(signal)
self.curr_dirty.add(signal)
self.curr[signal] = self.next[signal]
new_value = self.curr[signal]
return old_value, new_value
def iter_dirty(self):
dirty, self.dirty = self.dirty, ValueSet()
for signal in dirty:
yield signal, self.curr[signal], self.next[signal]
class _RHSValueCompiler(ValueTransformer):
def __init__(self, sensitivity):
self.sensitivity = sensitivity
def on_Const(self, value):
return lambda state: value
def on_Signal(self, value):
self.sensitivity.add(value)
return lambda state: state.get(value)
def on_ClockSignal(self, value):
raise NotImplementedError
def on_ResetSignal(self, value):
raise NotImplementedError
def on_Operator(self, value):
shape = value.shape()
if len(value.operands) == 1:
arg, = map(self, value.operands)
if value.op == "~":
return lambda state: Const(~arg(state).value, shape)
elif value.op == "-":
return lambda state: Const(-arg(state).value, shape)
elif len(value.operands) == 2:
lhs, rhs = map(self, value.operands)
if value.op == "+":
return lambda state: Const(lhs(state).value + rhs(state).value, shape)
if value.op == "-":
return lambda state: Const(lhs(state).value - rhs(state).value, shape)
if value.op == "&":
return lambda state: Const(lhs(state).value & rhs(state).value, shape)
if value.op == "|":
return lambda state: Const(lhs(state).value | rhs(state).value, shape)
if value.op == "^":
return lambda state: Const(lhs(state).value ^ rhs(state).value, shape)
elif value.op == "==":
lhs, rhs = map(self, value.operands)
return lambda state: Const(lhs(state).value == rhs(state).value, shape)
elif len(value.operands) == 3:
if value.op == "m":
sel, val1, val0 = map(self, value.operands)
return lambda state: val1(state) if sel(state).value else val0(state)
raise NotImplementedError("Operator '{}' not implemented".format(value.op))
def on_Slice(self, value):
shape = value.shape()
arg = self(value.value)
shift = value.start
mask = (1 << (value.end - value.start)) - 1
return lambda state: Const((arg(state).value >> shift) & mask, shape)
def on_Part(self, value):
raise NotImplementedError
def on_Cat(self, value):
shape = value.shape()
parts = []
offset = 0
for opnd in value.operands:
parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
offset += len(opnd)
def eval(state):
result = 0
for offset, mask, opnd in parts:
result |= (opnd(state).value & mask) << offset
return Const(result, shape)
return eval
def on_Repl(self, value):
shape = value.shape()
offset = len(value.value)
mask = (1 << len(value.value)) - 1
count = value.count
opnd = self(value.value)
def eval(state):
result = 0
for _ in range(count):
result <<= offset
result |= opnd(state).value
return Const(result, shape)
return eval
class _StatementCompiler(StatementTransformer):
def __init__(self):
self.sensitivity = ValueSet()
self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
def lhs_compiler(self, value):
# TODO
return lambda state, arg: state.set_next(value, arg)
def on_Assign(self, stmt):
assert isinstance(stmt.lhs, Signal)
shape = stmt.lhs.shape()
lhs = self.lhs_compiler(stmt.lhs)
rhs = self.rhs_compiler(stmt.rhs)
def run(state):
lhs(state, Const(rhs(state).value, shape))
return run
def on_Switch(self, stmt):
test = self.rhs_compiler(stmt.test)
cases = []
for value, stmts in stmt.cases.items():
if "-" in value:
mask = "".join("0" if b == "-" else "1" for b in value)
value = "".join("0" if b == "-" else b for b in value)
else:
mask = "1" * len(value)
mask = int(mask, 2)
value = int(value, 2)
cases.append((lambda test: test & mask == value,
self.on_statements(stmts)))
def run(state):
test_value = test(state).value
for check, body in cases:
if check(test_value):
body(state)
return
return run
def on_statements(self, stmts):
stmts = [self.on_statement(stmt) for stmt in stmts]
def run(state):
for stmt in stmts:
stmt(state)
return run
class Simulator:
def __init__(self, fragment=None, vcd_file=None):
self._fragments = {} # fragment -> hierarchy
self._domains = {} # str -> ClockDomain
self._domain_triggers = ValueDict() # Signal -> str
self._domain_signals = {} # str -> {Signal}
self._signals = ValueSet() # {Signal}
self._comb_signals = ValueSet() # {Signal}
self._sync_signals = ValueSet() # {Signal}
self._user_signals = ValueSet() # {Signal}
self._started = False
self._timestamp = 0.
self._state = _State()
self._processes = set() # {process}
self._passive = set() # {process}
self._suspended = {} # process -> until
self._handlers = ValueDict() # Signal -> lambda
self._vcd_file = vcd_file
self._vcd_writer = None
self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
if fragment is not None:
fragment = fragment.prepare()
self._add_fragment(fragment)
self._domains = fragment.domains
for domain, cd in self._domains.items():
self._domain_triggers[cd.clk] = domain
if cd.rst is not None:
self._domain_triggers[cd.rst] = domain
self._domain_signals[domain] = ValueSet()
def _add_fragment(self, fragment, hierarchy=("top",)):
self._fragments[fragment] = hierarchy
for subfragment, name in fragment.subfragments:
self._add_fragment(subfragment, (*hierarchy, name))
def add_process(self, fn):
self._processes.add(fn)
def add_clock(self, domain, period):
clk = self._domains[domain].clk
half_period = period / 2
def clk_process():
yield Passive()
while True:
yield clk.eq(1)
yield Delay(half_period)
yield clk.eq(0)
yield Delay(half_period)
self.add_process(clk_process())
def _signal_name_in_fragment(self, fragment, signal):
for subfragment, name in fragment.subfragments:
if signal in subfragment.ports:
return "{}_{}".format(name, signal.name)
return signal.name
def __enter__(self):
if self._vcd_file:
self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
comment="Generated by nMigen")
for fragment in self._fragments:
for signal in fragment.iter_signals():
self._signals.add(signal)
self._state.curr[signal] = self._state.next[signal] = \
Const(signal.reset, signal.shape())
self._state.curr_dirty.add(signal)
if signal not in self._vcd_signals:
self._vcd_signals[signal] = set()
name = self._signal_name_in_fragment(fragment, signal)
suffix = None
while True:
try:
if suffix is None:
name_suffix = name
else:
name_suffix = "{}${}".format(name, suffix)
self._vcd_signals[signal].add(self._vcd_writer.register_var(
scope=".".join(self._fragments[fragment]), name=name_suffix,
var_type="wire", size=signal.nbits, init=signal.reset))
break
except KeyError:
suffix = (suffix or 0) + 1
for domain, signals in fragment.drivers.items():
if domain is None:
self._comb_signals.update(signals)
else:
self._sync_signals.update(signals)
self._domain_signals[domain].update(signals)
compiler = _StatementCompiler()
handler = compiler(fragment.statements)
for signal in compiler.sensitivity:
self._handlers[signal] = handler
for domain, cd in fragment.domains.items():
self._handlers[cd.clk] = handler
if cd.rst is not None:
self._handlers[cd.rst] = handler
self._user_signals = self._signals - self._comb_signals - self._sync_signals
def _commit_signal(self, signal):
old, new = self._state.commit(signal)
if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
domain = self._domain_triggers[signal]
for sync_signal in self._state.next_dirty:
if sync_signal in self._domain_signals[domain]:
self._commit_signal(sync_signal)
if self._vcd_writer:
for vcd_signal in self._vcd_signals[signal]:
self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
def _handle_event(self):
while self._state.curr_dirty:
signal = self._state.curr_dirty.pop()
if signal in self._handlers:
self._handlers[signal](self._state)
for signal in self._state.next_dirty:
if signal in self._comb_signals or signal in self._user_signals:
self._commit_signal(signal)
def _force_signal(self, signal, value):
assert signal in self._comb_signals or signal in self._user_signals
self._state.set_next(signal, value)
self._commit_signal(signal)
def _run_process(self, proc):
try:
stmt = proc.send(None)
except StopIteration:
self._processes.remove(proc)
self._passive.remove(proc)
self._suspended.remove(proc)
return
if isinstance(stmt, Delay):
self._suspended[proc] = self._timestamp + stmt.interval
elif isinstance(stmt, Passive):
self._passive.add(proc)
elif isinstance(stmt, Assign):
assert isinstance(stmt.lhs, Signal)
assert isinstance(stmt.rhs, Const)
self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
else:
raise TypeError("Received unsupported statement '{!r}' from process {}"
.format(stmt, proc))
def step(self, run_passive=False):
# Are there any delta cycles we should run?
while self._state.curr_dirty:
self._timestamp += 1e-10
self._handle_event()
# Are there any processes that haven't had a chance to run yet?
if len(self._processes) > len(self._suspended):
# Schedule an arbitrary one.
proc = (self._processes - set(self._suspended)).pop()
self._run_process(proc)
return True
# All processes are suspended. Are any of them active?
if len(self._processes) > len(self._passive) or run_passive:
# Schedule the one with the lowest deadline.
proc, deadline = min(self._suspended.items(), key=lambda x: x[1])
del self._suspended[proc]
self._timestamp = deadline
self._run_process(proc)
return True
# No processes, or all processes are passive. Nothing to do!
return False
def run_until(self, deadline, run_passive=False):
while self._timestamp < deadline:
if not self.step(run_passive):
return False
return True
def __exit__(self, *args):
if self._vcd_writer:
self._vcd_writer.close(self._timestamp * 1e10)