hdl.mem: Switch to first-class IR representation for memories.

Fixes #611.
This commit is contained in:
Wanda 2024-01-16 21:19:03 +01:00 committed by Catherine
parent 2fecd1c78b
commit ae36b596bb
4 changed files with 170 additions and 92 deletions

View file

@ -822,10 +822,79 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
for port_name, (value, dir) in fragment.named_ports.items(): for port_name, (value, dir) in fragment.named_ports.items():
port_map[f"\\{port_name}"] = value port_map[f"\\{port_name}"] = value
params = OrderedDict(fragment.parameters)
if fragment.type[0] == "$": if fragment.type[0] == "$":
return fragment.type, port_map return fragment.type, port_map, params
else: else:
return f"\\{fragment.type}", port_map return f"\\{fragment.type}", port_map, params
if isinstance(fragment, mem.MemoryInstance):
memory = fragment.memory
init = "".join(format(ast.Const(elem, ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init))
init = ast.Const(int(init or "0", 2), memory.depth * memory.width)
rd_clk = []
rd_clk_enable = 0
rd_clk_polarity = 0
rd_transparency_mask = 0
for index, port in enumerate(fragment.read_ports):
if port.domain != "comb":
cd = fragment.domains[port.domain]
rd_clk.append(cd.clk)
if cd.clk_edge == "pos":
rd_clk_polarity |= 1 << index
rd_clk_enable |= 1 << index
if port.transparent:
for write_index, write_port in enumerate(fragment.write_ports):
if port.domain == write_port.domain:
rd_transparency_mask |= 1 << (index * len(fragment.write_ports) + write_index)
else:
rd_clk.append(ast.Const(0, 1))
wr_clk = []
wr_clk_enable = 0
wr_clk_polarity = 0
for index, port in enumerate(fragment.write_ports):
cd = fragment.domains[port.domain]
wr_clk.append(cd.clk)
wr_clk_enable |= 1 << index
if cd.clk_edge == "pos":
wr_clk_polarity |= 1 << index
params = {
"MEMID": builder._make_name(hierarchy[-1], local=False),
"SIZE": memory.depth,
"OFFSET": 0,
"ABITS": ast.Shape.cast(range(memory.depth)).width,
"WIDTH": memory.width,
"INIT": init,
"RD_PORTS": len(fragment.read_ports),
"RD_CLK_ENABLE": ast.Const(rd_clk_enable, max(1, len(fragment.read_ports))),
"RD_CLK_POLARITY": ast.Const(rd_clk_polarity, max(1, len(fragment.read_ports))),
"RD_TRANSPARENCY_MASK": ast.Const(rd_transparency_mask, max(1, len(fragment.read_ports) * len(fragment.write_ports))),
"RD_COLLISION_X_MASK": ast.Const(0, max(1, len(fragment.read_ports) * len(fragment.write_ports))),
"RD_WIDE_CONTINUATION": ast.Const(0, max(1, len(fragment.read_ports))),
"RD_CE_OVER_SRST": ast.Const(0, max(1, len(fragment.read_ports))),
"RD_ARST_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width),
"RD_SRST_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width),
"RD_INIT_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width),
"WR_PORTS": len(fragment.write_ports),
"WR_CLK_ENABLE": ast.Const(wr_clk_enable, max(1, len(fragment.write_ports))),
"WR_CLK_POLARITY": ast.Const(wr_clk_polarity, max(1, len(fragment.write_ports))),
"WR_PRIORITY_MASK": ast.Const(0, max(1, len(fragment.write_ports) * len(fragment.write_ports))),
"WR_WIDE_CONTINUATION": ast.Const(0, max(1, len(fragment.write_ports))),
}
port_map = {
"\\RD_CLK": ast.Cat(rd_clk),
"\\RD_EN": ast.Cat(port.en for port in fragment.read_ports),
"\\RD_ARST": ast.Const(0, len(fragment.read_ports)),
"\\RD_SRST": ast.Const(0, len(fragment.read_ports)),
"\\RD_ADDR": ast.Cat(port.addr for port in fragment.read_ports),
"\\RD_DATA": ast.Cat(port.data for port in fragment.read_ports),
"\\WR_CLK": ast.Cat(wr_clk),
"\\WR_EN": ast.Cat(ast.Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in fragment.write_ports),
"\\WR_ADDR": ast.Cat(port.addr for port in fragment.write_ports),
"\\WR_DATA": ast.Cat(port.data for port in fragment.write_ports),
}
return "$mem_v2", port_map, params
module_name = ".".join(name or "anonymous" for name in hierarchy) module_name = ".".join(name or "anonymous" for name in hierarchy)
module_attrs = OrderedDict() module_attrs = OrderedDict()
@ -860,9 +929,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Transform all subfragments to their respective cells. Transforming signals connected # Transform all subfragments to their respective cells. Transforming signals connected
# to their ports into wires eagerly makes sure they get sensible (prefixed with submodule # to their ports into wires eagerly makes sure they get sensible (prefixed with submodule
# name) names. # name) names.
memories = OrderedDict()
for subfragment, sub_name in fragment.subfragments: for subfragment, sub_name in fragment.subfragments:
if not (subfragment.ports or subfragment.statements or subfragment.subfragments): if not (subfragment.ports or subfragment.statements or subfragment.subfragments or
isinstance(subfragment, (ir.Instance, mem.MemoryInstance))):
# If the fragment is completely empty, skip translating it, otherwise synthesis # If the fragment is completely empty, skip translating it, otherwise synthesis
# tools (including Yosys and Vivado) will treat it as a black box when it is # tools (including Yosys and Vivado) will treat it as a black box when it is
# loaded after conversion to Verilog. # loaded after conversion to Verilog.
@ -871,25 +940,22 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
if sub_name is None: if sub_name is None:
sub_name = module.anonymous() sub_name = module.anonymous()
sub_params = OrderedDict(getattr(subfragment, "parameters", {})) sub_type, sub_port_map, sub_params = \
sub_type, sub_port_map = \
_convert_fragment(builder, subfragment, name_map, _convert_fragment(builder, subfragment, name_map,
hierarchy=hierarchy + (sub_name,)) hierarchy=hierarchy + (sub_name,))
if sub_type == "$mem_v2" and "MEMID" not in sub_params:
sub_params["MEMID"] = builder._make_name(sub_name, local=False)
sub_ports = OrderedDict() sub_ports = OrderedDict()
for port, value in sub_port_map.items(): for port, value in sub_port_map.items():
if not isinstance(subfragment, ir.Instance): if not isinstance(subfragment, (ir.Instance, mem.MemoryInstance)):
for signal in value._rhs_signals(): for signal in value._rhs_signals():
compiler_state.resolve_curr(signal, prefix=sub_name) compiler_state.resolve_curr(signal, prefix=sub_name)
if len(value) > 0 or sub_type == "$mem_v2": if len(value) > 0:
sub_ports[port] = rhs_compiler(value) sub_ports[port] = rhs_compiler(value)
if isinstance(subfragment, ir.Instance): if isinstance(subfragment, ir.Instance):
src = _src(subfragment.src_loc) src = _src(subfragment.src_loc)
elif isinstance(subfragment, mem.MemoryInstance):
src = _src(subfragment.memory.src_loc)
else: else:
src = "" src = ""
@ -1005,7 +1071,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
wire_name = wire_name[1:] wire_name = wire_name[1:]
name_map[signal] = hierarchy + (wire_name,) name_map[signal] = hierarchy + (wire_name,)
return module.name, port_map return module.name, port_map, {}
def convert_fragment(fragment, name="top", *, emit_src=True): def convert_fragment(fragment, name="top", *, emit_src=True):

View file

@ -119,55 +119,7 @@ class Memory(Elaboratable):
return self._array[index] return self._array[index]
def elaborate(self, platform): def elaborate(self, platform):
init = "".join(format(Const(elem, unsigned(self.width)).value, f"0{self.width}b") for elem in reversed(self.init)) f = MemoryInstance(self, self._read_ports, self._write_ports)
init = Const(int(init or "0", 2), self.depth * self.width)
rd_clk = []
rd_clk_enable = 0
rd_transparency_mask = 0
for index, port in enumerate(self._read_ports):
if port.domain != "comb":
rd_clk.append(ClockSignal(port.domain))
rd_clk_enable |= 1 << index
if port.transparent:
for write_index, write_port in enumerate(self._write_ports):
if port.domain == write_port.domain:
rd_transparency_mask |= 1 << (index * len(self._write_ports) + write_index)
else:
rd_clk.append(Const(0, 1))
f = Instance("$mem_v2",
*(("a", attr, value) for attr, value in self.attrs.items()),
p_SIZE=self.depth,
p_OFFSET=0,
p_ABITS=Shape.cast(range(self.depth)).width,
p_WIDTH=self.width,
p_INIT=init,
p_RD_PORTS=len(self._read_ports),
p_RD_CLK_ENABLE=Const(rd_clk_enable, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_CLK_POLARITY=Const(-1, unsigned(len(self._read_ports))) if self._read_ports else Const(0, 1),
p_RD_TRANSPARENCY_MASK=Const(rd_transparency_mask, max(1, len(self._read_ports) * len(self._write_ports))),
p_RD_COLLISION_X_MASK=Const(0, max(1, len(self._read_ports) * len(self._write_ports))),
p_RD_WIDE_CONTINUATION=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_CE_OVER_SRST=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_ARST_VALUE=Const(0, len(self._read_ports) * self.width),
p_RD_SRST_VALUE=Const(0, len(self._read_ports) * self.width),
p_RD_INIT_VALUE=Const(0, len(self._read_ports) * self.width),
p_WR_PORTS=len(self._write_ports),
p_WR_CLK_ENABLE=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
p_WR_CLK_POLARITY=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
p_WR_PRIORITY_MASK=Const(0, len(self._write_ports) * len(self._write_ports)) if self._write_ports else Const(0, 1),
p_WR_WIDE_CONTINUATION=Const(0, len(self._write_ports)) if self._write_ports else Const(0, 1),
i_RD_CLK=Cat(rd_clk),
i_RD_EN=Cat(port.en for port in self._read_ports),
i_RD_ARST=Const(0, len(self._read_ports)),
i_RD_SRST=Const(0, len(self._read_ports)),
i_RD_ADDR=Cat(port.addr for port in self._read_ports),
o_RD_DATA=Cat(port.data for port in self._read_ports),
i_WR_CLK=Cat(ClockSignal(port.domain) for port in self._write_ports),
i_WR_EN=Cat(Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in self._write_ports),
i_WR_ADDR=Cat(port.addr for port in self._write_ports),
i_WR_DATA=Cat(port.data for port in self._write_ports),
src_loc=self.src_loc,
)
for port in self._read_ports: for port in self._read_ports:
port._MustUse__used = True port._MustUse__used = True
if port.domain == "comb": if port.domain == "comb":
@ -211,6 +163,7 @@ class Memory(Elaboratable):
f.add_driver(signal, port.domain) f.add_driver(signal, port.domain)
return f return f
class ReadPort(Elaboratable): class ReadPort(Elaboratable):
"""A memory read port. """A memory read port.
@ -354,3 +307,12 @@ class DummyPort:
name=f"{name}_data", src_loc_at=1) name=f"{name}_data", src_loc_at=1)
self.en = Signal(data_width // granularity, self.en = Signal(data_width // granularity,
name=f"{name}_en", src_loc_at=1) name=f"{name}_en", src_loc_at=1)
class MemoryInstance(Fragment):
def __init__(self, memory, read_ports, write_ports):
super().__init__()
self.memory = memory
self.read_ports = read_ports
self.write_ports = write_ports
self.attrs = memory.attrs

