hdl.mem: lower Memory directly to $mem_v2 RTLIL cell.

The design decision of using split memory ports in the internal
representation (copied from Yosys) was misguided and caused no end
of misery. Remove any uses of `$memrd`/`$memwr` and lower memories
directly to a combined memory cell, currently the RTLIL one.
This commit is contained in:
Marcelina Kościelnicka 2023-09-01 05:22:46 +00:00 committed by Catherine
parent fc85feb30d
commit 8c4a15ab92
10 changed files with 183 additions and 193 deletions

View file

@ -863,49 +863,21 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
if sub_name is None:
sub_name = module.anonymous()
sub_params = OrderedDict()
if hasattr(subfragment, "parameters"):
for param_name, param_value in subfragment.parameters.items():
if isinstance(param_value, mem.Memory):
memory = param_value
if memory not in memories:
memories[memory] = module.memory(width=memory.width, size=memory.depth,
name=memory.name, attrs=memory.attrs)
addr_bits = bits_for(memory.depth)
data_parts = []
data_mask = (1 << memory.width) - 1
for addr in range(memory.depth):
if addr < len(memory.init):
data = memory.init[addr] & data_mask
else:
data = 0
data_parts.append("{:0{}b}".format(data, memory.width))
module.cell("$meminit", ports={
"\\ADDR": rhs_compiler(ast.Const(0, addr_bits)),
"\\DATA": "{}'".format(memory.width * memory.depth) +
"".join(reversed(data_parts)),
}, params={
"MEMID": memories[memory],
"ABITS": addr_bits,
"WIDTH": memory.width,
"WORDS": memory.depth,
"PRIORITY": 0,
})
param_value = memories[memory]
sub_params[param_name] = param_value
sub_params = OrderedDict(getattr(subfragment, "parameters", {}))
sub_type, sub_port_map = \
_convert_fragment(builder, subfragment, name_map,
hierarchy=hierarchy + (sub_name,))
if sub_type == "$mem_v2" and "MEMID" not in sub_params:
sub_params["MEMID"] = "$" + sub_name
sub_ports = OrderedDict()
for port, value in sub_port_map.items():
if not isinstance(subfragment, ir.Instance):
for signal in value._rhs_signals():
compiler_state.resolve_curr(signal, prefix=sub_name)
if len(value) > 0:
if len(value) > 0 or sub_type == "$mem_v2":
sub_ports[port] = rhs_compiler(value)
module.cell(sub_type, name=sub_name, ports=sub_ports, params=sub_params,

View file

@ -83,12 +83,6 @@ class _MemoryPort(CompatModule):
self.clock = ClockSignal(clock_domain)
@extend(NativeMemory)
@deprecated("it is not necessary or permitted to add Memory as a special or submodule")
def elaborate(self, platform):
return Fragment()
class CompatMemory(NativeMemory, Elaboratable):
def __init__(self, width, depth, init=None, name=None):
super().__init__(width=width, depth=depth, init=init, name=name)

View file

@ -181,7 +181,6 @@ class Fragment:
assert mode in ("silent", "warn", "error")
driver_subfrags = SignalDict()
memory_subfrags = OrderedDict()
def add_subfrag(registry, entity, entry):
# Because of missing domain insertion, at the point when this code runs, we have
# a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to
@ -212,24 +211,16 @@ class Fragment:
flatten_subfrags.add((subfrag, subfrag_hierarchy))
if isinstance(subfrag, Instance):
# For memories (which are subfragments, but semantically a part of superfragment),
# record that this fragment is driving it.
if subfrag.type in ("$memrd", "$memwr"):
memory = subfrag.parameters["MEMID"]
add_subfrag(memory_subfrags, memory, (None, hierarchy))
# Never flatten instances.
continue
# First, recurse into subfragments and let them detect driver conflicts as well.
subfrag_drivers, subfrag_memories = \
subfrag_drivers = \
subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
# Second, classify subfragments by signals they drive and memories they use.
# Second, classify subfragments by signals they drive.
for signal in subfrag_drivers:
add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
for memory in subfrag_memories:
add_subfrag(memory_subfrags, memory, (subfrag, subfrag_hierarchy))
# Find out the set of subfragments that needs to be flattened into this fragment
# to resolve driver-driver conflicts.
@ -253,20 +244,6 @@ class Fragment:
message += "; hierarchy will be flattened"
warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
for memory, subfrags in memory_subfrags.items():
subfrag_names = flatten_subfrags_if_needed(subfrags)
if not subfrag_names:
continue
# While we're at it, show a message.
message = ("Memory '{}' is accessed from multiple fragments: {}"
.format(memory.name, ", ".join(subfrag_names)))
if mode == "error":
raise DriverConflict(message)
elif mode == "warn":
message += "; hierarchy will be flattened"
warnings.warn_explicit(message, DriverConflict, *memory.src_loc)
# Flatten hierarchy.
for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
self._merge_subfragment(subfrag)
@ -282,8 +259,7 @@ class Fragment:
return self._resolve_hierarchy_conflicts(hierarchy, mode)
# Nothing was flattened, we're done!
return (SignalSet(driver_subfrags.keys()),
set(memory_subfrags.keys()))
return SignalSet(driver_subfrags.keys())
def _propagate_domains_up(self, hierarchy=("top",)):
from .xfrm import DomainRenamer

View file

@ -3,13 +3,13 @@ from collections import OrderedDict
from .. import tracer
from .ast import *
from .ir import Elaboratable, Instance
from .ir import Elaboratable, Instance, Fragment
__all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"]
class Memory:
class Memory(Elaboratable):
"""A word addressable storage.
Parameters
@ -58,6 +58,8 @@ class Memory:
.format(name or "memory", addr)))
self.init = init
self._read_ports = []
self._write_ports = []
@property
def init(self):
@ -116,6 +118,96 @@ class Memory:
"""Simulation only."""
return self._array[index]
def elaborate(self, platform):
init = "".join(format(Const(elem, unsigned(self.width)).value, f"0{self.width}b") for elem in reversed(self.init))
init = Const(int(init or "0", 2), len(self.init) * self.width)
rd_clk = []
rd_clk_enable = 0
rd_transparency_mask = 0
for index, port in enumerate(self._read_ports):
if port.domain != "comb":
rd_clk.append(ClockSignal(port.domain))
rd_clk_enable |= 1 << index
if port.transparent:
for write_index, write_port in enumerate(self._write_ports):
if port.domain == write_port.domain:
rd_transparency_mask |= 1 << (index * len(self._write_ports) + write_index)
else:
rd_clk.append(Const(0, 1))
f = Instance("$mem_v2",
*(("a", attr, value) for attr, value in self.attrs.items()),
p_SIZE=self.depth,
p_OFFSET=0,
p_ABITS=Shape.cast(range(self.depth)).width,
p_WIDTH=self.width,
p_INIT=init,
p_RD_PORTS=len(self._read_ports),
p_RD_CLK_ENABLE=Const(rd_clk_enable, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_CLK_POLARITY=Const(-1, unsigned(len(self._read_ports))) if self._read_ports else Const(0, 1),
p_RD_TRANSPARENCY_MASK=Const(rd_transparency_mask, max(1, len(self._read_ports) * len(self._write_ports))),
p_RD_COLLISION_X_MASK=Const(0, max(1, len(self._read_ports) * len(self._write_ports))),
p_RD_WIDE_CONTINUATION=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_CE_OVER_SRST=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_ARST_VALUE=Const(0, len(self._read_ports) * self.width),
p_RD_SRST_VALUE=Const(0, len(self._read_ports) * self.width),
p_RD_INIT_VALUE=Const(0, len(self._read_ports) * self.width),
p_WR_PORTS=len(self._write_ports),
p_WR_CLK_ENABLE=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
p_WR_CLK_POLARITY=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
p_WR_PRIORITY_MASK=Const(0, len(self._write_ports) * len(self._write_ports)) if self._write_ports else Const(0, 1),
p_WR_WIDE_CONTINUATION=Const(0, len(self._write_ports)) if self._write_ports else Const(0, 1),
i_RD_CLK=Cat(rd_clk),
i_RD_EN=Cat(port.en for port in self._read_ports),
i_RD_ARST=Const(0, len(self._read_ports)),
i_RD_SRST=Const(0, len(self._read_ports)),
i_RD_ADDR=Cat(port.addr for port in self._read_ports),
o_RD_DATA=Cat(port.data for port in self._read_ports),
i_WR_CLK=Cat(ClockSignal(port.domain) for port in self._write_ports),
i_WR_EN=Cat(Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in self._write_ports),
i_WR_ADDR=Cat(port.addr for port in self._write_ports),
i_WR_DATA=Cat(port.data for port in self._write_ports),
)
for port in self._read_ports:
port._MustUse__used = True
if port.domain == "comb":
# Asynchronous port
f.add_statements(port.data.eq(self._array[port.addr]))
f.add_driver(port.data)
else:
# Synchronous port
data = self._array[port.addr]
for write_port in self._write_ports:
if port.domain == write_port.domain and port.transparent:
if len(write_port.en) > 1:
parts = []
for index, en_bit in enumerate(write_port.en):
offset = index * write_port.granularity
bits = slice(offset, offset + write_port.granularity)
cond = en_bit & (port.addr == write_port.addr)
parts.append(Mux(cond, write_port.data[bits], data[bits]))
data = Cat(parts)
else:
data = Mux(write_port.en, write_port.data, data)
f.add_statements(
Switch(port.en, {
1: port.data.eq(data)
})
)
f.add_driver(port.data, port.domain)
for port in self._write_ports:
port._MustUse__used = True
if len(port.en) > 1:
for index, en_bit in enumerate(port.en):
offset = index * port.granularity
bits = slice(offset, offset + port.granularity)
write_data = self._array[port.addr][bits].eq(port.data[bits])
f.add_statements(Switch(en_bit, { 1: write_data }))
else:
write_data = self._array[port.addr].eq(port.data)
f.add_statements(Switch(port.en, { 1: write_data }))
for signal in self._array:
f.add_driver(signal, port.domain)
return f
class ReadPort(Elaboratable):
"""A memory read port.
@ -142,9 +234,7 @@ class ReadPort(Elaboratable):
data : Signal(memory.width), out
Read data.
en : Signal or Const, in
Read enable. If asserted, ``data`` is updated with the word stored at ``addr``. Note that
transparent ports cannot assign ``en`` (which is hardwired to 1 instead), as doing so is
currently not supported by Yosys.
Read enable. If asserted, ``data`` is updated with the word stored at ``addr``.
Exceptions
----------
@ -162,59 +252,19 @@ class ReadPort(Elaboratable):
name="{}_r_addr".format(memory.name), src_loc_at=1 + src_loc_at)
self.data = Signal(memory.width,
name="{}_r_data".format(memory.name), src_loc_at=1 + src_loc_at)
if self.domain != "comb" and not transparent:
if self.domain != "comb":
self.en = Signal(name="{}_r_en".format(memory.name), reset=1,
src_loc_at=1 + src_loc_at)
else:
self.en = Const(1)
memory._read_ports.append(self)
def elaborate(self, platform):
f = Instance("$memrd",
p_MEMID=self.memory,
p_ABITS=self.addr.width,
p_WIDTH=self.data.width,
p_CLK_ENABLE=self.domain != "comb",
p_CLK_POLARITY=1,
p_TRANSPARENT=self.transparent,
i_CLK=ClockSignal(self.domain) if self.domain != "comb" else Const(0),
i_EN=self.en,
i_ADDR=self.addr,
o_DATA=self.data,
)
if self.domain == "comb":
# Asynchronous port
f.add_statements(self.data.eq(self.memory._array[self.addr]))
f.add_driver(self.data)
elif not self.transparent:
# Synchronous, read-before-write port
f.add_statements(
Switch(self.en, {
1: self.data.eq(self.memory._array[self.addr])
})
)
f.add_driver(self.data, self.domain)
if self is self.memory._read_ports[0]:
return self.memory
else:
# Synchronous, write-through port
# This model is a bit unconventional. We model transparent ports as asynchronous ports
# that are latched when the clock is high. This isn't exactly correct, but it is very
# close to the correct behavior of a transparent port, and the difference should only
# be observable in pathological cases of clock gating. A register is injected to
# the address input to achieve the correct address-to-data latency. Also, the reset
# value of the data output is forcibly set to the 0th initial value, if any--note that
# many FPGAs do not guarantee this behavior!
if len(self.memory.init) > 0:
self.data.reset = operator.index(self.memory.init[0])
latch_addr = Signal.like(self.addr)
f.add_statements(
latch_addr.eq(self.addr),
Switch(ClockSignal(self.domain), {
0: self.data.eq(self.data),
1: self.data.eq(self.memory._array[latch_addr]),
}),
)
f.add_driver(latch_addr, self.domain)
f.add_driver(self.data)
return f
return Fragment()
class WritePort(Elaboratable):
@ -272,31 +322,13 @@ class WritePort(Elaboratable):
self.en = Signal(memory.width // granularity,
name="{}_w_en".format(memory.name), src_loc_at=1 + src_loc_at)
memory._write_ports.append(self)
def elaborate(self, platform):
f = Instance("$memwr",
p_MEMID=self.memory,
p_ABITS=self.addr.width,
p_WIDTH=self.data.width,
p_CLK_ENABLE=1,
p_CLK_POLARITY=1,
p_PRIORITY=0,
i_CLK=ClockSignal(self.domain),
i_EN=Cat(en_bit.replicate(self.granularity) for en_bit in self.en),
i_ADDR=self.addr,
i_DATA=self.data,
)
if len(self.en) > 1:
for index, en_bit in enumerate(self.en):
offset = index * self.granularity
bits = slice(offset, offset + self.granularity)
write_data = self.memory._array[self.addr][bits].eq(self.data[bits])
f.add_statements(Switch(en_bit, { 1: write_data }))
if not self.memory._read_ports and self is self.memory._write_ports[0]:
return self.memory
else:
write_data = self.memory._array[self.addr].eq(self.data)
f.add_statements(Switch(self.en, { 1: write_data }))
for signal in self.memory._array:
f.add_driver(signal, self.domain)
return f
return Fragment()
class DummyPort:

View file

@ -720,10 +720,14 @@ class EnableInserter(_ControlInserter):
def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment)
if isinstance(new_fragment, Instance) and new_fragment.type in ("$memrd", "$memwr"):
clk_port, clk_dir = new_fragment.named_ports["CLK"]
if isinstance(clk_port, ClockSignal) and clk_port.domain in self.controls:
en_port, en_dir = new_fragment.named_ports["EN"]
en_port = Mux(self.controls[clk_port.domain], en_port, Const(0, len(en_port)))
new_fragment.named_ports["EN"] = en_port, en_dir
if isinstance(new_fragment, Instance) and new_fragment.type == "$mem_v2":
for kind in ["RD", "WR"]:
clk_parts = new_fragment.named_ports[kind + "_CLK"][0].parts
en_parts = new_fragment.named_ports[kind + "_EN"][0].parts
new_en = []
for clk, en in zip(clk_parts, en_parts):
if isinstance(clk, ClockSignal) and clk.domain in self.controls:
en = Mux(self.controls[clk.domain], en, Const(0, len(en)))
new_en.append(en)
new_fragment.named_ports[kind + "_EN"] = Cat(new_en), "i"
return new_fragment

View file

@ -35,6 +35,7 @@ Apply the following changes to code written against Amaranth 0.3 to migrate it t
While code that uses the features listed as deprecated below will work in Amaranth 0.4, they will be removed in the next version.
Implemented RFCs
----------------
@ -78,6 +79,7 @@ Language changes
* Added: :meth:`Const.cast`. (`RFC 4`_)
* Added: :meth:`Value.matches` and ``with m.Case():`` accept any constant-castable objects. (`RFC 4`_)
* Added: :meth:`Value.replicate`, superseding :class:`Repl`. (`RFC 10`_)
* Added: :class:`Memory` supports transparent read ports with read enable.
* Changed: creating a :class:`Signal` with a shape that is a :class:`ShapeCastable` implementing :meth:`ShapeCastable.__call__` wraps the returned object using that method. (`RFC 15`_)
* Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively.
* Changed: :meth:`Value.cast` treats instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) as enumerations rather than integers.

