hdl.MemoryInstance: refactor and add first-class simulation support.

This commit is contained in:
Wanda 2024-02-04 08:15:29 +01:00 committed by Catherine
parent f4daf74634
commit b6c5294e50
9 changed files with 525 additions and 201 deletions

View file

@ -787,69 +787,71 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
return f"\\{fragment.type}", port_map, params return f"\\{fragment.type}", port_map, params
if isinstance(fragment, _mem.MemoryInstance): if isinstance(fragment, _mem.MemoryInstance):
memory = fragment.memory init = "".join(format(_ast.Const(elem, _ast.unsigned(fragment._width)).value, f"0{fragment._width}b") for elem in reversed(fragment._init))
init = "".join(format(_ast.Const(elem, _ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init)) init = _ast.Const(int(init or "0", 2), fragment._depth * fragment._width)
init = _ast.Const(int(init or "0", 2), memory.depth * memory.width)
rd_clk = [] rd_clk = []
rd_clk_enable = 0 rd_clk_enable = 0
rd_clk_polarity = 0 rd_clk_polarity = 0
rd_transparency_mask = 0 rd_transparency_mask = 0
for index, port in enumerate(fragment.read_ports): for index, port in enumerate(fragment._read_ports):
if port.domain != "comb": if port._domain is not None:
cd = fragment.domains[port.domain] cd = fragment.domains[port._domain]
rd_clk.append(cd.clk) rd_clk.append(cd.clk)
if cd.clk_edge == "pos": if cd.clk_edge == "pos":
rd_clk_polarity |= 1 << index rd_clk_polarity |= 1 << index
rd_clk_enable |= 1 << index rd_clk_enable |= 1 << index
if port.transparent: for write_index in port._transparency:
for write_index, write_port in enumerate(fragment.write_ports): rd_transparency_mask |= 1 << (index * len(fragment._write_ports) + write_index)
if port.domain == write_port.domain:
rd_transparency_mask |= 1 << (index * len(fragment.write_ports) + write_index)
else: else:
rd_clk.append(_ast.Const(0, 1)) rd_clk.append(_ast.Const(0, 1))
wr_clk = [] wr_clk = []
wr_clk_enable = 0 wr_clk_enable = 0
wr_clk_polarity = 0 wr_clk_polarity = 0
for index, port in enumerate(fragment.write_ports): for index, port in enumerate(fragment._write_ports):
cd = fragment.domains[port.domain] cd = fragment.domains[port._domain]
wr_clk.append(cd.clk) wr_clk.append(cd.clk)
wr_clk_enable |= 1 << index wr_clk_enable |= 1 << index
if cd.clk_edge == "pos": if cd.clk_edge == "pos":
wr_clk_polarity |= 1 << index wr_clk_polarity |= 1 << index
params = { params = {
"MEMID": builder._make_name(hierarchy[-1], local=False), "MEMID": builder._make_name(hierarchy[-1], local=False),
"SIZE": memory.depth, "SIZE": fragment._depth,
"OFFSET": 0, "OFFSET": 0,
"ABITS": _ast.Shape.cast(range(memory.depth)).width, "ABITS": _ast.Shape.cast(range(fragment._depth)).width,
"WIDTH": memory.width, "WIDTH": fragment._width,
"INIT": init, "INIT": init,
"RD_PORTS": len(fragment.read_ports), "RD_PORTS": len(fragment._read_ports),
"RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment.read_ports))), "RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment._read_ports))),
"RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment.read_ports))), "RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment._read_ports))),
"RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment.read_ports) * len(fragment.write_ports))), "RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment._read_ports) * len(fragment._write_ports))),
"RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment.read_ports) * len(fragment.write_ports))), "RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment._read_ports) * len(fragment._write_ports))),
"RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment.read_ports))), "RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._read_ports))),
"RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment.read_ports))), "RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment._read_ports))),
"RD_ARST_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width), "RD_ARST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width),
"RD_SRST_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width), "RD_SRST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width),
"RD_INIT_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width), "RD_INIT_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width),
"WR_PORTS": len(fragment.write_ports), "WR_PORTS": len(fragment._write_ports),
"WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment.write_ports))), "WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment._write_ports))),
"WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment.write_ports))), "WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment._write_ports))),
"WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment.write_ports) * len(fragment.write_ports))), "WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment._write_ports) * len(fragment._write_ports))),
"WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment.write_ports))), "WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._write_ports))),
} }
def make_en(port):
if len(port._data) == 0:
return _ast.Const(0, 0)
granularity = len(port._data) // len(port._en)
return _ast.Cat(en_bit.replicate(granularity) for en_bit in port._en)
port_map = { port_map = {
"\\RD_CLK": _ast.Cat(rd_clk), "\\RD_CLK": _ast.Cat(rd_clk),
"\\RD_EN": _ast.Cat(port.en for port in fragment.read_ports), "\\RD_EN": _ast.Cat(port._en for port in fragment._read_ports),
"\\RD_ARST": _ast.Const(0, len(fragment.read_ports)), "\\RD_ARST": _ast.Const(0, len(fragment._read_ports)),
"\\RD_SRST": _ast.Const(0, len(fragment.read_ports)), "\\RD_SRST": _ast.Const(0, len(fragment._read_ports)),
"\\RD_ADDR": _ast.Cat(port.addr for port in fragment.read_ports), "\\RD_ADDR": _ast.Cat(port._addr for port in fragment._read_ports),
"\\RD_DATA": _ast.Cat(port.data for port in fragment.read_ports), "\\RD_DATA": _ast.Cat(port._data for port in fragment._read_ports),
"\\WR_CLK": _ast.Cat(wr_clk), "\\WR_CLK": _ast.Cat(wr_clk),
"\\WR_EN": _ast.Cat(_ast.Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in fragment.write_ports), "\\WR_EN": _ast.Cat(make_en(port) for port in fragment._write_ports),
"\\WR_ADDR": _ast.Cat(port.addr for port in fragment.write_ports), "\\WR_ADDR": _ast.Cat(port._addr for port in fragment._write_ports),
"\\WR_DATA": _ast.Cat(port.data for port in fragment.write_ports), "\\WR_DATA": _ast.Cat(port._data for port in fragment._write_ports),
} }
return "$mem_v2", port_map, params return "$mem_v2", port_map, params
@ -913,7 +915,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
if isinstance(subfragment, _ir.Instance): if isinstance(subfragment, _ir.Instance):
src = _src(subfragment.src_loc) src = _src(subfragment.src_loc)
elif isinstance(subfragment, _mem.MemoryInstance): elif isinstance(subfragment, _mem.MemoryInstance):
src = _src(subfragment.memory.src_loc) src = _src(subfragment._src_loc)
else: else:
src = "" src = ""