View file

@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable from collections.abc import Iterable
from copy import copy
from .._utils import flatten, _ignore_deprecated from .._utils import flatten, _ignore_deprecated
from .. import tracer from .. import tracer
@ -8,6 +9,7 @@ from .ast import *
from .ast import _StatementList from .ast import _StatementList
from .cd import * from .cd import *
from .ir import * from .ir import *
from .mem import MemoryInstance
__all__ = ["ValueVisitor", "ValueTransformer", __all__ = ["ValueVisitor", "ValueTransformer",
@ -261,8 +263,30 @@ class FragmentTransformer:
for domain, signal in fragment.iter_drivers(): for domain, signal in fragment.iter_drivers():
new_fragment.add_driver(signal, domain) new_fragment.add_driver(signal, domain)
def map_memory_ports(self, fragment, new_fragment):
new_fragment.read_ports = [
copy(port)
for port in fragment.read_ports
]
new_fragment.write_ports = [
copy(port)
for port in fragment.write_ports
]
if hasattr(self, "on_value"):
for port in new_fragment.read_ports:
port.en = self.on_value(port.en)
port.addr = self.on_value(port.addr)
port.data = self.on_value(port.data)
for port in new_fragment.write_ports:
port.en = self.on_value(port.en)
port.addr = self.on_value(port.addr)
port.data = self.on_value(port.data)
def on_fragment(self, fragment): def on_fragment(self, fragment):
if isinstance(fragment, Instance): if isinstance(fragment, MemoryInstance):
new_fragment = MemoryInstance(fragment.memory, [], [])
self.map_memory_ports(fragment, new_fragment)
elif isinstance(fragment, Instance):
new_fragment = Instance(fragment.type, src_loc=fragment.src_loc) new_fragment = Instance(fragment.type, src_loc=fragment.src_loc)
new_fragment.parameters = OrderedDict(fragment.parameters) new_fragment.parameters = OrderedDict(fragment.parameters)
self.map_named_ports(fragment, new_fragment) self.map_named_ports(fragment, new_fragment)
@ -381,6 +405,19 @@ class DomainCollector(ValueVisitor, StatementVisitor):
self.on_statement(stmt) self.on_statement(stmt)
def on_fragment(self, fragment): def on_fragment(self, fragment):
if isinstance(fragment, MemoryInstance):
for port in fragment.read_ports:
self.on_value(port.addr)
self.on_value(port.data)
self.on_value(port.en)
if port.domain != "comb":
self._add_used_domain(port.domain)
for port in fragment.write_ports:
self.on_value(port.addr)
self.on_value(port.data)
self.on_value(port.en)
self._add_used_domain(port.domain)
if isinstance(fragment, Instance): if isinstance(fragment, Instance):
for name, (value, dir) in fragment.named_ports.items(): for name, (value, dir) in fragment.named_ports.items():
self.on_value(value) self.on_value(value)
@ -444,6 +481,15 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer)
for signal in signals: for signal in signals:
new_fragment.add_driver(self.on_value(signal), domain) new_fragment.add_driver(self.on_value(signal), domain)
def map_memory_ports(self, fragment, new_fragment):
super().map_memory_ports(fragment, new_fragment)
for port in new_fragment.read_ports:
if port.domain in self.domain_map:
port.domain = self.domain_map[port.domain]
for port in new_fragment.write_ports:
if port.domain in self.domain_map:
port.domain = self.domain_map[port.domain]
class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
def __init__(self, domains=None): def __init__(self, domains=None):
@ -630,14 +676,11 @@ class EnableInserter(_ControlInserter):
def on_fragment(self, fragment): def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment) new_fragment = super().on_fragment(fragment)
if isinstance(new_fragment, Instance) and new_fragment.type == "$mem_v2": if isinstance(new_fragment, MemoryInstance):
for kind in ["RD", "WR"]: for port in new_fragment.read_ports:
clk_parts = new_fragment.named_ports[kind + "_CLK"][0].parts if port.domain in self.controls:
en_parts = new_fragment.named_ports[kind + "_EN"][0].parts port.en = port.en & self.controls[port.domain]
new_en = [] for port in new_fragment.write_ports:
for clk, en in zip(clk_parts, en_parts): if port.domain in self.controls:
if isinstance(clk, ClockSignal) and clk.domain in self.controls: port.en = Mux(self.controls[port.domain], port.en, Const(0, len(port.en)))
en = Mux(self.controls[clk.domain], en, Const(0, len(en)))
new_en.append(en)
new_fragment.named_ports[kind + "_EN"] = Cat(new_en), "i"
return new_fragment return new_fragment

