From b6c5294e5031fccaf9af517c4433c94a18dd9928 Mon Sep 17 00:00:00 2001 From: Wanda Date: Sun, 4 Feb 2024 08:15:29 +0100 Subject: [PATCH] hdl.MemoryInstance: refactor and add first-class simulation support. --- amaranth/back/rtlil.py | 80 ++++++++-------- amaranth/hdl/_ir.py | 21 ++++- amaranth/hdl/_mem.py | 170 +++++++++++++++++++++------------- amaranth/hdl/_xfrm.py | 100 +++++++++++--------- amaranth/sim/_base.py | 46 +++++++--- amaranth/sim/_pycoro.py | 24 ++++- amaranth/sim/_pyrtl.py | 76 ++++++++++++++- amaranth/sim/pysim.py | 199 +++++++++++++++++++++++++++++++++------- tests/test_hdl_xfrm.py | 10 +- 9 files changed, 525 insertions(+), 201 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 657b1d2..451236e 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -787,69 +787,71 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): return f"\\{fragment.type}", port_map, params if isinstance(fragment, _mem.MemoryInstance): - memory = fragment.memory - init = "".join(format(_ast.Const(elem, _ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init)) - init = _ast.Const(int(init or "0", 2), memory.depth * memory.width) + init = "".join(format(_ast.Const(elem, _ast.unsigned(fragment._width)).value, f"0{fragment._width}b") for elem in reversed(fragment._init)) + init = _ast.Const(int(init or "0", 2), fragment._depth * fragment._width) rd_clk = [] rd_clk_enable = 0 rd_clk_polarity = 0 rd_transparency_mask = 0 - for index, port in enumerate(fragment.read_ports): - if port.domain != "comb": - cd = fragment.domains[port.domain] + for index, port in enumerate(fragment._read_ports): + if port._domain is not None: + cd = fragment.domains[port._domain] rd_clk.append(cd.clk) if cd.clk_edge == "pos": rd_clk_polarity |= 1 << index rd_clk_enable |= 1 << index - if port.transparent: - for write_index, write_port in enumerate(fragment.write_ports): - if port.domain == write_port.domain: - rd_transparency_mask |= 1 << (index * len(fragment.write_ports) + write_index) + for write_index in port._transparency: + rd_transparency_mask |= 1 << (index * len(fragment._write_ports) + write_index) else: rd_clk.append(_ast.Const(0, 1)) wr_clk = [] wr_clk_enable = 0 wr_clk_polarity = 0 - for index, port in enumerate(fragment.write_ports): - cd = fragment.domains[port.domain] + for index, port in enumerate(fragment._write_ports): + cd = fragment.domains[port._domain] wr_clk.append(cd.clk) wr_clk_enable |= 1 << index if cd.clk_edge == "pos": wr_clk_polarity |= 1 << index params = { "MEMID": builder._make_name(hierarchy[-1], local=False), - "SIZE": memory.depth, + "SIZE": fragment._depth, "OFFSET": 0, - "ABITS": _ast.Shape.cast(range(memory.depth)).width, - "WIDTH": memory.width, + "ABITS": _ast.Shape.cast(range(fragment._depth)).width, + "WIDTH": fragment._width, "INIT": init, - "RD_PORTS": len(fragment.read_ports), - "RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment.read_ports))), - "RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment.read_ports))), - "RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment.read_ports) * len(fragment.write_ports))), - "RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment.read_ports) * len(fragment.write_ports))), - "RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment.read_ports))), - "RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment.read_ports))), - "RD_ARST_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width), - "RD_SRST_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width), - "RD_INIT_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width), - "WR_PORTS": len(fragment.write_ports), - "WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment.write_ports))), - "WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment.write_ports))), - "WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment.write_ports) * len(fragment.write_ports))), - "WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment.write_ports))), + "RD_PORTS": len(fragment._read_ports), + "RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment._read_ports))), + "RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment._read_ports))), + "RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment._read_ports) * len(fragment._write_ports))), + "RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment._read_ports) * len(fragment._write_ports))), + "RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._read_ports))), + "RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment._read_ports))), + "RD_ARST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), + "RD_SRST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), + "RD_INIT_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width), + "WR_PORTS": len(fragment._write_ports), + "WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment._write_ports))), + "WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment._write_ports))), + "WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment._write_ports) * len(fragment._write_ports))), + "WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._write_ports))), } + def make_en(port): + if len(port._data) == 0: + return _ast.Const(0, 0) + granularity = len(port._data) // len(port._en) + return _ast.Cat(en_bit.replicate(granularity) for en_bit in port._en) port_map = { "\\RD_CLK": _ast.Cat(rd_clk), - "\\RD_EN": _ast.Cat(port.en for port in fragment.read_ports), - "\\RD_ARST": _ast.Const(0, len(fragment.read_ports)), - "\\RD_SRST": _ast.Const(0, len(fragment.read_ports)), - "\\RD_ADDR": _ast.Cat(port.addr for port in fragment.read_ports), - "\\RD_DATA": _ast.Cat(port.data for port in fragment.read_ports), + "\\RD_EN": _ast.Cat(port._en for port in fragment._read_ports), + "\\RD_ARST": _ast.Const(0, len(fragment._read_ports)), + "\\RD_SRST": _ast.Const(0, len(fragment._read_ports)), + "\\RD_ADDR": _ast.Cat(port._addr for port in fragment._read_ports), + "\\RD_DATA": _ast.Cat(port._data for port in fragment._read_ports), "\\WR_CLK": _ast.Cat(wr_clk), - "\\WR_EN": _ast.Cat(_ast.Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in fragment.write_ports), - "\\WR_ADDR": _ast.Cat(port.addr for port in fragment.write_ports), - "\\WR_DATA": _ast.Cat(port.data for port in fragment.write_ports), + "\\WR_EN": _ast.Cat(make_en(port) for port in fragment._write_ports), + "\\WR_ADDR": _ast.Cat(port._addr for port in fragment._write_ports), + "\\WR_DATA": _ast.Cat(port._data for port in fragment._write_ports), } return "$mem_v2", port_map, params @@ -913,7 +915,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): if isinstance(subfragment, _ir.Instance): src = _src(subfragment.src_loc) elif isinstance(subfragment, _mem.MemoryInstance): - src = _src(subfragment.memory.src_loc) + src = _src(subfragment._src_loc) else: src = "" diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index bb67124..820dbf0 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -183,6 +183,7 @@ class Fragment: def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"): assert mode in ("silent", "warn", "error") + from ._mem import MemoryInstance driver_subfrags = SignalDict() def add_subfrag(registry, entity, entry): @@ -214,7 +215,7 @@ class Fragment: # Always flatten subfragments that explicitly request it. flatten_subfrags.add((subfrag, subfrag_hierarchy)) - if isinstance(subfrag, Instance): + if isinstance(subfrag, (Instance, MemoryInstance)): # Never flatten instances. continue @@ -368,6 +369,8 @@ class Fragment: return new_domains def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top): + from ._mem import MemoryInstance + def add_uses(*sigs, self=self): for sig in flatten(sigs): if sig not in uses: @@ -416,6 +419,22 @@ class Fragment: if dir == "io": subfrag.add_ports(value._lhs_signals(), dir=dir) add_io(value._lhs_signals()) + elif isinstance(subfrag, MemoryInstance): + for port in subfrag._read_ports: + subfrag.add_ports(port._data._lhs_signals(), dir="o") + add_defs(port._data._lhs_signals()) + for value in [port._addr, port._en]: + subfrag.add_ports(value._rhs_signals(), dir="i") + add_uses(value._rhs_signals()) + for port in subfrag._write_ports: + for value in [port._addr, port._en, port._data]: + subfrag.add_ports(value._rhs_signals(), dir="i") + add_uses(value._rhs_signals()) + for domain, _ in subfrag.iter_sync(): + cd = subfrag.domains[domain] + add_uses(cd.clk) + if cd.rst is not None: + add_uses(cd.rst) else: parent[subfrag] = self level [subfrag] = level[self] + 1 diff --git a/amaranth/hdl/_mem.py b/amaranth/hdl/_mem.py index 36c6672..b92154e 100644 --- a/amaranth/hdl/_mem.py +++ b/amaranth/hdl/_mem.py @@ -3,12 +3,104 @@ from collections import OrderedDict from .. import tracer from ._ast import * -from ._ir import Elaboratable, Instance, Fragment +from ._ir import Elaboratable, Fragment +from ..utils import ceil_log2 __all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"] +class MemoryIdentity: pass + + +class MemorySimRead: + def __init__(self, identity, addr): + assert isinstance(identity, MemoryIdentity) + self._identity = identity + self._addr = Value.cast(addr) + + def eq(self, value): + return MemorySimWrite(self._identity, self._addr, value) + + +class MemorySimWrite: + def __init__(self, identity, addr, data): + assert isinstance(identity, MemoryIdentity) + self._identity = identity + self._addr = Value.cast(addr) + self._data = Value.cast(data) + + +class MemoryInstance(Fragment): + class _ReadPort: + def __init__(self, *, domain, addr, data, en, transparency): + assert domain is None or isinstance(domain, str) + if domain == "comb": + domain = None + self._domain = domain + self._addr = Value.cast(addr) + self._data = Value.cast(data) + self._en = Value.cast(en) + self._transparency = tuple(transparency) + assert len(self._en) == 1 + if domain is None: + assert isinstance(self._en, Const) + assert self._en.width == 1 + assert self._en.value == 1 + + class _WritePort: + def __init__(self, *, domain, addr, data, en): + assert isinstance(domain, str) + assert domain != "comb" + self._domain = domain + self._addr = Value.cast(addr) + self._data = Value.cast(data) + self._en = Value.cast(en) + if len(self._data): + assert len(self._data) % len(self._en) == 0 + + @property + def _granularity(self): + if not len(self._data): + return 1 + return len(self._data) // len(self._en) + + + def __init__(self, *, identity, width, depth, init=None, attrs=None, src_loc=None): + super().__init__() + assert isinstance(identity, MemoryIdentity) + self._identity = identity + self._width = operator.index(width) + self._depth = operator.index(depth) + self._init = tuple(init) if init is not None else () + assert len(self._init) <= self._depth + self._init += (0,) * (self._depth - len(self._init)) + for x in self._init: + assert isinstance(x, int) + self._attrs = attrs or {} + self._src_loc = src_loc + self._read_ports = [] + self._write_ports = [] + + def read_port(self, *, domain, addr, data, en, transparency): + port = self._ReadPort(domain=domain, addr=addr, data=data, en=en, transparency=transparency) + assert len(port._data) == self._width + assert len(port._addr) == ceil_log2(self._depth) + for x in port._transparency: + assert isinstance(x, int) + assert x in range(len(self._write_ports)) + for signal in port._data._rhs_signals(): + self.add_driver(signal, port._domain) + self._read_ports.append(port) + + def write_port(self, *, domain, addr, data, en): + port = self._WritePort(domain=domain, addr=addr, data=data, en=en) + assert len(port._data) == self._width + assert len(port._addr) == ceil_log2(self._depth) + self._write_ports.append(port) + return len(self._write_ports) - 1 + + class Memory(Elaboratable): """A word addressable storage. @@ -50,16 +142,10 @@ class Memory(Elaboratable): self.depth = depth self.attrs = OrderedDict(() if attrs is None else attrs) - # Array of signals for simulation. - self._array = Array() - if simulate: - for addr in range(self.depth): - self._array.append(Signal(self.width, name="{}({})" - .format(name or "memory", addr))) - self.init = init self._read_ports = [] self._write_ports = [] + self._identity = MemoryIdentity() @property def init(self): @@ -73,11 +159,8 @@ class Memory(Elaboratable): .format(len(self.init), self.depth)) try: - for addr in range(len(self._array)): - if addr < len(self._init): - self._array[addr].reset = operator.index(self._init[addr]) - else: - self._array[addr].reset = 0 + for addr, val in enumerate(self._init): + operator.index(val) except TypeError as e: raise TypeError("Memory initialization value at address {:x}: {}" .format(addr, e)) from None @@ -116,52 +199,24 @@ class Memory(Elaboratable): def __getitem__(self, index): """Simulation only.""" - return self._array[index] + return MemorySimRead(self._identity, index) def elaborate(self, platform): - f = MemoryInstance(self, self._read_ports, self._write_ports) + f = MemoryInstance(identity=self._identity, width=self.width, depth=self.depth, init=self.init, attrs=self.attrs, src_loc=self.src_loc) + write_ports = {} + for port in self._write_ports: + port._MustUse__used = True + iport = f.write_port(domain=port.domain, addr=port.addr, data=port.data, en=port.en) + write_ports.setdefault(port.domain, []).append(iport) for port in self._read_ports: port._MustUse__used = True if port.domain == "comb": - # Asynchronous port - f.add_statements(None, port.data.eq(self._array[port.addr])) - f.add_driver(port.data) + f.read_port(domain="comb", addr=port.addr, data=port.data, en=Const(1), transparency=()) else: - # Synchronous port - data = self._array[port.addr] - for write_port in self._write_ports: - if port.domain == write_port.domain and port.transparent: - if len(write_port.en) > 1: - parts = [] - for index, en_bit in enumerate(write_port.en): - offset = index * write_port.granularity - bits = slice(offset, offset + write_port.granularity) - cond = en_bit & (port.addr == write_port.addr) - parts.append(Mux(cond, write_port.data[bits], data[bits])) - data = Cat(parts) - else: - cond = write_port.en & (port.addr == write_port.addr) - data = Mux(cond, write_port.data, data) - f.add_statements( - port.domain, - Switch(port.en, { - 1: port.data.eq(data) - }) - ) - f.add_driver(port.data, port.domain) - for port in self._write_ports: - port._MustUse__used = True - if len(port.en) > 1: - for index, en_bit in enumerate(port.en): - offset = index * port.granularity - bits = slice(offset, offset + port.granularity) - write_data = self._array[port.addr][bits].eq(port.data[bits]) - f.add_statements(port.domain, Switch(en_bit, { 1: write_data })) - else: - write_data = self._array[port.addr].eq(port.data) - f.add_statements(port.domain, Switch(port.en, { 1: write_data })) - for signal in self._array: - f.add_driver(signal, port.domain) + transparency = [] + if port.transparent: + transparency = write_ports.get(port.domain, []) + f.read_port(domain=port.domain, addr=port.addr, data=port.data, en=port.en, transparency=transparency) return f @@ -308,12 +363,3 @@ class DummyPort: name=f"{name}_data", src_loc_at=1) self.en = Signal(data_width // granularity, name=f"{name}_en", src_loc_at=1) - - -class MemoryInstance(Fragment): - def __init__(self, memory, read_ports, write_ports): - super().__init__() - self.memory = memory - self.read_ports = read_ports - self.write_ports = write_ports - self.attrs = memory.attrs diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index ea17af8..dad2207 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -1,9 +1,8 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict from collections.abc import Iterable -from copy import copy -from .._utils import flatten, _ignore_deprecated +from .._utils import flatten from .. import tracer from ._ast import * from ._ast import _StatementList, AnyValue, Property @@ -239,27 +238,45 @@ class FragmentTransformer: new_fragment.add_driver(signal, domain) def map_memory_ports(self, fragment, new_fragment): - new_fragment.read_ports = [ - copy(port) - for port in fragment.read_ports - ] - new_fragment.write_ports = [ - copy(port) - for port in fragment.write_ports - ] if hasattr(self, "on_value"): - for port in new_fragment.read_ports: - port.en = self.on_value(port.en) - port.addr = self.on_value(port.addr) - port.data = self.on_value(port.data) - for port in new_fragment.write_ports: - port.en = self.on_value(port.en) - port.addr = self.on_value(port.addr) - port.data = self.on_value(port.data) + for port in new_fragment._read_ports: + port._en = self.on_value(port._en) + port._addr = self.on_value(port._addr) + port._data = self.on_value(port._data) + for port in new_fragment._write_ports: + port._en = self.on_value(port._en) + port._addr = self.on_value(port._addr) + port._data = self.on_value(port._data) def on_fragment(self, fragment): if isinstance(fragment, MemoryInstance): - new_fragment = MemoryInstance(fragment.memory, [], []) + new_fragment = MemoryInstance( + identity=fragment._identity, + width=fragment._width, + depth=fragment._depth, + init=fragment._init, + attrs=fragment._attrs, + src_loc=fragment._src_loc + ) + new_fragment._read_ports = [ + MemoryInstance._ReadPort( + domain=port._domain, + addr=port._addr, + data=port._data, + en=port._en, + transparency=port._transparency, + ) + for port in fragment._read_ports + ] + new_fragment._write_ports = [ + MemoryInstance._WritePort( + domain=port._domain, + addr=port._addr, + data=port._data, + en=port._en, + ) + for port in fragment._write_ports + ] self.map_memory_ports(fragment, new_fragment) elif isinstance(fragment, Instance): new_fragment = Instance(fragment.type, src_loc=fragment.src_loc) @@ -376,17 +393,16 @@ class DomainCollector(ValueVisitor, StatementVisitor): def on_fragment(self, fragment): if isinstance(fragment, MemoryInstance): - for port in fragment.read_ports: - self.on_value(port.addr) - self.on_value(port.data) - self.on_value(port.en) - if port.domain != "comb": - self._add_used_domain(port.domain) - for port in fragment.write_ports: - self.on_value(port.addr) - self.on_value(port.data) - self.on_value(port.en) - self._add_used_domain(port.domain) + for port in fragment._read_ports: + self.on_value(port._addr) + self.on_value(port._data) + self.on_value(port._en) + self._add_used_domain(port._domain) + for port in fragment._write_ports: + self.on_value(port._addr) + self.on_value(port._data) + self.on_value(port._en) + self._add_used_domain(port._domain) if isinstance(fragment, Instance): for name, (value, dir) in fragment.named_ports.items(): @@ -460,12 +476,12 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer) def map_memory_ports(self, fragment, new_fragment): super().map_memory_ports(fragment, new_fragment) - for port in new_fragment.read_ports: - if port.domain in self.domain_map: - port.domain = self.domain_map[port.domain] - for port in new_fragment.write_ports: - if port.domain in self.domain_map: - port.domain = self.domain_map[port.domain] + for port in new_fragment._read_ports: + if port._domain in self.domain_map: + port._domain = self.domain_map[port._domain] + for port in new_fragment._write_ports: + if port._domain in self.domain_map: + port._domain = self.domain_map[port._domain] class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): @@ -645,10 +661,10 @@ class EnableInserter(_ControlInserter): def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) if isinstance(new_fragment, MemoryInstance): - for port in new_fragment.read_ports: - if port.domain in self.controls: - port.en = port.en & self.controls[port.domain] - for port in new_fragment.write_ports: - if port.domain in self.controls: - port.en = Mux(self.controls[port.domain], port.en, Const(0, len(port.en))) + for port in new_fragment._read_ports: + if port._domain in self.controls: + port._en = port._en & self.controls[port._domain] + for port in new_fragment._write_ports: + if port._domain in self.controls: + port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en))) return new_fragment diff --git a/amaranth/sim/_base.py b/amaranth/sim/_base.py index ee1061f..8f6fe40 100644 --- a/amaranth/sim/_base.py +++ b/amaranth/sim/_base.py @@ -1,4 +1,4 @@ -__all__ = ["BaseProcess", "BaseSignalState", "BaseSimulation", "BaseEngine"] +__all__ = ["BaseProcess", "BaseSignalState", "BaseMemoryState", "BaseSimulation", "BaseEngine"] class BaseProcess: @@ -12,7 +12,7 @@ class BaseProcess: self.passive = True def run(self): - raise NotImplementedError + raise NotImplementedError # :nocov: class BaseSignalState: @@ -24,44 +24,62 @@ class BaseSignalState: next = NotImplemented def set(self, value): - raise NotImplementedError + raise NotImplementedError # :nocov: + + +class BaseMemoryState: + __slots__ = () + + memory = NotImplemented + + def read(self, addr): + raise NotImplementedError # :nocov: + + def write(self, addr, value): + raise NotImplementedError # :nocov: class BaseSimulation: def reset(self): - raise NotImplementedError + raise NotImplementedError # :nocov: def get_signal(self, signal): - raise NotImplementedError + raise NotImplementedError # :nocov: slots = NotImplemented def add_trigger(self, process, signal, *, trigger=None): - raise NotImplementedError + raise NotImplementedError # :nocov: def remove_trigger(self, process, signal): - raise NotImplementedError + raise NotImplementedError # :nocov: + + def add_memory_trigger(self, process, identity): + raise NotImplementedError # :nocov: + + def remove_memory_trigger(self, process, identity): + raise NotImplementedError # :nocov: def wait_interval(self, process, interval): - raise NotImplementedError + raise NotImplementedError # :nocov: class BaseEngine: def add_coroutine_process(self, process, *, default_cmd): - raise NotImplementedError + raise NotImplementedError # :nocov: def add_clock_process(self, clock, *, phase, period): - raise NotImplementedError + raise NotImplementedError # :nocov: def reset(self): - raise NotImplementedError + raise NotImplementedError # :nocov: @property def now(self): - raise NotImplementedError + raise NotImplementedError # :nocov: def advance(self): - raise NotImplementedError + raise NotImplementedError # :nocov: def write_vcd(self, *, vcd_file, gtkw_file, traces): - raise NotImplementedError + raise NotImplementedError # :nocov: diff --git a/amaranth/sim/_pycoro.py b/amaranth/sim/_pycoro.py index 6a0487e..6baffac 100644 --- a/amaranth/sim/_pycoro.py +++ b/amaranth/sim/_pycoro.py @@ -2,8 +2,9 @@ import inspect from ..hdl import * from ..hdl._ast import Statement, SignalSet, ValueCastable +from ..hdl._mem import MemorySimRead, MemorySimWrite from .core import Tick, Settle, Delay, Passive, Active -from ._base import BaseProcess +from ._base import BaseProcess, BaseMemoryState from ._pyrtl import _ValueCompiler, _RHSValueCompiler, _StatementCompiler @@ -119,6 +120,27 @@ class PyCoroProcess(BaseProcess): elif type(command) is Active: self.passive = False + elif type(command) is MemorySimRead: + exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"), + self.exec_locals) + addr = Const(self.exec_locals["result"], command._addr.shape()).value + index = self.state.memories[command._identity] + state = self.state.slots[index] + assert isinstance(state, BaseMemoryState) + response = state.read(addr) + + elif type(command) is MemorySimWrite: + exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"), + self.exec_locals) + addr = Const(self.exec_locals["result"], command._addr.shape()).value + exec(_RHSValueCompiler.compile(self.state, command._data, mode="curr"), + self.exec_locals) + data = Const(self.exec_locals["result"], command._data.shape()).value + index = self.state.memories[command._identity] + state = self.state.slots[index] + assert isinstance(state, BaseMemoryState) + state.write(addr, data) + elif command is None: # only possible if self.default_cmd is None raise TypeError("Received default command from process {!r} that was added " "with add_process(); did you mean to add this process with " diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 30f5362..4aecc60 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -4,8 +4,9 @@ from contextlib import contextmanager import sys from ..hdl import * -from ..hdl._ast import SignalSet +from ..hdl._ast import SignalSet, _StatementList from ..hdl._xfrm import ValueVisitor, StatementVisitor +from ..hdl._mem import MemoryInstance from ._base import BaseProcess @@ -409,10 +410,25 @@ class _FragmentCompiler: def __call__(self, fragment): processes = set() - for domain_name, domain_stmts in fragment.statements.items(): + domains = set(fragment.statements) + + if isinstance(fragment, MemoryInstance): + self.state.add_memory(fragment) + for port in fragment._read_ports: + domains.add(port._domain) + for port in fragment._write_ports: + domains.add(port._domain) + + for domain_name in domains: + domain_stmts = fragment.statements.get(domain_name, _StatementList()) domain_process = PyRTLProcess(is_comb=domain_name is None) domain_signals = domain_stmts._lhs_signals() + if isinstance(fragment, MemoryInstance): + for port in fragment._read_ports: + if port._domain == domain_name: + domain_signals.update(port._data._lhs_signals()) + emitter = _PythonEmitter() emitter.append(f"def run():") emitter._level += 1 @@ -425,6 +441,21 @@ class _FragmentCompiler: inputs = SignalSet() _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts) + if isinstance(fragment, MemoryInstance): + self.state.add_memory_trigger(domain_process, fragment._identity) + memory_index = self.state.memories[fragment._identity] + rhs = _RHSValueCompiler(self.state, emitter, mode="curr", inputs=inputs) + lhs = _LHSValueCompiler(self.state, emitter, rhs=rhs) + + for port in fragment._read_ports: + if port._domain is not None: + continue + + addr = rhs(port._addr) + addr = f"({(1 << len(port._addr)) - 1:#x} & {addr})" + data = emitter.def_var("read_data", f"slots[{memory_index}].read({addr})") + lhs(port._data)(data) + for input in inputs: self.state.add_trigger(domain_process, input) @@ -442,6 +473,47 @@ class _FragmentCompiler: _StatementCompiler(self.state, emitter)(domain_stmts) + if isinstance(fragment, MemoryInstance): + memory_index = self.state.memories[fragment._identity] + rhs = _RHSValueCompiler(self.state, emitter, mode="curr") + lhs = _LHSValueCompiler(self.state, emitter, rhs=rhs) + + write_vals = {} + + for idx, port in enumerate(fragment._write_ports): + if port._domain != domain_name: + continue + + addr = rhs(port._addr) + addr = emitter.def_var("write_addr", f"({(1 << len(port._addr)) - 1:#x} & {addr})") + data = rhs(port._data) + data = emitter.def_var("write_data", f"({(1 << len(port._data)) - 1:#x} & {data})") + en = rhs(Cat(bit.replicate(port._granularity) for bit in port._en)) + en = emitter.def_var("write_en", f"({(1 << len(port._data)) - 1:#x} & {en})") + emitter.append(f"slots[{memory_index}].write({addr}, {data}, {en})") + write_vals[idx] = addr, data, en + + for port in fragment._read_ports: + if port._domain != domain_name: + continue + + en = rhs(port._en) + en = f"(1 & {en})" + emitter.append(f"if {en}:") + with emitter.indent(): + addr = rhs(port._addr) + addr = emitter.def_var("read_addr", f"({(1 << len(port._addr)) - 1:#x} & {addr})") + data = emitter.def_var("read_data", f"slots[{memory_index}].read({addr})") + + for idx in port._transparency: + waddr, wdata, wen = write_vals[idx] + emitter.append(f"if {addr} == {waddr}:") + with emitter.indent(): + emitter.append(f"{data} &= ~{wen}") + emitter.append(f"{data} |= {wdata} & {wen}") + + lhs(port._data)(data) + for signal in domain_signals: signal_index = self.state.get_signal(signal) emitter.append(f"slots[{signal_index}].set(next_{signal_index})") diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index 0fa9ea0..3f95b9a 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -6,6 +6,7 @@ from vcd.gtkw import GTKWSave from ..hdl import * from ..hdl._repr import * +from ..hdl._mem import MemoryInstance, MemoryIdentity from ..hdl._ast import SignalDict, Slice, Operator from ._base import * from ._pyrtl import _FragmentCompiler @@ -43,47 +44,58 @@ class _VCDWriter: if isinstance(gtkw_file, str): gtkw_file = open(gtkw_file, "w") - self.vcd_vars = SignalDict() + self.vcd_signal_vars = SignalDict() + self.vcd_memory_vars = {} self.vcd_file = vcd_file self.vcd_writer = vcd_file and VCDWriter(self.vcd_file, timescale="1 ps", comment="Generated by Amaranth") - self.gtkw_names = SignalDict() + self.gtkw_signal_names = SignalDict() + self.gtkw_memory_names = {} self.gtkw_file = gtkw_file self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file) self.traces = [] signal_names = SignalDict() + memories = {} for subfragment, subfragment_name in \ fragment._assign_names_to_fragments(hierarchy=("bench", "top",)).items(): for signal, signal_name in subfragment._assign_names_to_signals().items(): if signal not in signal_names: signal_names[signal] = set() signal_names[signal].add((*subfragment_name, signal_name)) + if isinstance(subfragment, MemoryInstance): + memories[subfragment._identity] = (subfragment, subfragment_name) trace_names = SignalDict() assigned_names = set() for trace in traces: - trace = Value.cast(trace) - for trace_signal in trace._rhs_signals(): - if trace_signal not in signal_names: - if trace_signal.name not in assigned_names: - name = trace_signal.name - else: - name = f"{trace_signal.name}${len(assigned_names)}" - assert name not in assigned_names - trace_names[trace_signal] = {("bench", name)} - assigned_names.add(name) - self.traces.append(trace_signal) + if isinstance(trace, ValueLike): + trace = Value.cast(trace) + for trace_signal in trace._rhs_signals(): + if trace_signal not in signal_names: + if trace_signal.name not in assigned_names: + name = trace_signal.name + else: + name = f"{trace_signal.name}${len(assigned_names)}" + assert name not in assigned_names + trace_names[trace_signal] = {("bench", name)} + assigned_names.add(name) + self.traces.append(trace_signal) + elif isinstance(trace, (MemoryInstance, Memory)): + if not trace._identity in memories: + raise ValueError(f"{trace!r} is a memory not part of the elaborated design") + self.traces.append(trace._identity) + else: + raise TypeError(f"{trace!r} is not a traceable object") if self.vcd_writer is None: return for signal, names in itertools.chain(signal_names.items(), trace_names.items()): - fields = [] - self.vcd_vars[signal] = [] - self.gtkw_names[signal] = [] + self.vcd_signal_vars[signal] = [] + self.gtkw_signal_names[signal] = [] for repr in signal._value_repr: var_init = self.eval_field(repr.value, signal, signal.reset) if isinstance(repr.format, FormatInt): @@ -119,22 +131,45 @@ class _VCDWriter: gtkw_field_name = '\\' + field_name else: gtkw_field_name = field_name - self.gtkw_names[signal].append(".".join((*var_scope, gtkw_field_name)) + suffix) + self.gtkw_signal_names[signal].append(".".join((*var_scope, gtkw_field_name)) + suffix) else: self.vcd_writer.register_alias( scope=var_scope, name=field_name, var=vcd_var) - self.vcd_vars[signal].append((vcd_var, repr)) + self.vcd_signal_vars[signal].append((vcd_var, repr)) + + for memory, memory_name in memories.values(): + self.vcd_memory_vars[memory._identity] = vcd_vars = [] + self.gtkw_memory_names[memory._identity] = gtkw_names = [] + if memory._width > 1: + suffix = f"[{memory._width - 1}:0]" + else: + suffix = "" + for idx, init in enumerate(memory._init): + field_name = "\\" + memory_name[-1] + f"[{idx}]" + var_scope = memory_name[:-1] + vcd_var = self.vcd_writer.register_var( + scope=var_scope, name=field_name, + var_type="wire", size=memory._width, init=init, + ) + vcd_vars.append(vcd_var) + gtkw_field_name = field_name + suffix + gtkw_name = ".".join((*var_scope, gtkw_field_name)) + gtkw_names.append(gtkw_name) - def update(self, timestamp, signal, value): - for (vcd_var, repr) in self.vcd_vars.get(signal, ()): + def update_signal(self, timestamp, signal, value): + for (vcd_var, repr) in self.vcd_signal_vars.get(signal, ()): var_value = self.eval_field(repr.value, signal, value) if not isinstance(repr.format, FormatInt): var_value = self.decode_to_vcd(repr.format, var_value) self.vcd_writer.change(vcd_var, timestamp, var_value) + def update_memory(self, timestamp, memory, addr, value): + vcd_var = self.vcd_memory_vars[memory._identity][addr] + self.vcd_writer.change(vcd_var, timestamp, value) + def close(self, timestamp): if self.vcd_writer is not None: self.vcd_writer.close(timestamp) @@ -145,8 +180,14 @@ class _VCDWriter: self.gtkw_save.treeopen("top") for signal in self.traces: - for name in self.gtkw_names[signal]: - self.gtkw_save.trace(name) + if isinstance(signal, Signal): + for name in self.gtkw_signal_names[signal]: + self.gtkw_save.trace(name) + elif isinstance(signal, MemoryIdentity): + for name in self.gtkw_memory_names[signal]: + self.gtkw_save.trace(name) + else: + assert False # :nocov: if self.vcd_file is not None: self.vcd_file.close() @@ -208,7 +249,7 @@ class _PySignalState(BaseSignalState): def __init__(self, signal, pending): self.signal = signal self.pending = pending - self.waiters = dict() + self.waiters = {} self.curr = self.next = signal.reset def set(self, value): @@ -229,17 +270,83 @@ class _PySignalState(BaseSignalState): return awoken_any +class _PyMemoryChange: + __slots__ = ("state", "addr") + + def __init__(self, state, addr): + self.state = state + self.addr = addr + + +class _PyMemoryState(BaseMemoryState): + __slots__ = ("memory", "data", "write_queue", "waiters", "pending") + + def __init__(self, memory, pending): + self.memory = memory + self.pending = pending + self.waiters = {} + self.reset() + + def reset(self): + self.data = list(self.memory._init) + self.write_queue = [] + + def commit(self): + if not self.write_queue: + return False + + for addr, value, mask in self.write_queue: + curr = self.data[addr] + value = (value & mask) | (curr & ~mask) + self.data[addr] = value + self.write_queue.clear() + + awoken_any = False + for process in self.waiters: + process.runnable = awoken_any = True + return awoken_any + + def read(self, addr): + if addr not in range(self.memory._depth): + return 0 + + return self.data[addr] + + def write(self, addr, value, mask=None): + if addr not in range(self.memory._depth): + return + if mask == 0: + return + + if mask is None: + mask = (1 << self.memory._width) - 1 + + self.write_queue.append((addr, value, mask)) + self.pending.add(self) + + class _PySimulation(BaseSimulation): def __init__(self): - self.timeline = _Timeline() - self.signals = SignalDict() - self.slots = [] - self.pending = set() + self.timeline = _Timeline() + self.signals = SignalDict() + self.memories = {} + self.slots = [] + self.pending = set() + + def add_memory(self, fragment): + self.memories[fragment._identity] = len(self.slots) + self.slots.append(_PyMemoryState(fragment, self.pending)) def reset(self): self.timeline.reset() for signal, index in self.signals.items(): - self.slots[index].curr = self.slots[index].next = signal.reset + state = self.slots[index] + assert isinstance(state, _PySignalState) + state.curr = state.next = signal.reset + for index in self.memories.values(): + state = self.slots[index] + assert isinstance(state, _PyMemoryState) + state.reset() self.pending.clear() def get_signal(self, signal): @@ -262,16 +369,31 @@ class _PySimulation(BaseSimulation): assert process in self.slots[index].waiters del self.slots[index].waiters[process] + def add_memory_trigger(self, process, identity): + index = self.memories[identity] + self.slots[index].waiters[process] = None + + def remove_memory_trigger(self, process, identity): + index = self.memories[identity] + assert process in self.slots[index].waiters + del self.slots[index].waiters[process] + def wait_interval(self, process, interval): self.timeline.delay(interval, process) def commit(self, changed=None): converged = True - for signal_state in self.pending: - if signal_state.commit(): + for state in self.pending: + if changed is not None: + if isinstance(state, _PyMemoryState): + for addr, _value, _mask in state.write_queue: + changed.add(_PyMemoryChange(state, addr)) + elif isinstance(state, _PySignalState): + changed.add(state) + else: + assert False # :nocov: + if state.commit(): converged = False - if changed is not None: - changed.update(self.pending) self.pending.clear() return converged @@ -314,9 +436,16 @@ class PySimEngine(BaseEngine): converged = self._state.commit(changed) for vcd_writer in self._vcd_writers: - for signal_state in changed: - vcd_writer.update(self._timeline.now, - signal_state.signal, signal_state.curr) + for change in changed: + if isinstance(change, _PySignalState): + signal_state = change + vcd_writer.update_signal(self._timeline.now, + signal_state.signal, signal_state.curr) + elif isinstance(change, _PyMemoryChange): + vcd_writer.update_memory(self._timeline.now, change.state.memory, + change.addr, change.state.data[change.addr]) + else: + assert False # :nocov: def advance(self): self._step() diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index a8724cd..27eb5c8 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -138,9 +138,9 @@ class DomainRenamerTestCase(FHDLTestCase): f = DomainRenamer({"a": "d", "c": "e"})(f) mem = f.subfragments[0][0] self.assertIsInstance(mem, MemoryInstance) - self.assertEqual(mem.read_ports[0].domain, "d") - self.assertEqual(mem.read_ports[1].domain, "b") - self.assertEqual(mem.write_ports[0].domain, "e") + self.assertEqual(mem._read_ports[0]._domain, "d") + self.assertEqual(mem._read_ports[1]._domain, "b") + self.assertEqual(mem._write_ports[0]._domain, "e") def test_rename_wrong_to_comb(self): with self.assertRaisesRegex(ValueError, @@ -530,7 +530,7 @@ class EnableInserterTestCase(FHDLTestCase): mem = Memory(width=8, depth=4) mem.read_port(transparent=False) f = EnableInserter(self.c1)(mem).elaborate(platform=None) - self.assertRepr(f.read_ports[0].en, """ + self.assertRepr(f._read_ports[0]._en, """ (& (sig mem_r_en) (sig c1)) """) @@ -538,7 +538,7 @@ class EnableInserterTestCase(FHDLTestCase): mem = Memory(width=8, depth=4) mem.write_port(granularity=2) f = EnableInserter(self.c1)(mem).elaborate(platform=None) - self.assertRepr(f.write_ports[0].en, """ + self.assertRepr(f._write_ports[0]._en, """ (m (sig c1) (sig mem_w_en)