View file

@ -183,6 +183,7 @@ class Fragment:
def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"): def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
assert mode in ("silent", "warn", "error") assert mode in ("silent", "warn", "error")
from ._mem import MemoryInstance
driver_subfrags = SignalDict() driver_subfrags = SignalDict()
def add_subfrag(registry, entity, entry): def add_subfrag(registry, entity, entry):
@ -214,7 +215,7 @@ class Fragment:
# Always flatten subfragments that explicitly request it. # Always flatten subfragments that explicitly request it.
flatten_subfrags.add((subfrag, subfrag_hierarchy)) flatten_subfrags.add((subfrag, subfrag_hierarchy))
if isinstance(subfrag, Instance): if isinstance(subfrag, (Instance, MemoryInstance)):
# Never flatten instances. # Never flatten instances.
continue continue
@ -368,6 +369,8 @@ class Fragment:
return new_domains return new_domains
def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top): def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
from ._mem import MemoryInstance
def add_uses(*sigs, self=self): def add_uses(*sigs, self=self):
for sig in flatten(sigs): for sig in flatten(sigs):
if sig not in uses: if sig not in uses:
@ -416,6 +419,22 @@ class Fragment:
if dir == "io": if dir == "io":
subfrag.add_ports(value._lhs_signals(), dir=dir) subfrag.add_ports(value._lhs_signals(), dir=dir)
add_io(value._lhs_signals()) add_io(value._lhs_signals())
elif isinstance(subfrag, MemoryInstance):
for port in subfrag._read_ports:
subfrag.add_ports(port._data._lhs_signals(), dir="o")
add_defs(port._data._lhs_signals())
for value in [port._addr, port._en]:
subfrag.add_ports(value._rhs_signals(), dir="i")
add_uses(value._rhs_signals())
for port in subfrag._write_ports:
for value in [port._addr, port._en, port._data]:
subfrag.add_ports(value._rhs_signals(), dir="i")
add_uses(value._rhs_signals())
for domain, _ in subfrag.iter_sync():
cd = subfrag.domains[domain]
add_uses(cd.clk)
if cd.rst is not None:
add_uses(cd.rst)
else: else:
parent[subfrag] = self parent[subfrag] = self
level [subfrag] = level[self] + 1 level [subfrag] = level[self] + 1

View file

