hdl._mem: implement MemoryData._Row from RFC 62.

This commit is contained in:
Wanda 2024-04-03 17:12:56 +02:00 committed by Catherine
parent 93ef89626e
commit 767d69c703
11 changed files with 214 additions and 83 deletions

View file

@ -13,11 +13,90 @@ from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns
from ._ir import *
from ._cd import *
from ._xfrm import *
from ._mem import MemoryData
__all__ = ["SyntaxError", "SyntaxWarning", "Module"]
class _Visitor:
def __init__(self):
self.driven_signals = SignalSet()
def visit_stmt(self, stmt):
if isinstance(stmt, _StatementList):
for s in stmt:
self.visit_stmt(s)
elif isinstance(stmt, Assign):
self.visit_lhs(stmt.lhs)
self.visit_rhs(stmt.rhs)
elif isinstance(stmt, Print):
for chunk in stmt.message._chunks:
if not isinstance(chunk, str):
obj, format_spec = chunk
self.visit_rhs(obj)
elif isinstance(stmt, Property):
self.visit_rhs(stmt.test)
if stmt.message is not None:
for chunk in stmt.message._chunks:
if not isinstance(chunk, str):
obj, format_spec = chunk
self.visit_rhs(obj)
elif isinstance(stmt, Switch):
self.visit_rhs(stmt.test)
for _patterns, stmts, _src_loc in stmt.cases:
self.visit_stmt(stmts)
elif isinstance(stmt, _LateBoundStatement):
pass
else:
assert False # :nocov:
def visit_lhs(self, value):
if isinstance(value, Operator) and value.operator in ("u", "s"):
self.visit_lhs(value.operands[0])
elif isinstance(value, (Signal, ClockSignal, ResetSignal)):
self.driven_signals.add(value)
elif isinstance(value, Slice):
self.visit_lhs(value.value)
elif isinstance(value, Part):
self.visit_lhs(value.value)
self.visit_rhs(value.offset)
elif isinstance(value, Concat):
for part in value.parts:
self.visit_lhs(part)
elif isinstance(value, SwitchValue):
self.visit_rhs(value.test)
for _patterns, elem in value.cases:
self.visit_lhs(elem)
elif isinstance(value, MemoryData._Row):
raise ValueError(f"Value {value!r} can only be used in simulator processes")
else:
raise ValueError(f"Value {value!r} cannot be assigned to")
def visit_rhs(self, value):
if isinstance(value, (Const, Signal, ClockSignal, ResetSignal, Initial, AnyValue)):
pass
elif isinstance(value, Operator):
for op in value.operands:
self.visit_rhs(op)
elif isinstance(value, Slice):
self.visit_rhs(value.value)
elif isinstance(value, Part):
self.visit_rhs(value.value)
self.visit_rhs(value.offset)
elif isinstance(value, Concat):
for part in value.parts:
self.visit_rhs(part)
elif isinstance(value, SwitchValue):
self.visit_rhs(value.test)
for _patterns, elem in value.cases:
self.visit_rhs(elem)
elif isinstance(value, MemoryData._Row):
raise ValueError(f"Value {value!r} can only be used in simulator processes")
else:
assert False # :nocov:
class _ModuleBuilderProxy:
def __init__(self, builder, depth):
object.__setattr__(self, "_builder", builder)
@ -545,15 +624,16 @@ class Module(_ModuleBuilderRoot, Elaboratable):
stmt._MustUse__used = True
if isinstance(stmt, Assign):
for signal in stmt._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
visitor = _Visitor()
visitor.visit_stmt(stmt)
for signal in visitor.driven_signals:
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
self._statements.setdefault(domain, []).append(stmt)
@ -595,10 +675,14 @@ class Module(_ModuleBuilderRoot, Elaboratable):
for domain, statements in self._statements.items():
statements = resolve_statements(statements)
fragment.add_statements(domain, statements)
for signal in statements._lhs_signals():
visitor = _Visitor()
visitor.visit_stmt(statements)
for signal in visitor.driven_signals:
fragment.add_driver(signal, domain)
fragment.add_statements("comb", self._top_comb_statements)
for signal in self._top_comb_statements._lhs_signals():
visitor = _Visitor()
visitor.visit_stmt(self._top_comb_statements)
for signal in visitor.driven_signals:
fragment.add_driver(signal, "comb")
fragment.add_domains(self._domains.values())
fragment.generated.update(self._generated)