View file

@ -4,9 +4,11 @@ import warnings
from amaranth.hdl.ast import * from amaranth.hdl.ast import *
from amaranth.hdl.cd import * from amaranth.hdl.cd import *
from amaranth.hdl.dsl import *
from amaranth.hdl.ir import * from amaranth.hdl.ir import *
from amaranth.hdl.xfrm import * from amaranth.hdl.xfrm import *
from amaranth.hdl.mem import * from amaranth.hdl.mem import *
from amaranth.hdl.mem import MemoryInstance
from .utils import * from .utils import *
from amaranth._utils import _ignore_deprecated from amaranth._utils import _ignore_deprecated
@ -113,6 +115,22 @@ class DomainRenamerTestCase(FHDLTestCase):
"pix": cd_pix, "pix": cd_pix,
}) })
def test_rename_mem_ports(self):
m = Module()
mem = Memory(depth=4, width=16)
m.submodules.mem = mem
mem.read_port(domain="a")
mem.read_port(domain="b")
mem.write_port(domain="c")
f = Fragment.get(m, None)
f = DomainRenamer({"a": "d", "c": "e"})(f)
mem = f.subfragments[0][0]
self.assertIsInstance(mem, MemoryInstance)
self.assertEqual(mem.read_ports[0].domain, "d")
self.assertEqual(mem.read_ports[1].domain, "b")
self.assertEqual(mem.write_ports[0].domain, "e")
def test_rename_wrong_to_comb(self): def test_rename_wrong_to_comb(self):
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
r"^Domain 'sync' may not be renamed to 'comb'$"): r"^Domain 'sync' may not be renamed to 'comb'$"):
@ -501,31 +519,20 @@ class EnableInserterTestCase(FHDLTestCase):
mem = Memory(width=8, depth=4) mem = Memory(width=8, depth=4)
mem.read_port(transparent=False) mem.read_port(transparent=False)
f = EnableInserter(self.c1)(mem).elaborate(platform=None) f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.named_ports["RD_EN"][0], """ self.assertRepr(f.read_ports[0].en, """
(cat (m (sig c1) (sig mem_r_en) (const 1'd0))) (& (sig mem_r_en) (sig c1))
""") """)
def test_enable_write_port(self): def test_enable_write_port(self):
mem = Memory(width=8, depth=4) mem = Memory(width=8, depth=4)
mem.write_port() mem.write_port(granularity=2)
f = EnableInserter(self.c1)(mem).elaborate(platform=None) f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.named_ports["WR_EN"][0], """ self.assertRepr(f.write_ports[0].en, """
(cat (m (m
(sig c1) (sig c1)
(cat (sig mem_w_en)
(cat (const 4'd0)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
) )
)
(const 8'd0)
))
""") """)