@ -3,12 +3,104 @@ from collections import OrderedDict
from .. import tracer from .. import tracer
from ._ast import * from ._ast import *
from ._ir import Elaboratable, Instance, Fragment from ._ir import Elaboratable, Fragment
from ..utils import ceil_log2
__all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"] __all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"]
class MemoryIdentity: pass
class MemorySimRead:
def __init__(self, identity, addr):
assert isinstance(identity, MemoryIdentity)
self._identity = identity
self._addr = Value.cast(addr)
def eq(self, value):
return MemorySimWrite(self._identity, self._addr, value)
class MemorySimWrite:
def __init__(self, identity, addr, data):
assert isinstance(identity, MemoryIdentity)
self._identity = identity
self._addr = Value.cast(addr)
self._data = Value.cast(data)
class MemoryInstance(Fragment):
class _ReadPort:
def __init__(self, *, domain, addr, data, en, transparency):
assert domain is None or isinstance(domain, str)
if domain == "comb":
domain = None
self._domain = domain
self._addr = Value.cast(addr)
self._data = Value.cast(data)
self._en = Value.cast(en)
self._transparency = tuple(transparency)
assert len(self._en) == 1
if domain is None:
assert isinstance(self._en, Const)
assert self._en.width == 1
assert self._en.value == 1
class _WritePort:
def __init__(self, *, domain, addr, data, en):
assert isinstance(domain, str)
assert domain != "comb"
self._domain = domain
self._addr = Value.cast(addr)
self._data = Value.cast(data)
self._en = Value.cast(en)
if len(self._data):
assert len(self._data) % len(self._en) == 0
@property
def _granularity(self):
if not len(self._data):
return 1
return len(self._data) // len(self._en)
def __init__(self, *, identity, width, depth, init=None, attrs=None, src_loc=None):
super().__init__()
assert isinstance(identity, MemoryIdentity)
self._identity = identity
self._width = operator.index(width)
self._depth = operator.index(depth)
self._init = tuple(init) if init is not None else ()
assert len(self._init) <= self._depth
self._init += (0,) * (self._depth - len(self._init))
for x in self._init:
assert isinstance(x, int)
self._attrs = attrs or {}
self._src_loc = src_loc
self._read_ports = []
self._write_ports = []
def read_port(self, *, domain, addr, data, en, transparency):
port = self._ReadPort(domain=domain, addr=addr, data=data, en=en, transparency=transparency)
assert len(port._data) == self._width
assert len(port._addr) == ceil_log2(self._depth)
for x in port._transparency:
assert isinstance(x, int)
assert x in range(len(self._write_ports))
for signal in port._data._rhs_signals():
self.add_driver(signal, port._domain)
self._read_ports.append(port)
def write_port(self, *, domain, addr, data, en):
port = self._WritePort(domain=domain, addr=addr, data=data, en=en)
assert len(port._data) == self._width
assert len(port._addr) == ceil_log2(self._depth)
self._write_ports.append(port)
return len(self._write_ports) - 1
class Memory(Elaboratable): class Memory(Elaboratable):
"""A word addressable storage. """A word addressable storage.
@ -50,16 +142,10 @@ class Memory(Elaboratable):
self.depth = depth self.depth = depth
self.attrs = OrderedDict(() if attrs is None else attrs) self.attrs = OrderedDict(() if attrs is None else attrs)
# Array of signals for simulation.
self._array = Array()
if simulate:
for addr in range(self.depth):
self._array.append(Signal(self.width, name="{}({})"
.format(name or "memory", addr)))
self.init = init self.init = init
self._read_ports = [] self._read_ports = []
self._write_ports = [] self._write_ports = []
self._identity = MemoryIdentity()
@property @property
def init(self): def init(self):
@ -73,11 +159,8 @@ class Memory(Elaboratable):
.format(len(self.init), self.depth)) .format(len(self.init), self.depth))
try: try:
for addr in range(len(self._array)): for addr, val in enumerate(self._init):
if addr < len(self._init): operator.index(val)
self._array[addr].reset = operator.index(self._init[addr])
else:
self._array[addr].reset = 0
except TypeError as e: except TypeError as e:
raise TypeError("Memory initialization value at address {:x}: {}" raise TypeError("Memory initialization value at address {:x}: {}"
.format(addr, e)) from None .format(addr, e)) from None
@ -116,52 +199,24 @@ class Memory(Elaboratable):
def __getitem__(self, index): def __getitem__(self, index):
"""Simulation only.""" """Simulation only."""
return self._array[index] return MemorySimRead(self._identity, index)
def elaborate(self, platform): def elaborate(self, platform):
f = MemoryInstance(self, self._read_ports, self._write_ports) f = MemoryInstance(identity=self._identity, width=self.width, depth=self.depth, init=self.init, attrs=self.attrs, src_loc=self.src_loc)
write_ports = {}
for port in self._write_ports:
port._MustUse__used = True
iport = f.write_port(domain=port.domain, addr=port.addr, data=port.data, en=port.en)
write_ports.setdefault(port.domain, []).append(iport)
for port in self._read_ports: for port in self._read_ports:
port._MustUse__used = True port._MustUse__used = True
if port.domain == "comb": if port.domain == "comb":
# Asynchronous port f.read_port(domain="comb", addr=port.addr, data=port.data, en=Const(1), transparency=())
f.add_statements(None, port.data.eq(self._array[port.addr]))
f.add_driver(port.data)
else: else:
# Synchronous port transparency = []
data = self._array[port.addr] if port.transparent:
for write_port in self._write_ports: transparency = write_ports.get(port.domain, [])
if port.domain == write_port.domain and port.transparent: f.read_port(domain=port.domain, addr=port.addr, data=port.data, en=port.en, transparency=transparency)
if len(write_port.en) > 1:
parts = []
for index, en_bit in enumerate(write_port.en):
offset = index * write_port.granularity
bits = slice(offset, offset + write_port.granularity)
cond = en_bit & (port.addr == write_port.addr)
parts.append(Mux(cond, write_port.data[bits], data[bits]))
data = Cat(parts)
else:
cond = write_port.en & (port.addr == write_port.addr)
data = Mux(cond, write_port.data, data)
f.add_statements(
port.domain,
Switch(port.en, {
1: port.data.eq(data)
})
)
f.add_driver(port.data, port.domain)
for port in self._write_ports:
port._MustUse__used = True
if len(port.en) > 1:
for index, en_bit in enumerate(port.en):
offset = index * port.granularity
bits = slice(offset, offset + port.granularity)
write_data = self._array[port.addr][bits].eq(port.data[bits])
f.add_statements(port.domain, Switch(en_bit, { 1: write_data }))
else:
write_data = self._array[port.addr].eq(port.data)
f.add_statements(port.domain, Switch(port.en, { 1: write_data }))
for signal in self._array:
f.add_driver(signal, port.domain)
return f return f
@ -308,12 +363,3 @@ class DummyPort:
name=f"{name}_data", src_loc_at=1) name=f"{name}_data", src_loc_at=1)
self.en = Signal(data_width // granularity, self.en = Signal(data_width // granularity,
name=f"{name}_en", src_loc_at=1) name=f"{name}_en", src_loc_at=1)
class MemoryInstance(Fragment):
def __init__(self, memory, read_ports, write_ports):
super().__init__()
self.memory = memory
self.read_ports = read_ports
self.write_ports = write_ports
self.attrs = memory.attrs

View file