View file

@ -102,6 +102,28 @@ class MemoryData:
return f"MemoryData.Init({self._elems!r}, shape={self._shape!r}, depth={self._depth})"
@final
class _Row(Value):
def __init__(self, memory, index, *, src_loc_at=0):
assert isinstance(memory, MemoryData)
self._memory = memory
self._index = operator.index(index)
assert self._index in range(memory.depth)
super().__init__(src_loc_at=src_loc_at)
def shape(self):
return Shape.cast(self._memory.shape)
def _lhs_signals(self):
# This value cannot ever appear in a design.
raise NotImplementedError # :nocov:
_rhs_signals = _lhs_signals
def __repr__(self):
return f"(memory-row {self._memory!r} {self._index})"
def __init__(self, *, shape, depth, init, src_loc_at=0):
# shape and depth validation is performed in MemoryData.Init()
self._shape = shape
@ -137,26 +159,14 @@ class MemoryData:
return f"(memory-data {self.name})"
def __getitem__(self, index):
"""Simulation only."""
return MemorySimRead(self, index)
class MemorySimRead:
def __init__(self, memory, addr):
assert isinstance(memory, MemoryData)
self._memory = memory
self._addr = Value.cast(addr)
def eq(self, value):
return MemorySimWrite(self._memory, self._addr, value)
class MemorySimWrite:
def __init__(self, memory, addr, data):
assert isinstance(memory, MemoryData)
self._memory = memory
self._addr = Value.cast(addr)
self._data = Value.cast(data)
index = operator.index(index)
if index not in range(self.depth):
raise IndexError(f"Index {index} is out of bounds (memory has {self.depth} rows)")
row = MemoryData._Row(self, index)
if isinstance(self.shape, ShapeCastable):
return self.shape(row)
else:
return row
class MemoryInstance(Fragment):
@ -312,8 +322,7 @@ class Memory(Elaboratable):
return WritePort(self, src_loc_at=1 + src_loc_at, **kwargs)
def __getitem__(self, index):
"""Simulation only."""
return MemorySimRead(self._data, index)
return self._data[index]
def elaborate(self, platform):
f = MemoryInstance(data=self._data, attrs=self.attrs, src_loc=self.src_loc)

View file

