diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index d92255f..837aba2 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -13,11 +13,90 @@ from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns from ._ir import * from ._cd import * from ._xfrm import * +from ._mem import MemoryData __all__ = ["SyntaxError", "SyntaxWarning", "Module"] +class _Visitor: + def __init__(self): + self.driven_signals = SignalSet() + + def visit_stmt(self, stmt): + if isinstance(stmt, _StatementList): + for s in stmt: + self.visit_stmt(s) + elif isinstance(stmt, Assign): + self.visit_lhs(stmt.lhs) + self.visit_rhs(stmt.rhs) + elif isinstance(stmt, Print): + for chunk in stmt.message._chunks: + if not isinstance(chunk, str): + obj, format_spec = chunk + self.visit_rhs(obj) + elif isinstance(stmt, Property): + self.visit_rhs(stmt.test) + if stmt.message is not None: + for chunk in stmt.message._chunks: + if not isinstance(chunk, str): + obj, format_spec = chunk + self.visit_rhs(obj) + elif isinstance(stmt, Switch): + self.visit_rhs(stmt.test) + for _patterns, stmts, _src_loc in stmt.cases: + self.visit_stmt(stmts) + elif isinstance(stmt, _LateBoundStatement): + pass + else: + assert False # :nocov: + + def visit_lhs(self, value): + if isinstance(value, Operator) and value.operator in ("u", "s"): + self.visit_lhs(value.operands[0]) + elif isinstance(value, (Signal, ClockSignal, ResetSignal)): + self.driven_signals.add(value) + elif isinstance(value, Slice): + self.visit_lhs(value.value) + elif isinstance(value, Part): + self.visit_lhs(value.value) + self.visit_rhs(value.offset) + elif isinstance(value, Concat): + for part in value.parts: + self.visit_lhs(part) + elif isinstance(value, SwitchValue): + self.visit_rhs(value.test) + for _patterns, elem in value.cases: + self.visit_lhs(elem) + elif isinstance(value, MemoryData._Row): + raise ValueError(f"Value {value!r} can only be used in simulator processes") + else: + raise ValueError(f"Value {value!r} cannot be assigned to") + + def visit_rhs(self, value): + if isinstance(value, (Const, Signal, ClockSignal, ResetSignal, Initial, AnyValue)): + pass + elif isinstance(value, Operator): + for op in value.operands: + self.visit_rhs(op) + elif isinstance(value, Slice): + self.visit_rhs(value.value) + elif isinstance(value, Part): + self.visit_rhs(value.value) + self.visit_rhs(value.offset) + elif isinstance(value, Concat): + for part in value.parts: + self.visit_rhs(part) + elif isinstance(value, SwitchValue): + self.visit_rhs(value.test) + for _patterns, elem in value.cases: + self.visit_rhs(elem) + elif isinstance(value, MemoryData._Row): + raise ValueError(f"Value {value!r} can only be used in simulator processes") + else: + assert False # :nocov: + + class _ModuleBuilderProxy: def __init__(self, builder, depth): object.__setattr__(self, "_builder", builder) @@ -545,15 +624,16 @@ class Module(_ModuleBuilderRoot, Elaboratable): stmt._MustUse__used = True - if isinstance(stmt, Assign): - for signal in stmt._lhs_signals(): - if signal not in self._driving: - self._driving[signal] = domain - elif self._driving[signal] != domain: - cd_curr = self._driving[signal] - raise SyntaxError( - f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is " - f"already driven from d.{cd_curr}") + visitor = _Visitor() + visitor.visit_stmt(stmt) + for signal in visitor.driven_signals: + if signal not in self._driving: + self._driving[signal] = domain + elif self._driving[signal] != domain: + cd_curr = self._driving[signal] + raise SyntaxError( + f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is " + f"already driven from d.{cd_curr}") self._statements.setdefault(domain, []).append(stmt) @@ -595,10 +675,14 @@ class Module(_ModuleBuilderRoot, Elaboratable): for domain, statements in self._statements.items(): statements = resolve_statements(statements) fragment.add_statements(domain, statements) - for signal in statements._lhs_signals(): + visitor = _Visitor() + visitor.visit_stmt(statements) + for signal in visitor.driven_signals: fragment.add_driver(signal, domain) fragment.add_statements("comb", self._top_comb_statements) - for signal in self._top_comb_statements._lhs_signals(): + visitor = _Visitor() + visitor.visit_stmt(self._top_comb_statements) + for signal in visitor.driven_signals: fragment.add_driver(signal, "comb") fragment.add_domains(self._domains.values()) fragment.generated.update(self._generated) diff --git a/amaranth/hdl/_mem.py b/amaranth/hdl/_mem.py index 7b32ca1..79104f2 100644 --- a/amaranth/hdl/_mem.py +++ b/amaranth/hdl/_mem.py @@ -102,6 +102,28 @@ class MemoryData: return f"MemoryData.Init({self._elems!r}, shape={self._shape!r}, depth={self._depth})" + @final + class _Row(Value): + def __init__(self, memory, index, *, src_loc_at=0): + assert isinstance(memory, MemoryData) + self._memory = memory + self._index = operator.index(index) + assert self._index in range(memory.depth) + super().__init__(src_loc_at=src_loc_at) + + def shape(self): + return Shape.cast(self._memory.shape) + + def _lhs_signals(self): + # This value cannot ever appear in a design. + raise NotImplementedError # :nocov: + + _rhs_signals = _lhs_signals + + def __repr__(self): + return f"(memory-row {self._memory!r} {self._index})" + + def __init__(self, *, shape, depth, init, src_loc_at=0): # shape and depth validation is performed in MemoryData.Init() self._shape = shape @@ -137,26 +159,14 @@ class MemoryData: return f"(memory-data {self.name})" def __getitem__(self, index): - """Simulation only.""" - return MemorySimRead(self, index) - - -class MemorySimRead: - def __init__(self, memory, addr): - assert isinstance(memory, MemoryData) - self._memory = memory - self._addr = Value.cast(addr) - - def eq(self, value): - return MemorySimWrite(self._memory, self._addr, value) - - -class MemorySimWrite: - def __init__(self, memory, addr, data): - assert isinstance(memory, MemoryData) - self._memory = memory - self._addr = Value.cast(addr) - self._data = Value.cast(data) + index = operator.index(index) + if index not in range(self.depth): + raise IndexError(f"Index {index} is out of bounds (memory has {self.depth} rows)") + row = MemoryData._Row(self, index) + if isinstance(self.shape, ShapeCastable): + return self.shape(row) + else: + return row class MemoryInstance(Fragment): @@ -312,8 +322,7 @@ class Memory(Elaboratable): return WritePort(self, src_loc_at=1 + src_loc_at, **kwargs) def __getitem__(self, index): - """Simulation only.""" - return MemorySimRead(self._data, index) + return self._data[index] def elaborate(self, platform): f = MemoryInstance(data=self._data, attrs=self.attrs, src_loc=self.src_loc) diff --git a/amaranth/lib/memory.py b/amaranth/lib/memory.py index 7ed9f15..e52aa56 100644 --- a/amaranth/lib/memory.py +++ b/amaranth/lib/memory.py @@ -3,7 +3,7 @@ from collections import OrderedDict from collections.abc import MutableSequence from ..hdl import MemoryData, MemoryInstance, Shape, ShapeCastable, Const -from ..hdl._mem import MemorySimRead, FrozenError +from ..hdl._mem import FrozenError from ..utils import ceil_log2 from .._utils import final from .. import tracer @@ -194,10 +194,6 @@ class Memory(wiring.Component): transparent_for=transparent_for) return instance - def __getitem__(self, index): - """Simulation only.""" - return self._data[index] - class ReadPort: """A read memory port. diff --git a/amaranth/sim/_base.py b/amaranth/sim/_base.py index 9b8f291..046f690 100644 --- a/amaranth/sim/_base.py +++ b/amaranth/sim/_base.py @@ -35,7 +35,7 @@ class BaseMemoryState: def read(self, addr): raise NotImplementedError # :nocov: - def write(self, addr, value): + def write(self, addr, value, mask=None): raise NotImplementedError # :nocov: diff --git a/amaranth/sim/_pycoro.py b/amaranth/sim/_pycoro.py index 0c705a9..8c929b5 100644 --- a/amaranth/sim/_pycoro.py +++ b/amaranth/sim/_pycoro.py @@ -2,7 +2,6 @@ import inspect from ..hdl import * from ..hdl._ast import Statement, Assign, SignalSet, ValueCastable -from ..hdl._mem import MemorySimRead, MemorySimWrite from .core import Tick, Settle, Delay, Passive, Active from ._base import BaseProcess, BaseMemoryState from ._pyeval import eval_value, eval_assign @@ -123,23 +122,6 @@ class PyCoroProcess(BaseProcess): elif type(command) is Active: self.passive = False - elif type(command) is MemorySimRead: - addr = eval_value(self.state, command._addr) - index = self.state.get_memory(command._memory) - state = self.state.slots[index] - assert isinstance(state, BaseMemoryState) - response = state.read(addr) - - elif type(command) is MemorySimWrite: - addr = eval_value(self.state, command._addr) - data = eval_value(self.state, command._data) - index = self.state.get_memory(command._memory) - state = self.state.slots[index] - assert isinstance(state, BaseMemoryState) - state.write(addr, data) - if self.testbench: - return True # assignment; run a delta cycle - elif command is None: # only possible if self.default_cmd is None raise TypeError("Received default command from process {!r} that was added " "with add_process(); did you mean to use Tick() instead?" diff --git a/amaranth/sim/_pyeval.py b/amaranth/sim/_pyeval.py index 510b0e9..32a4d0a 100644 --- a/amaranth/sim/_pyeval.py +++ b/amaranth/sim/_pyeval.py @@ -1,4 +1,5 @@ from amaranth.hdl._ast import * +from amaranth.hdl._mem import MemoryData def _eval_matches(test, patterns): @@ -118,6 +119,9 @@ def eval_value(sim, value): elif isinstance(value, Signal): slot = sim.get_signal(value) return sim.slots[slot].curr + elif isinstance(value, MemoryData._Row): + slot = sim.get_memory(value._memory) + return sim.slots[slot].read(value._index) elif isinstance(value, (ResetSignal, ClockSignal, AnyValue, Initial)): raise ValueError(f"Value {value!r} cannot be used in simulation") else: @@ -142,6 +146,15 @@ def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len): if lhs._signed and (value & (1 << (len(lhs) - 1))): value |= -1 << (len(lhs) - 1) sim.slots[slot].set(value) + elif isinstance(lhs, MemoryData._Row): + lhs_stop = lhs_start + rhs_len + if lhs_stop > len(lhs): + lhs_stop = len(lhs) + if lhs_start >= len(lhs): + return + slot = sim.get_memory(lhs._memory) + mask = (1 << lhs_stop) - (1 << lhs_start) + sim.slots[slot].write(lhs._index, rhs << lhs_start, mask) elif isinstance(lhs, Slice): _eval_assign_inner(sim, lhs.value, lhs_start + lhs.start, rhs, rhs_len) elif isinstance(lhs, Concat): diff --git a/amaranth/sim/core.py b/amaranth/sim/core.py index b54792e..bfd249c 100644 --- a/amaranth/sim/core.py +++ b/amaranth/sim/core.py @@ -5,6 +5,7 @@ from .._utils import deprecated from ..hdl._cd import * from ..hdl._ir import * from ..hdl._ast import Value, ValueLike +from ..hdl._mem import MemoryData from ._base import BaseEngine @@ -242,6 +243,8 @@ class Simulator: for trace in traces: if isinstance(trace, ValueLike): trace_cast = Value.cast(trace) + if isinstance(trace_cast, MemoryData._Row): + continue for trace_signal in trace_cast._rhs_signals(): if trace_signal.name == "": if trace_signal is trace: diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index 384c09d..298fe99 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -82,16 +82,28 @@ class _VCDWriter: for trace in traces: if isinstance(trace, ValueLike): trace = Value.cast(trace) - for trace_signal in trace._rhs_signals(): - if trace_signal not in signal_names: - if trace_signal.name not in assigned_names: - name = trace_signal.name + if isinstance(trace, MemoryData._Row): + memory = trace._memory + if not memory in memories: + if memory.name not in assigned_names: + name = memory.name else: - name = f"{trace_signal.name}${len(assigned_names)}" + name = f"{memory.name}${len(assigned_names)}" assert name not in assigned_names - trace_names[trace_signal] = {("bench", name)} + memories[memory] = ("bench", name) assigned_names.add(name) - self.traces.append(trace_signal) + self.traces.append(trace) + else: + for trace_signal in trace._rhs_signals(): + if trace_signal not in signal_names: + if trace_signal.name not in assigned_names: + name = trace_signal.name + else: + name = f"{trace_signal.name}${len(assigned_names)}" + assert name not in assigned_names + trace_names[trace_signal] = {("bench", name)} + assigned_names.add(name) + self.traces.append(trace_signal) elif isinstance(trace, MemoryData): if not trace in memories: if trace.name not in assigned_names: @@ -223,13 +235,16 @@ class _VCDWriter: self.gtkw_save.dumpfile_size(self.vcd_file.tell()) self.gtkw_save.treeopen("top") - for signal in self.traces: - if isinstance(signal, Signal): - for name in self.gtkw_signal_names[signal]: + for trace in self.traces: + if isinstance(trace, Signal): + for name in self.gtkw_signal_names[trace]: self.gtkw_save.trace(name) - elif isinstance(signal, MemoryIdentity): - for name in self.gtkw_memory_names[signal]: + elif isinstance(trace, MemoryData): + for name in self.gtkw_memory_names[trace]: self.gtkw_save.trace(name) + elif isinstance(trace, MemoryData._Row): + name = self.gtkw_memory_names[trace._memory][trace._index] + self.gtkw_save.trace(name) else: assert False # :nocov: diff --git a/docs/changes.rst b/docs/changes.rst index 54e9bd5..b289733 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -69,6 +69,7 @@ Implemented RFCs * `RFC 51`_: Add ``ShapeCastable.from_bits`` and ``amaranth.lib.data.Const`` * `RFC 53`_: Low-level I/O primitives * `RFC 59`_: Get rid of upwards propagation of clock domains +* `RFC 62`_: The `MemoryData`` class Language changes diff --git a/tests/test_hdl_mem.py b/tests/test_hdl_mem.py index 716cb9d..f165622 100644 --- a/tests/test_hdl_mem.py +++ b/tests/test_hdl_mem.py @@ -1,12 +1,38 @@ # amaranth: UnusedElaboratable=no -from amaranth.hdl._ast import * +from amaranth.hdl import * from amaranth.hdl._mem import * from amaranth._utils import _ignore_deprecated from .utils import * +class MemoryDataTestCase(FHDLTestCase): + def test_repr(self): + data = MemoryData(shape=8, depth=4, init=[]) + self.assertRepr(data, "(memory-data data)") + + def test_row(self): + data = MemoryData(shape=8, depth=4, init=[]) + self.assertRepr(data[2], "(memory-row (memory-data data) 2)") + + def test_row_wrong(self): + data = MemoryData(shape=8, depth=4, init=[]) + with self.assertRaisesRegex(IndexError, + r"^Index 4 is out of bounds \(memory has 4 rows\)$"): + data[4] + + def test_row_elab(self): + data = MemoryData(shape=8, depth=4, init=[]) + m = Module() + a = Signal(8) + with self.assertRaisesRegex(ValueError, + r"^Value \(memory-row \(memory-data data\) 0\) can only be used in simulator processes$"): + m.d.comb += a.eq(data[0]) + with self.assertRaisesRegex(ValueError, + r"^Value \(memory-row \(memory-data data\) 0\) can only be used in simulator processes$"): + m.d.comb += data[0].eq(1) + class MemoryTestCase(FHDLTestCase): def test_name(self): with _ignore_deprecated(): diff --git a/tests/test_sim.py b/tests/test_sim.py index 0f139b7..e6590a1 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -1043,17 +1043,19 @@ class SimulatorIntegrationTestCase(FHDLTestCase): self.setUp_memory() with self.assertSimulation(self.m) as sim: def process(): - self.assertEqual((yield self.memory[1]), 0x55) - self.assertEqual((yield self.memory[Const(1)]), 0x55) - self.assertEqual((yield self.memory[Const(2)]), 0x00) - yield self.memory[Const(1)].eq(Const(0x33)) - self.assertEqual((yield self.memory[Const(1)]), 0x33) + self.assertEqual((yield self.memory.data[1]), 0x55) + self.assertEqual((yield self.memory.data[1]), 0x55) + self.assertEqual((yield self.memory.data[2]), 0x00) + yield self.memory.data[1].eq(Const(0x33)) + self.assertEqual((yield self.memory.data[1]), 0x33) + yield self.memory.data[1][2:5].eq(Const(0x7)) + self.assertEqual((yield self.memory.data[1]), 0x3f) yield self.wrport.addr.eq(3) yield self.wrport.data.eq(0x22) yield self.wrport.en.eq(1) - self.assertEqual((yield self.memory[Const(3)]), 0) + self.assertEqual((yield self.memory.data[3]), 0) yield Tick() - self.assertEqual((yield self.memory[Const(3)]), 0x22) + self.assertEqual((yield self.memory.data[3]), 0x22) sim.add_clock(1e-6) sim.add_testbench(process) @@ -1062,13 +1064,13 @@ class SimulatorIntegrationTestCase(FHDLTestCase): self.setUp_memory() with self.assertSimulation(self.m) as sim: def process(): - self.assertEqual((yield self.memory[1]), 0x55) - self.assertEqual((yield self.memory[Const(1)]), 0x55) - self.assertEqual((yield self.memory[Const(2)]), 0x00) - yield self.memory[Const(1)].eq(Const(0x33)) - self.assertEqual((yield self.memory[Const(1)]), 0x55) + self.assertEqual((yield self.memory.data[1]), 0x55) + self.assertEqual((yield self.memory.data[1]), 0x55) + self.assertEqual((yield self.memory.data[2]), 0x00) + yield self.memory.data[1].eq(Const(0x33)) + self.assertEqual((yield self.memory.data[1]), 0x55) yield Tick() - self.assertEqual((yield self.memory[Const(1)]), 0x33) + self.assertEqual((yield self.memory.data[1]), 0x33) sim.add_clock(1e-6) sim.add_process(process)