@ -1,9 +1,8 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable from collections.abc import Iterable
from copy import copy
from .._utils import flatten, _ignore_deprecated from .._utils import flatten
from .. import tracer from .. import tracer
from ._ast import * from ._ast import *
from ._ast import _StatementList, AnyValue, Property from ._ast import _StatementList, AnyValue, Property
@ -239,27 +238,45 @@ class FragmentTransformer:
new_fragment.add_driver(signal, domain) new_fragment.add_driver(signal, domain)
def map_memory_ports(self, fragment, new_fragment): def map_memory_ports(self, fragment, new_fragment):
new_fragment.read_ports = [
copy(port)
for port in fragment.read_ports
]
new_fragment.write_ports = [
copy(port)
for port in fragment.write_ports
]
if hasattr(self, "on_value"): if hasattr(self, "on_value"):
for port in new_fragment.read_ports: for port in new_fragment._read_ports:
port.en = self.on_value(port.en) port._en = self.on_value(port._en)
port.addr = self.on_value(port.addr) port._addr = self.on_value(port._addr)
port.data = self.on_value(port.data) port._data = self.on_value(port._data)
for port in new_fragment.write_ports: for port in new_fragment._write_ports:
port.en = self.on_value(port.en) port._en = self.on_value(port._en)
port.addr = self.on_value(port.addr) port._addr = self.on_value(port._addr)
port.data = self.on_value(port.data) port._data = self.on_value(port._data)
def on_fragment(self, fragment): def on_fragment(self, fragment):
if isinstance(fragment, MemoryInstance): if isinstance(fragment, MemoryInstance):
new_fragment = MemoryInstance(fragment.memory, [], []) new_fragment = MemoryInstance(
identity=fragment._identity,
width=fragment._width,
depth=fragment._depth,
init=fragment._init,
attrs=fragment._attrs,
src_loc=fragment._src_loc
)
new_fragment._read_ports = [
MemoryInstance._ReadPort(
domain=port._domain,
addr=port._addr,
data=port._data,
en=port._en,
transparency=port._transparency,
)
for port in fragment._read_ports
]
new_fragment._write_ports = [
MemoryInstance._WritePort(
domain=port._domain,
addr=port._addr,
data=port._data,
en=port._en,
)
for port in fragment._write_ports
]
self.map_memory_ports(fragment, new_fragment) self.map_memory_ports(fragment, new_fragment)
elif isinstance(fragment, Instance): elif isinstance(fragment, Instance):
new_fragment = Instance(fragment.type, src_loc=fragment.src_loc) new_fragment = Instance(fragment.type, src_loc=fragment.src_loc)
@ -376,17 +393,16 @@ class DomainCollector(ValueVisitor, StatementVisitor):
def on_fragment(self, fragment): def on_fragment(self, fragment):
if isinstance(fragment, MemoryInstance): if isinstance(fragment, MemoryInstance):
for port in fragment.read_ports: for port in fragment._read_ports:
self.on_value(port.addr) self.on_value(port._addr)
self.on_value(port.data) self.on_value(port._data)
self.on_value(port.en) self.on_value(port._en)
if port.domain != "comb": self._add_used_domain(port._domain)
self._add_used_domain(port.domain) for port in fragment._write_ports:
for port in fragment.write_ports: self.on_value(port._addr)
self.on_value(port.addr) self.on_value(port._data)
self.on_value(port.data) self.on_value(port._en)
self.on_value(port.en) self._add_used_domain(port._domain)
self._add_used_domain(port.domain)
if isinstance(fragment, Instance): if isinstance(fragment, Instance):
for name, (value, dir) in fragment.named_ports.items(): for name, (value, dir) in fragment.named_ports.items():
@ -460,12 +476,12 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer)
def map_memory_ports(self, fragment, new_fragment): def map_memory_ports(self, fragment, new_fragment):
super().map_memory_ports(fragment, new_fragment) super().map_memory_ports(fragment, new_fragment)
for port in new_fragment.read_ports: for port in new_fragment._read_ports:
if port.domain in self.domain_map: if port._domain in self.domain_map:
port.domain = self.domain_map[port.domain] port._domain = self.domain_map[port._domain]
for port in new_fragment.write_ports: for port in new_fragment._write_ports:
if port.domain in self.domain_map: if port._domain in self.domain_map:
port.domain = self.domain_map[port.domain] port._domain = self.domain_map[port._domain]
class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
@ -645,10 +661,10 @@ class EnableInserter(_ControlInserter):
def on_fragment(self, fragment): def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment) new_fragment = super().on_fragment(fragment)
if isinstance(new_fragment, MemoryInstance): if isinstance(new_fragment, MemoryInstance):
for port in new_fragment.read_ports: for port in new_fragment._read_ports:
if port.domain in self.controls: if port._domain in self.controls:
port.en = port.en & self.controls[port.domain] port._en = port._en & self.controls[port._domain]
for port in new_fragment.write_ports: for port in new_fragment._write_ports:
if port.domain in self.controls: if port._domain in self.controls:
port.en = Mux(self.controls[port.domain], port.en, Const(0, len(port.en))) port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en)))
return new_fragment return new_fragment

View file