@ -3,7 +3,7 @@ from collections import OrderedDict
from collections.abc import MutableSequence
from ..hdl import MemoryData, MemoryInstance, Shape, ShapeCastable, Const
from ..hdl._mem import MemorySimRead, FrozenError
from ..hdl._mem import FrozenError
from ..utils import ceil_log2
from .._utils import final
from .. import tracer
@ -194,10 +194,6 @@ class Memory(wiring.Component):
transparent_for=transparent_for)
return instance
def __getitem__(self, index):
"""Simulation only."""
return self._data[index]
class ReadPort:
"""A read memory port.

View file

@ -35,7 +35,7 @@ class BaseMemoryState:
def read(self, addr):
raise NotImplementedError # :nocov:
def write(self, addr, value):
def write(self, addr, value, mask=None):
raise NotImplementedError # :nocov:

View file

@ -2,7 +2,6 @@ import inspect
from ..hdl import *
from ..hdl._ast import Statement, Assign, SignalSet, ValueCastable
from ..hdl._mem import MemorySimRead, MemorySimWrite
from .core import Tick, Settle, Delay, Passive, Active
from ._base import BaseProcess, BaseMemoryState
from ._pyeval import eval_value, eval_assign
@ -123,23 +122,6 @@ class PyCoroProcess(BaseProcess):
elif type(command) is Active:
self.passive = False
elif type(command) is MemorySimRead:
addr = eval_value(self.state, command._addr)
index = self.state.get_memory(command._memory)
state = self.state.slots[index]
assert isinstance(state, BaseMemoryState)
response = state.read(addr)
elif type(command) is MemorySimWrite:
addr = eval_value(self.state, command._addr)
data = eval_value(self.state, command._data)
index = self.state.get_memory(command._memory)
state = self.state.slots[index]
assert isinstance(state, BaseMemoryState)
state.write(addr, data)
if self.testbench:
return True # assignment; run a delta cycle
elif command is None: # only possible if self.default_cmd is None
raise TypeError("Received default command from process {!r} that was added "
"with add_process(); did you mean to use Tick() instead?"

View file

@ -1,4 +1,5 @@
from amaranth.hdl._ast import *
from amaranth.hdl._mem import MemoryData
def _eval_matches(test, patterns):
@ -118,6 +119,9 @@ def eval_value(sim, value):
elif isinstance(value, Signal):
slot = sim.get_signal(value)
return sim.slots[slot].curr
elif isinstance(value, MemoryData._Row):
slot = sim.get_memory(value._memory)
return sim.slots[slot].read(value._index)
elif isinstance(value, (ResetSignal, ClockSignal, AnyValue, Initial)):
raise ValueError(f"Value {value!r} cannot be used in simulation")
else:
@ -142,6 +146,15 @@ def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len):
if lhs._signed and (value & (1 << (len(lhs) - 1))):
value |= -1 << (len(lhs) - 1)
sim.slots[slot].set(value)
elif isinstance(lhs, MemoryData._Row):
lhs_stop = lhs_start + rhs_len
if lhs_stop > len(lhs):
lhs_stop = len(lhs)
if lhs_start >= len(lhs):
return
slot = sim.get_memory(lhs._memory)
mask = (1 << lhs_stop) - (1 << lhs_start)
sim.slots[slot].write(lhs._index, rhs << lhs_start, mask)
elif isinstance(lhs, Slice):
_eval_assign_inner(sim, lhs.value, lhs_start + lhs.start, rhs, rhs_len)
elif isinstance(lhs, Concat):

View file

@ -5,6 +5,7 @@ from .._utils import deprecated
from ..hdl._cd import *
from ..hdl._ir import *
from ..hdl._ast import Value, ValueLike
from ..hdl._mem import MemoryData
from ._base import BaseEngine
@ -242,6 +243,8 @@ class Simulator:
for trace in traces:
if isinstance(trace, ValueLike):
trace_cast = Value.cast(trace)
if isinstance(trace_cast, MemoryData._Row):
continue
for trace_signal in trace_cast._rhs_signals():
if trace_signal.name == "":
if trace_signal is trace:

View file

@ -82,16 +82,28 @@ class _VCDWriter:
for trace in traces:
if isinstance(trace, ValueLike):
trace = Value.cast(trace)
for trace_signal in trace._rhs_signals():
if trace_signal not in signal_names:
if trace_signal.name not in assigned_names:
name = trace_signal.name
if isinstance(trace, MemoryData._Row):
memory = trace._memory
if not memory in memories:
if memory.name not in assigned_names:
name = memory.name
else:
name = f"{trace_signal.name}${len(assigned_names)}"
name = f"{memory.name}${len(assigned_names)}"
assert name not in assigned_names
trace_names[trace_signal] = {("bench", name)}
memories[memory] = ("bench", name)
assigned_names.add(name)
self.traces.append(trace_signal)
self.traces.append(trace)
else:
for trace_signal in trace._rhs_signals():
if trace_signal not in signal_names:
if trace_signal.name not in assigned_names:
name = trace_signal.name
else:
name = f"{trace_signal.name}${len(assigned_names)}"
assert name not in assigned_names
trace_names[trace_signal] = {("bench", name)}
assigned_names.add(name)
self.traces.append(trace_signal)
elif isinstance(trace, MemoryData):
if not trace in memories:
if trace.name not in assigned_names:
@ -223,13 +235,16 @@ class _VCDWriter:
self.gtkw_save.dumpfile_size(self.vcd_file.tell())
self.gtkw_save.treeopen("top")
for signal in self.traces:
if isinstance(signal, Signal):
for name in self.gtkw_signal_names[signal]:
for trace in self.traces:
if isinstance(trace, Signal):
for name in self.gtkw_signal_names[trace]:
self.gtkw_save.trace(name)
elif isinstance(signal, MemoryIdentity):
for name in self.gtkw_memory_names[signal]:
elif isinstance(trace, MemoryData):
for name in self.gtkw_memory_names[trace]:
self.gtkw_save.trace(name)
elif isinstance(trace, MemoryData._Row):
name = self.gtkw_memory_names[trace._memory][trace._index]
self.gtkw_save.trace(name)
else:
assert False # :nocov:

View file

@ -69,6 +69,7 @@ Implemented RFCs
* `RFC 51`_: Add ``ShapeCastable.from_bits`` and ``amaranth.lib.data.Const``
* `RFC 53`_: Low-level I/O primitives
* `RFC 59`_: Get rid of upwards propagation of clock domains
* `RFC 62`_: The `MemoryData`` class
Language changes

View file

@ -1,12 +1,38 @@
# amaranth: UnusedElaboratable=no
from amaranth.hdl._ast import *
from amaranth.hdl import *
from amaranth.hdl._mem import *
from amaranth._utils import _ignore_deprecated
from .utils import *
class MemoryDataTestCase(FHDLTestCase):
def test_repr(self):
data = MemoryData(shape=8, depth=4, init=[])
self.assertRepr(data, "(memory-data data)")
def test_row(self):
data = MemoryData(shape=8, depth=4, init=[])
self.assertRepr(data[2], "(memory-row (memory-data data) 2)")
def test_row_wrong(self):
data = MemoryData(shape=8, depth=4, init=[])
with self.assertRaisesRegex(IndexError,
r"^Index 4 is out of bounds \(memory has 4 rows\)$"):
data[4]
def test_row_elab(self):
data = MemoryData(shape=8, depth=4, init=[])
m = Module()
a = Signal(8)
with self.assertRaisesRegex(ValueError,
r"^Value \(memory-row \(memory-data data\) 0\) can only be used in simulator processes$"):
m.d.comb += a.eq(data[0])
with self.assertRaisesRegex(ValueError,
r"^Value \(memory-row \(memory-data data\) 0\) can only be used in simulator processes$"):
m.d.comb += data[0].eq(1)
class MemoryTestCase(FHDLTestCase):
def test_name(self):
with _ignore_deprecated():

View file

@ -1043,17 +1043,19 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
self.setUp_memory()
with self.assertSimulation(self.m) as sim:
def process():
self.assertEqual((yield self.memory[1]), 0x55)
self.assertEqual((yield self.memory[Const(1)]), 0x55)
self.assertEqual((yield self.memory[Const(2)]), 0x00)
yield self.memory[Const(1)].eq(Const(0x33))
self.assertEqual((yield self.memory[Const(1)]), 0x33)
self.assertEqual((yield self.memory.data[1]), 0x55)
self.assertEqual((yield self.memory.data[1]), 0x55)
self.assertEqual((yield self.memory.data[2]), 0x00)
yield self.memory.data[1].eq(Const(0x33))
self.assertEqual((yield self.memory.data[1]), 0x33)
yield self.memory.data[1][2:5].eq(Const(0x7))
self.assertEqual((yield self.memory.data[1]), 0x3f)
yield self.wrport.addr.eq(3)
yield self.wrport.data.eq(0x22)
yield self.wrport.en.eq(1)
self.assertEqual((yield self.memory[Const(3)]), 0)
self.assertEqual((yield self.memory.data[3]), 0)
yield Tick()
self.assertEqual((yield self.memory[Const(3)]), 0x22)
self.assertEqual((yield self.memory.data[3]), 0x22)
sim.add_clock(1e-6)
sim.add_testbench(process)
@ -1062,13 +1064,13 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
self.setUp_memory()
with self.assertSimulation(self.m) as sim:
def process():
self.assertEqual((yield self.memory[1]), 0x55)
self.assertEqual((yield self.memory[Const(1)]), 0x55)
self.assertEqual((yield self.memory[Const(2)]), 0x00)
yield self.memory[Const(1)].eq(Const(0x33))
self.assertEqual((yield self.memory[Const(1)]), 0x55)
self.assertEqual((yield self.memory.data[1]), 0x55)
self.assertEqual((yield self.memory.data[1]), 0x55)
self.assertEqual((yield self.memory.data[2]), 0x00)
yield self.memory.data[1].eq(Const(0x33))
self.assertEqual((yield self.memory.data[1]), 0x55)
yield Tick()
self.assertEqual((yield self.memory[Const(1)]), 0x33)
self.assertEqual((yield self.memory.data[1]), 0x33)
sim.add_clock(1e-6)
sim.add_process(process)