amaranth/nmigen/sim/_pyrtl.py
Marcelina Kościelnicka 44318149e0
sim._pyrtl: mask Mux selection operand.
Otherwise it behaves funny when it's eg. the result of operator ~.
2020-11-14 15:22:34 +00:00

457 lines
16 KiB
Python

import os
import tempfile
from contextlib import contextmanager
from ..hdl import *
from ..hdl.ast import SignalSet
from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
from ._base import BaseProcess
__all__ = ["PyRTLProcess"]
class PyRTLProcess(BaseProcess):
__slots__ = ("is_comb", "runnable", "passive", "run")
def __init__(self, *, is_comb):
self.is_comb = is_comb
self.reset()
def reset(self):
self.runnable = self.is_comb
self.passive = True
class _PythonEmitter:
def __init__(self):
self._buffer = []
self._suffix = 0
self._level = 0
def append(self, code):
self._buffer.append(" " * self._level)
self._buffer.append(code)
self._buffer.append("\n")
@contextmanager
def indent(self):
self._level += 1
yield
self._level -= 1
def flush(self, indent=""):
code = "".join(self._buffer)
self._buffer.clear()
return code
def gen_var(self, prefix):
name = f"{prefix}_{self._suffix}"
self._suffix += 1
return name
def def_var(self, prefix, value):
name = self.gen_var(prefix)
self.append(f"{name} = {value}")
return name
class _Compiler:
def __init__(self, state, emitter):
self.state = state
self.emitter = emitter
class _ValueCompiler(ValueVisitor, _Compiler):
helpers = {
"sign": lambda value, sign: value | sign if value & sign else value,
"zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
"zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
}
def on_ClockSignal(self, value):
raise NotImplementedError # :nocov:
def on_ResetSignal(self, value):
raise NotImplementedError # :nocov:
def on_AnyConst(self, value):
raise NotImplementedError # :nocov:
def on_AnySeq(self, value):
raise NotImplementedError # :nocov:
def on_Sample(self, value):
raise NotImplementedError # :nocov:
def on_Initial(self, value):
raise NotImplementedError # :nocov:
class _RHSValueCompiler(_ValueCompiler):
def __init__(self, state, emitter, *, mode, inputs=None):
super().__init__(state, emitter)
assert mode in ("curr", "next")
self.mode = mode
# If not None, `inputs` gets populated with RHS signals.
self.inputs = inputs
def on_Const(self, value):
return f"{value.value}"
def on_Signal(self, value):
if self.inputs is not None:
self.inputs.add(value)
if self.mode == "curr":
return f"slots[{self.state.get_signal(value)}].{self.mode}"
else:
return f"next_{self.state.get_signal(value)}"
def on_Operator(self, value):
def mask(value):
value_mask = (1 << len(value)) - 1
return f"({value_mask} & {self(value)})"
def sign(value):
if value.shape().signed:
return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
else: # unsigned
return mask(value)
if len(value.operands) == 1:
arg, = value.operands
if value.operator == "~":
return f"(~{self(arg)})"
if value.operator == "-":
return f"(-{sign(arg)})"
if value.operator == "b":
return f"bool({mask(arg)})"
if value.operator == "r|":
return f"(0 != {mask(arg)})"
if value.operator == "r&":
return f"({(1 << len(arg)) - 1} == {mask(arg)})"
if value.operator == "r^":
# Believe it or not, this is the fastest way to compute a sideways XOR in Python.
return f"(format({mask(arg)}, 'b').count('1') % 2)"
if value.operator in ("u", "s"):
# These operators don't change the bit pattern, only its interpretation.
return self(arg)
elif len(value.operands) == 2:
lhs, rhs = value.operands
lhs_mask = (1 << len(lhs)) - 1
rhs_mask = (1 << len(rhs)) - 1
if value.operator == "+":
return f"({sign(lhs)} + {sign(rhs)})"
if value.operator == "-":
return f"({sign(lhs)} - {sign(rhs)})"
if value.operator == "*":
return f"({sign(lhs)} * {sign(rhs)})"
if value.operator == "//":
return f"zdiv({sign(lhs)}, {sign(rhs)})"
if value.operator == "%":
return f"zmod({sign(lhs)}, {sign(rhs)})"
if value.operator == "&":
return f"({self(lhs)} & {self(rhs)})"
if value.operator == "|":
return f"({self(lhs)} | {self(rhs)})"
if value.operator == "^":
return f"({self(lhs)} ^ {self(rhs)})"
if value.operator == "<<":
return f"({sign(lhs)} << {sign(rhs)})"
if value.operator == ">>":
return f"({sign(lhs)} >> {sign(rhs)})"
if value.operator == "==":
return f"({sign(lhs)} == {sign(rhs)})"
if value.operator == "!=":
return f"({sign(lhs)} != {sign(rhs)})"
if value.operator == "<":
return f"({sign(lhs)} < {sign(rhs)})"
if value.operator == "<=":
return f"({sign(lhs)} <= {sign(rhs)})"
if value.operator == ">":
return f"({sign(lhs)} > {sign(rhs)})"
if value.operator == ">=":
return f"({sign(lhs)} >= {sign(rhs)})"
elif len(value.operands) == 3:
if value.operator == "m":
sel, val1, val0 = value.operands
return f"({self(val1)} if {mask(sel)} else {self(val0)})"
raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
def on_Slice(self, value):
return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
def on_Part(self, value):
offset_mask = (1 << len(value.offset)) - 1
offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
return f"({(1 << value.width) - 1} & " \
f"{self(value.value)} >> {offset})"
def on_Cat(self, value):
gen_parts = []
offset = 0
for part in value.parts:
part_mask = (1 << len(part)) - 1
gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
offset += len(part)
if gen_parts:
return f"({' | '.join(gen_parts)})"
return f"0"
def on_Repl(self, value):
part_mask = (1 << len(value.value)) - 1
gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
gen_parts = []
offset = 0
for _ in range(value.count):
gen_parts.append(f"({gen_part} << {offset})")
offset += len(value.value)
if gen_parts:
return f"({' | '.join(gen_parts)})"
return f"0"
def on_ArrayProxy(self, value):
index_mask = (1 << len(value.index)) - 1
gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
gen_value = self.emitter.gen_var("rhs_proxy")
if value.elems:
gen_elems = []
for index, elem in enumerate(value.elems):
if index == 0:
self.emitter.append(f"if {index} == {gen_index}:")
else:
self.emitter.append(f"elif {index} == {gen_index}:")
with self.emitter.indent():
self.emitter.append(f"{gen_value} = {self(elem)}")
self.emitter.append(f"else:")
with self.emitter.indent():
self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
return gen_value
else:
return f"0"
@classmethod
def compile(cls, state, value, *, mode):
emitter = _PythonEmitter()
compiler = cls(state, emitter, mode=mode)
emitter.append(f"result = {compiler(value)}")
return emitter.flush()
class _LHSValueCompiler(_ValueCompiler):
def __init__(self, state, emitter, *, rhs, outputs=None):
super().__init__(state, emitter)
# `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
# the offset of a Part.
self.rrhs = rhs
# `lrhs` is used to translate the read part of a read-modify-write cycle during partial
# update of an lvalue.
self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
# If not None, `outputs` gets populated with signals on LHS.
self.outputs = outputs
def on_Const(self, value):
raise TypeError # :nocov:
def on_Signal(self, value):
if self.outputs is not None:
self.outputs.add(value)
def gen(arg):
value_mask = (1 << len(value)) - 1
if value.shape().signed:
value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
else: # unsigned
value_sign = f"{value_mask} & {arg}"
self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
return gen
def on_Operator(self, value):
raise TypeError # :nocov:
def on_Slice(self, value):
def gen(arg):
width_mask = (1 << (value.stop - value.start)) - 1
self(value.value)(f"({self.lrhs(value.value)} & " \
f"{~(width_mask << value.start)} | " \
f"(({width_mask} & {arg}) << {value.start}))")
return gen
def on_Part(self, value):
def gen(arg):
width_mask = (1 << value.width) - 1
offset_mask = (1 << len(value.offset)) - 1
offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
self(value.value)(f"({self.lrhs(value.value)} & " \
f"~({width_mask} << {offset}) | " \
f"(({width_mask} & {arg}) << {offset}))")
return gen
def on_Cat(self, value):
def gen(arg):
gen_arg = self.emitter.def_var("cat", arg)
gen_parts = []
offset = 0
for part in value.parts:
part_mask = (1 << len(part)) - 1
self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
offset += len(part)
return gen
def on_Repl(self, value):
raise TypeError # :nocov:
def on_ArrayProxy(self, value):
def gen(arg):
index_mask = (1 << len(value.index)) - 1
gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
if value.elems:
gen_elems = []
for index, elem in enumerate(value.elems):
if index == 0:
self.emitter.append(f"if {index} == {gen_index}:")
else:
self.emitter.append(f"elif {index} == {gen_index}:")
with self.emitter.indent():
self(elem)(arg)
self.emitter.append(f"else:")
with self.emitter.indent():
self(value.elems[-1])(arg)
else:
self.emitter.append(f"pass")
return gen
class _StatementCompiler(StatementVisitor, _Compiler):
def __init__(self, state, emitter, *, inputs=None, outputs=None):
super().__init__(state, emitter)
self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
def on_statements(self, stmts):
for stmt in stmts:
self(stmt)
if not stmts:
self.emitter.append("pass")
def on_Assign(self, stmt):
gen_rhs = f"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
if stmt.rhs.shape().signed:
gen_rhs = f"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
return self.lhs(stmt.lhs)(gen_rhs)
def on_Switch(self, stmt):
gen_test = self.emitter.def_var("test",
f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
gen_checks = []
if not patterns:
gen_checks.append(f"True")
else:
for pattern in patterns:
if "-" in pattern:
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
value = int("".join("0" if b == "-" else b for b in pattern), 2)
gen_checks.append(f"{value} == ({mask} & {gen_test})")
else:
value = int(pattern, 2)
gen_checks.append(f"{value} == {gen_test}")
if index == 0:
self.emitter.append(f"if {' or '.join(gen_checks)}:")
else:
self.emitter.append(f"elif {' or '.join(gen_checks)}:")
with self.emitter.indent():
self(stmts)
def on_Assert(self, stmt):
raise NotImplementedError # :nocov:
def on_Assume(self, stmt):
raise NotImplementedError # :nocov:
def on_Cover(self, stmt):
raise NotImplementedError # :nocov:
@classmethod
def compile(cls, state, stmt):
output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
emitter = _PythonEmitter()
for signal_index in output_indexes:
emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
compiler = cls(state, emitter)
compiler(stmt)
for signal_index in output_indexes:
emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
return emitter.flush()
class _FragmentCompiler:
def __init__(self, state):
self.state = state
def __call__(self, fragment):
processes = set()
for domain_name, domain_signals in fragment.drivers.items():
domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
domain_process = PyRTLProcess(is_comb=domain_name is None)
emitter = _PythonEmitter()
emitter.append(f"def run():")
emitter._level += 1
if domain_name is None:
for signal in domain_signals:
signal_index = self.state.get_signal(signal)
emitter.append(f"next_{signal_index} = {signal.reset}")
inputs = SignalSet()
_StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
for input in inputs:
self.state.add_trigger(domain_process, input)
else:
domain = fragment.domains[domain_name]
clk_trigger = 1 if domain.clk_edge == "pos" else 0
self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
if domain.rst is not None and domain.async_reset:
rst_trigger = 1
self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
for signal in domain_signals:
signal_index = self.state.get_signal(signal)
emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
_StatementCompiler(self.state, emitter)(domain_stmts)
for signal in domain_signals:
signal_index = self.state.get_signal(signal)
emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
# There shouldn't be any exceptions raised by the generated code, but if there are
# (almost certainly due to a bug in the code generator), use this environment variable
# to make backtraces useful.
code = emitter.flush()
if os.getenv("NMIGEN_pysim_dump"):
file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
file.write(code)
filename = file.name
else:
filename = "<string>"
exec_locals = {"slots": self.state.slots, **_ValueCompiler.helpers}
exec(compile(code, filename, "exec"), exec_locals)
domain_process.run = exec_locals["run"]
processes.add(domain_process)
for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
if subfragment_name is None:
subfragment_name = "U${}".format(subfragment_index)
processes.update(self(subfragment))
return processes