@ -1,4 +1,4 @@
__all__ = ["BaseProcess", "BaseSignalState", "BaseSimulation", "BaseEngine"] __all__ = ["BaseProcess", "BaseSignalState", "BaseMemoryState", "BaseSimulation", "BaseEngine"]
class BaseProcess: class BaseProcess:
@ -12,7 +12,7 @@ class BaseProcess:
self.passive = True self.passive = True
def run(self): def run(self):
raise NotImplementedError raise NotImplementedError # :nocov:
class BaseSignalState: class BaseSignalState:
@ -24,44 +24,62 @@ class BaseSignalState:
next = NotImplemented next = NotImplemented
def set(self, value): def set(self, value):
raise NotImplementedError raise NotImplementedError # :nocov:
class BaseMemoryState:
__slots__ = ()
memory = NotImplemented
def read(self, addr):
raise NotImplementedError # :nocov:
def write(self, addr, value):
raise NotImplementedError # :nocov:
class BaseSimulation: class BaseSimulation:
def reset(self): def reset(self):
raise NotImplementedError raise NotImplementedError # :nocov:
def get_signal(self, signal): def get_signal(self, signal):
raise NotImplementedError raise NotImplementedError # :nocov:
slots = NotImplemented slots = NotImplemented
def add_trigger(self, process, signal, *, trigger=None): def add_trigger(self, process, signal, *, trigger=None):
raise NotImplementedError raise NotImplementedError # :nocov:
def remove_trigger(self, process, signal): def remove_trigger(self, process, signal):
raise NotImplementedError raise NotImplementedError # :nocov:
def add_memory_trigger(self, process, identity):
raise NotImplementedError # :nocov:
def remove_memory_trigger(self, process, identity):
raise NotImplementedError # :nocov:
def wait_interval(self, process, interval): def wait_interval(self, process, interval):
raise NotImplementedError raise NotImplementedError # :nocov:
class BaseEngine: class BaseEngine:
def add_coroutine_process(self, process, *, default_cmd): def add_coroutine_process(self, process, *, default_cmd):
raise NotImplementedError raise NotImplementedError # :nocov:
def add_clock_process(self, clock, *, phase, period): def add_clock_process(self, clock, *, phase, period):
raise NotImplementedError raise NotImplementedError # :nocov:
def reset(self): def reset(self):
raise NotImplementedError raise NotImplementedError # :nocov:
@property @property
def now(self): def now(self):
raise NotImplementedError raise NotImplementedError # :nocov:
def advance(self): def advance(self):
raise NotImplementedError raise NotImplementedError # :nocov:
def write_vcd(self, *, vcd_file, gtkw_file, traces): def write_vcd(self, *, vcd_file, gtkw_file, traces):
raise NotImplementedError raise NotImplementedError # :nocov:

View file