View file

@ -678,43 +678,6 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
)
""")
def setUp_memory(self):
self.m = Memory(width=8, depth=4)
self.fr = self.m.read_port().elaborate(platform=None)
self.fw = self.m.write_port().elaborate(platform=None)
self.f1 = Fragment()
self.f2 = Fragment()
self.f2.add_subfragment(self.fr)
self.f1.add_subfragment(self.f2)
self.f3 = Fragment()
self.f3.add_subfragment(self.fw)
self.f1.add_subfragment(self.f3)
def test_conflict_memory(self):
self.setUp_memory()
self.f1._resolve_hierarchy_conflicts(mode="silent")
self.assertEqual(self.f1.subfragments, [
(self.fr, None),
(self.fw, None),
])
def test_conflict_memory_error(self):
self.setUp_memory()
with self.assertRaisesRegex(DriverConflict,
r"^Memory 'm' is accessed from multiple fragments: top\.<unnamed #0>, "
r"top\.<unnamed #1>$"):
self.f1._resolve_hierarchy_conflicts(mode="error")
def test_conflict_memory_warning(self):
self.setUp_memory()
with self.assertWarnsRegex(DriverConflict,
(r"^Memory 'm' is accessed from multiple fragments: top.<unnamed #0>, "
r"top.<unnamed #1>; hierarchy will be flattened$")):
self.f1._resolve_hierarchy_conflicts(mode="warn")
def test_explicit_flatten(self):
self.f1 = Fragment()
self.f2 = Fragment()

View file

@ -58,8 +58,8 @@ class MemoryTestCase(FHDLTestCase):
self.assertEqual(len(rdport.addr), 2)
self.assertEqual(len(rdport.data), 8)
self.assertEqual(len(rdport.en), 1)
self.assertIsInstance(rdport.en, Const)
self.assertEqual(rdport.en.value, 1)
self.assertIsInstance(rdport.en, Signal)
self.assertEqual(rdport.en.reset, 1)
def test_read_port_non_transparent(self):
mem = Memory(width=8, depth=4)

View file

@ -547,16 +547,18 @@ class EnableInserterTestCase(FHDLTestCase):
def test_enable_read_port(self):
mem = Memory(width=8, depth=4)
f = EnableInserter(self.c1)(mem.read_port(transparent=False)).elaborate(platform=None)
self.assertRepr(f.named_ports["EN"][0], """
(m (sig c1) (sig mem_r_en) (const 1'd0))
mem.read_port(transparent=False)
f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.named_ports["RD_EN"][0], """
(cat (m (sig c1) (sig mem_r_en) (const 1'd0)))
""")
def test_enable_write_port(self):
mem = Memory(width=8, depth=4)
f = EnableInserter(self.c1)(mem.write_port()).elaborate(platform=None)
self.assertRepr(f.named_ports["EN"][0], """
(m
mem.write_port()
f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.named_ports["WR_EN"][0], """
(cat (m
(sig c1)
(cat
(cat
@ -571,7 +573,7 @@ class EnableInserterTestCase(FHDLTestCase):
)
)
(const 8'd0)
)
))
""")

View file

@ -697,7 +697,6 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
self.setUp_memory()
with self.assertSimulation(self.m) as sim:
def process():
self.assertEqual((yield self.rdport.data), 0xaa)
yield self.rdport.addr.eq(1)
yield
yield
@ -807,6 +806,7 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
self.m.submodules.rdport = self.rdport = self.memory.read_port()
with self.assertSimulation(self.m) as sim:
def process():
yield
self.assertEqual((yield self.rdport.data), 0xaa)
yield self.rdport.addr.eq(1)
yield
@ -815,6 +815,51 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
sim.add_clock(1e-6)
sim.add_sync_process(process)
def test_memory_transparency(self):
m = Module()
init = [0x11111111, 0x22222222, 0x33333333, 0x44444444]
m.submodules.memory = memory = Memory(width=32, depth=4, init=init)
rdport = memory.read_port()
wrport = memory.write_port(granularity=8)
with self.assertSimulation(m) as sim:
def process():
yield rdport.addr.eq(0)
yield
yield Settle()
self.assertEqual((yield rdport.data), 0x11111111)
yield rdport.addr.eq(1)
yield
yield Settle()
self.assertEqual((yield rdport.data), 0x22222222)
yield wrport.addr.eq(0)
yield wrport.data.eq(0x44444444)
yield wrport.en.eq(1)
yield
yield Settle()
self.assertEqual((yield rdport.data), 0x22222222)
yield wrport.addr.eq(1)
yield wrport.data.eq(0x55555555)
yield wrport.en.eq(1)
yield
yield Settle()
self.assertEqual((yield rdport.data), 0x22222255)
yield wrport.addr.eq(1)
yield wrport.data.eq(0x66666666)
yield wrport.en.eq(2)
yield rdport.en.eq(0)
yield
yield Settle()
self.assertEqual((yield rdport.data), 0x22222255)
yield wrport.addr.eq(1)
yield wrport.data.eq(0x77777777)
yield wrport.en.eq(4)
yield rdport.en.eq(1)
yield
yield Settle()
self.assertEqual((yield rdport.data), 0x22776655)
sim.add_clock(1e-6)
sim.add_sync_process(process)
@_ignore_deprecated
def test_sample_helpers(self):
m = Module()