@ -2,8 +2,9 @@ import inspect
from ..hdl import * from ..hdl import *
from ..hdl._ast import Statement, SignalSet, ValueCastable from ..hdl._ast import Statement, SignalSet, ValueCastable
from ..hdl._mem import MemorySimRead, MemorySimWrite
from .core import Tick, Settle, Delay, Passive, Active from .core import Tick, Settle, Delay, Passive, Active
from ._base import BaseProcess from ._base import BaseProcess, BaseMemoryState
from ._pyrtl import _ValueCompiler, _RHSValueCompiler, _StatementCompiler from ._pyrtl import _ValueCompiler, _RHSValueCompiler, _StatementCompiler
@ -119,6 +120,27 @@ class PyCoroProcess(BaseProcess):
elif type(command) is Active: elif type(command) is Active:
self.passive = False self.passive = False
elif type(command) is MemorySimRead:
exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"),
self.exec_locals)
addr = Const(self.exec_locals["result"], command._addr.shape()).value
index = self.state.memories[command._identity]
state = self.state.slots[index]
assert isinstance(state, BaseMemoryState)
response = state.read(addr)
elif type(command) is MemorySimWrite:
exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"),
self.exec_locals)
addr = Const(self.exec_locals["result"], command._addr.shape()).value
exec(_RHSValueCompiler.compile(self.state, command._data, mode="curr"),
self.exec_locals)
data = Const(self.exec_locals["result"], command._data.shape()).value
index = self.state.memories[command._identity]
state = self.state.slots[index]
assert isinstance(state, BaseMemoryState)
state.write(addr, data)
elif command is None: # only possible if self.default_cmd is None elif command is None: # only possible if self.default_cmd is None
raise TypeError("Received default command from process {!r} that was added " raise TypeError("Received default command from process {!r} that was added "
"with add_process(); did you mean to add this process with " "with add_process(); did you mean to add this process with "

View file

@ -4,8 +4,9 @@ from contextlib import contextmanager
import sys import sys
from ..hdl import * from ..hdl import *
from ..hdl._ast import SignalSet from ..hdl._ast import SignalSet, _StatementList
from ..hdl._xfrm import ValueVisitor, StatementVisitor from ..hdl._xfrm import ValueVisitor, StatementVisitor
from ..hdl._mem import MemoryInstance
from ._base import BaseProcess from ._base import BaseProcess
@ -409,10 +410,25 @@ class _FragmentCompiler:
def __call__(self, fragment): def __call__(self, fragment):
processes = set() processes = set()
for domain_name, domain_stmts in fragment.statements.items(): domains = set(fragment.statements)
if isinstance(fragment, MemoryInstance):
self.state.add_memory(fragment)
for port in fragment._read_ports:
domains.add(port._domain)
for port in fragment._write_ports:
domains.add(port._domain)
for domain_name in domains:
domain_stmts = fragment.statements.get(domain_name, _StatementList())
domain_process = PyRTLProcess(is_comb=domain_name is None) domain_process = PyRTLProcess(is_comb=domain_name is None)
domain_signals = domain_stmts._lhs_signals() domain_signals = domain_stmts._lhs_signals()
if isinstance(fragment, MemoryInstance):
for port in fragment._read_ports:
if port._domain == domain_name:
domain_signals.update(port._data._lhs_signals())
emitter = _PythonEmitter() emitter = _PythonEmitter()
emitter.append(f"def run():") emitter.append(f"def run():")
emitter._level += 1 emitter._level += 1
@ -425,6 +441,21 @@ class _FragmentCompiler:
inputs = SignalSet() inputs = SignalSet()
_StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts) _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
if isinstance(fragment, MemoryInstance):
self.state.add_memory_trigger(domain_process, fragment._identity)
memory_index = self.state.memories[fragment._identity]
rhs = _RHSValueCompiler(self.state, emitter, mode="curr", inputs=inputs)
lhs = _LHSValueCompiler(self.state, emitter, rhs=rhs)
for port in fragment._read_ports:
if port._domain is not None:
continue
addr = rhs(port._addr)
addr = f"({(1 << len(port._addr)) - 1:#x} & {addr})"
data = emitter.def_var("read_data", f"slots[{memory_index}].read({addr})")
lhs(port._data)(data)
for input in inputs: for input in inputs:
self.state.add_trigger(domain_process, input) self.state.add_trigger(domain_process, input)
@ -442,6 +473,47 @@ class _FragmentCompiler:
_StatementCompiler(self.state, emitter)(domain_stmts) _StatementCompiler(self.state, emitter)(domain_stmts)
if isinstance(fragment, MemoryInstance):
memory_index = self.state.memories[fragment._identity]
rhs = _RHSValueCompiler(self.state, emitter, mode="curr")
lhs = _LHSValueCompiler(self.state, emitter, rhs=rhs)
write_vals = {}
for idx, port in enumerate(fragment._write_ports):
if port._domain != domain_name:
continue
addr = rhs(port._addr)
addr = emitter.def_var("write_addr", f"({(1 << len(port._addr)) - 1:#x} & {addr})")
data = rhs(port._data)
data = emitter.def_var("write_data", f"({(1 << len(port._data)) - 1:#x} & {data})")
en = rhs(Cat(bit.replicate(port._granularity) for bit in port._en))
en = emitter.def_var("write_en", f"({(1 << len(port._data)) - 1:#x} & {en})")
emitter.append(f"slots[{memory_index}].write({addr}, {data}, {en})")
write_vals[idx] = addr, data, en
for port in fragment._read_ports:
if port._domain != domain_name:
continue
en = rhs(port._en)
en = f"(1 & {en})"
emitter.append(f"if {en}:")
with emitter.indent():
addr = rhs(port._addr)
addr = emitter.def_var("read_addr", f"({(1 << len(port._addr)) - 1:#x} & {addr})")
data = emitter.def_var("read_data", f"slots[{memory_index}].read({addr})")
for idx in port._transparency:
waddr, wdata, wen = write_vals[idx]
emitter.append(f"if {addr} == {waddr}:")
with emitter.indent():
emitter.append(f"{data} &= ~{wen}")
emitter.append(f"{data} |= {wdata} & {wen}")
lhs(port._data)(data)
for signal in domain_signals: for signal in domain_signals:
signal_index = self.state.get_signal(signal) signal_index = self.state.get_signal(signal)
emitter.append(f"slots[{signal_index}].set(next_{signal_index})") emitter.append(f"slots[{signal_index}].set(next_{signal_index})")

View file

@ -6,6 +6,7 @@ from vcd.gtkw import GTKWSave
from ..hdl import * from ..hdl import *
from ..hdl._repr import * from ..hdl._repr import *
from ..hdl._mem import MemoryInstance, MemoryIdentity
from ..hdl._ast import SignalDict, Slice, Operator from ..hdl._ast import SignalDict, Slice, Operator
from ._base import * from ._base import *
from ._pyrtl import _FragmentCompiler from ._pyrtl import _FragmentCompiler
@ -43,28 +44,34 @@ class _VCDWriter:
if isinstance(gtkw_file, str): if isinstance(gtkw_file, str):
gtkw_file = open(gtkw_file, "w") gtkw_file = open(gtkw_file, "w")
self.vcd_vars = SignalDict() self.vcd_signal_vars = SignalDict()
self.vcd_memory_vars = {}
self.vcd_file = vcd_file self.vcd_file = vcd_file
self.vcd_writer = vcd_file and VCDWriter(self.vcd_file, self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
timescale="1 ps", comment="Generated by Amaranth") timescale="1 ps", comment="Generated by Amaranth")
self.gtkw_names = SignalDict() self.gtkw_signal_names = SignalDict()
self.gtkw_memory_names = {}
self.gtkw_file = gtkw_file self.gtkw_file = gtkw_file
self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file) self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
self.traces = [] self.traces = []
signal_names = SignalDict() signal_names = SignalDict()
memories = {}
for subfragment, subfragment_name in \ for subfragment, subfragment_name in \
fragment._assign_names_to_fragments(hierarchy=("bench", "top",)).items(): fragment._assign_names_to_fragments(hierarchy=("bench", "top",)).items():
for signal, signal_name in subfragment._assign_names_to_signals().items(): for signal, signal_name in subfragment._assign_names_to_signals().items():
if signal not in signal_names: if signal not in signal_names:
signal_names[signal] = set() signal_names[signal] = set()
signal_names[signal].add((*subfragment_name, signal_name)) signal_names[signal].add((*subfragment_name, signal_name))
if isinstance(subfragment, MemoryInstance):
memories[subfragment._identity] = (subfragment, subfragment_name)
trace_names = SignalDict() trace_names = SignalDict()
assigned_names = set() assigned_names = set()
for trace in traces: for trace in traces:
if isinstance(trace, ValueLike):
trace = Value.cast(trace) trace = Value.cast(trace)
for trace_signal in trace._rhs_signals(): for trace_signal in trace._rhs_signals():
if trace_signal not in signal_names: if trace_signal not in signal_names:
@ -76,14 +83,19 @@ class _VCDWriter:
trace_names[trace_signal] = {("bench", name)} trace_names[trace_signal] = {("bench", name)}
assigned_names.add(name) assigned_names.add(name)
self.traces.append(trace_signal) self.traces.append(trace_signal)
elif isinstance(trace, (MemoryInstance, Memory)):
if not trace._identity in memories:
raise ValueError(f"{trace!r} is a memory not part of the elaborated design")
self.traces.append(trace._identity)
else:
raise TypeError(f"{trace!r} is not a traceable object")
if self.vcd_writer is None: if self.vcd_writer is None:
return return
for signal, names in itertools.chain(signal_names.items(), trace_names.items()): for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
fields = [] self.vcd_signal_vars[signal] = []
self.vcd_vars[signal] = [] self.gtkw_signal_names[signal] = []
self.gtkw_names[signal] = []
for repr in signal._value_repr: for repr in signal._value_repr:
var_init = self.eval_field(repr.value, signal, signal.reset) var_init = self.eval_field(repr.value, signal, signal.reset)
if isinstance(repr.format, FormatInt): if isinstance(repr.format, FormatInt):
@ -119,22 +131,45 @@ class _VCDWriter:
gtkw_field_name = '\\' + field_name gtkw_field_name = '\\' + field_name
else: else:
gtkw_field_name = field_name gtkw_field_name = field_name
self.gtkw_names[signal].append(".".join((*var_scope, gtkw_field_name)) + suffix) self.gtkw_signal_names[signal].append(".".join((*var_scope, gtkw_field_name)) + suffix)
else: else:
self.vcd_writer.register_alias( self.vcd_writer.register_alias(
scope=var_scope, name=field_name, scope=var_scope, name=field_name,
var=vcd_var) var=vcd_var)
self.vcd_vars[signal].append((vcd_var, repr)) self.vcd_signal_vars[signal].append((vcd_var, repr))
for memory, memory_name in memories.values():
self.vcd_memory_vars[memory._identity] = vcd_vars = []
self.gtkw_memory_names[memory._identity] = gtkw_names = []
if memory._width > 1:
suffix = f"[{memory._width - 1}:0]"
else:
suffix = ""
for idx, init in enumerate(memory._init):
field_name = "\\" + memory_name[-1] + f"[{idx}]"
var_scope = memory_name[:-1]
vcd_var = self.vcd_writer.register_var(
scope=var_scope, name=field_name,
var_type="wire", size=memory._width, init=init,
)
vcd_vars.append(vcd_var)
gtkw_field_name = field_name + suffix
gtkw_name = ".".join((*var_scope, gtkw_field_name))
gtkw_names.append(gtkw_name)
def update(self, timestamp, signal, value): def update_signal(self, timestamp, signal, value):
for (vcd_var, repr) in self.vcd_vars.get(signal, ()): for (vcd_var, repr) in self.vcd_signal_vars.get(signal, ()):
var_value = self.eval_field(repr.value, signal, value) var_value = self.eval_field(repr.value, signal, value)
if not isinstance(repr.format, FormatInt): if not isinstance(repr.format, FormatInt):
var_value = self.decode_to_vcd(repr.format, var_value) var_value = self.decode_to_vcd(repr.format, var_value)
self.vcd_writer.change(vcd_var, timestamp, var_value) self.vcd_writer.change(vcd_var, timestamp, var_value)
def update_memory(self, timestamp, memory, addr, value):
vcd_var = self.vcd_memory_vars[memory._identity][addr]
self.vcd_writer.change(vcd_var, timestamp, value)
def close(self, timestamp): def close(self, timestamp):
if self.vcd_writer is not None: if self.vcd_writer is not None:
self.vcd_writer.close(timestamp) self.vcd_writer.close(timestamp)
@ -145,8 +180,14 @@ class _VCDWriter:
self.gtkw_save.treeopen("top") self.gtkw_save.treeopen("top")
for signal in self.traces: for signal in self.traces:
for name in self.gtkw_names[signal]: if isinstance(signal, Signal):
for name in self.gtkw_signal_names[signal]:
self.gtkw_save.trace(name) self.gtkw_save.trace(name)
elif isinstance(signal, MemoryIdentity):
for name in self.gtkw_memory_names[signal]:
self.gtkw_save.trace(name)
else:
assert False # :nocov:
if self.vcd_file is not None: if self.vcd_file is not None:
self.vcd_file.close() self.vcd_file.close()
@ -208,7 +249,7 @@ class _PySignalState(BaseSignalState):
def __init__(self, signal, pending): def __init__(self, signal, pending):
self.signal = signal self.signal = signal
self.pending = pending self.pending = pending
self.waiters = dict() self.waiters = {}
self.curr = self.next = signal.reset self.curr = self.next = signal.reset
def set(self, value): def set(self, value):
@ -229,17 +270,83 @@ class _PySignalState(BaseSignalState):
return awoken_any return awoken_any
class _PyMemoryChange:
__slots__ = ("state", "addr")
def __init__(self, state, addr):
self.state = state
self.addr = addr
class _PyMemoryState(BaseMemoryState):
__slots__ = ("memory", "data", "write_queue", "waiters", "pending")
def __init__(self, memory, pending):
self.memory = memory
self.pending = pending
self.waiters = {}
self.reset()
def reset(self):
self.data = list(self.memory._init)
self.write_queue = []
def commit(self):
if not self.write_queue:
return False
for addr, value, mask in self.write_queue:
curr = self.data[addr]
value = (value & mask) | (curr & ~mask)
self.data[addr] = value
self.write_queue.clear()
awoken_any = False
for process in self.waiters:
process.runnable = awoken_any = True
return awoken_any
def read(self, addr):
if addr not in range(self.memory._depth):
return 0
return self.data[addr]
def write(self, addr, value, mask=None):
if addr not in range(self.memory._depth):
return
if mask == 0:
return
if mask is None:
mask = (1 << self.memory._width) - 1
self.write_queue.append((addr, value, mask))
self.pending.add(self)
class _PySimulation(BaseSimulation): class _PySimulation(BaseSimulation):
def __init__(self): def __init__(self):
self.timeline = _Timeline() self.timeline = _Timeline()
self.signals = SignalDict() self.signals = SignalDict()
self.memories = {}
self.slots = [] self.slots = []
self.pending = set() self.pending = set()
def add_memory(self, fragment):
self.memories[fragment._identity] = len(self.slots)
self.slots.append(_PyMemoryState(fragment, self.pending))
def reset(self): def reset(self):
self.timeline.reset() self.timeline.reset()
for signal, index in self.signals.items(): for signal, index in self.signals.items():
self.slots[index].curr = self.slots[index].next = signal.reset state = self.slots[index]
assert isinstance(state, _PySignalState)
state.curr = state.next = signal.reset
for index in self.memories.values():
state = self.slots[index]
assert isinstance(state, _PyMemoryState)
state.reset()
self.pending.clear() self.pending.clear()
def get_signal(self, signal): def get_signal(self, signal):
@ -262,16 +369,31 @@ class _PySimulation(BaseSimulation):
assert process in self.slots[index].waiters assert process in self.slots[index].waiters
del self.slots[index].waiters[process] del self.slots[index].waiters[process]
def add_memory_trigger(self, process, identity):
index = self.memories[identity]
self.slots[index].waiters[process] = None
def remove_memory_trigger(self, process, identity):
index = self.memories[identity]
assert process in self.slots[index].waiters
del self.slots[index].waiters[process]
def wait_interval(self, process, interval): def wait_interval(self, process, interval):
self.timeline.delay(interval, process) self.timeline.delay(interval, process)
def commit(self, changed=None): def commit(self, changed=None):
converged = True converged = True
for signal_state in self.pending: for state in self.pending:
if signal_state.commit():
converged = False
if changed is not None: if changed is not None:
changed.update(self.pending) if isinstance(state, _PyMemoryState):
for addr, _value, _mask in state.write_queue:
changed.add(_PyMemoryChange(state, addr))
elif isinstance(state, _PySignalState):
changed.add(state)
else:
assert False # :nocov:
if state.commit():
converged = False
self.pending.clear() self.pending.clear()
return converged return converged
@ -314,9 +436,16 @@ class PySimEngine(BaseEngine):
converged = self._state.commit(changed) converged = self._state.commit(changed)
for vcd_writer in self._vcd_writers: for vcd_writer in self._vcd_writers:
for signal_state in changed: for change in changed:
vcd_writer.update(self._timeline.now, if isinstance(change, _PySignalState):
signal_state = change
vcd_writer.update_signal(self._timeline.now,
signal_state.signal, signal_state.curr) signal_state.signal, signal_state.curr)
elif isinstance(change, _PyMemoryChange):
vcd_writer.update_memory(self._timeline.now, change.state.memory,
change.addr, change.state.data[change.addr])
else:
assert False # :nocov:
def advance(self): def advance(self):
self._step() self._step()

View file

@ -138,9 +138,9 @@ class DomainRenamerTestCase(FHDLTestCase):
f = DomainRenamer({"a": "d", "c": "e"})(f) f = DomainRenamer({"a": "d", "c": "e"})(f)
mem = f.subfragments[0][0] mem = f.subfragments[0][0]
self.assertIsInstance(mem, MemoryInstance) self.assertIsInstance(mem, MemoryInstance)
self.assertEqual(mem.read_ports[0].domain, "d") self.assertEqual(mem._read_ports[0]._domain, "d")
self.assertEqual(mem.read_ports[1].domain, "b") self.assertEqual(mem._read_ports[1]._domain, "b")
self.assertEqual(mem.write_ports[0].domain, "e") self.assertEqual(mem._write_ports[0]._domain, "e")
def test_rename_wrong_to_comb(self): def test_rename_wrong_to_comb(self):
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
@ -530,7 +530,7 @@ class EnableInserterTestCase(FHDLTestCase):
mem = Memory(width=8, depth=4) mem = Memory(width=8, depth=4)
mem.read_port(transparent=False) mem.read_port(transparent=False)
f = EnableInserter(self.c1)(mem).elaborate(platform=None) f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.read_ports[0].en, """ self.assertRepr(f._read_ports[0]._en, """
(& (sig mem_r_en) (sig c1)) (& (sig mem_r_en) (sig c1))
""") """)
@ -538,7 +538,7 @@ class EnableInserterTestCase(FHDLTestCase):
mem = Memory(width=8, depth=4) mem = Memory(width=8, depth=4)
mem.write_port(granularity=2) mem.write_port(granularity=2)
f = EnableInserter(self.c1)(mem).elaborate(platform=None) f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.write_ports[0].en, """ self.assertRepr(f._write_ports[0]._en, """
(m (m
(sig c1) (sig c1)
(sig mem_w_en) (sig mem_w_en)