From 51e02627108292be67a724e4fd42715ec00607d7 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Wed, 22 May 2024 10:40:34 +0200 Subject: [PATCH] sim: group signal traces according to their function. --- amaranth/sim/core.py | 20 ++++++++---- amaranth/sim/pysim.py | 72 ++++++++++++++++++++++++++++++------------- tests/test_sim.py | 45 ++++++++++++++++++++++++++- 3 files changed, 109 insertions(+), 28 deletions(-) diff --git a/amaranth/sim/core.py b/amaranth/sim/core.py index 4ea4993..d3f5581 100644 --- a/amaranth/sim/core.py +++ b/amaranth/sim/core.py @@ -212,17 +212,25 @@ class Simulator: file.close() raise ValueError("Cannot start writing waveforms after advancing simulation time") - for trace in traces: - if isinstance(trace, ValueLike): - trace_cast = Value.cast(trace) + def traverse_traces(traces): + if isinstance(traces, ValueLike): + trace_cast = Value.cast(traces) if isinstance(trace_cast, MemoryData._Row): - continue + return for trace_signal in trace_cast._rhs_signals(): if trace_signal.name == "": - if trace_signal is trace: + if trace_signal is traces: raise TypeError("Cannot trace signal with private name") else: - raise TypeError(f"Cannot trace signal with private name (within {trace!r})") + raise TypeError(f"Cannot trace signal with private name (within {traces!r})") + elif isinstance(traces, (list, tuple)): + for trace in traces: + traverse_traces(trace) + elif isinstance(traces, dict): + for trace in traces.values(): + traverse_traces(trace) + + traverse_traces(traces) return self._engine.write_vcd(vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces, fs_per_delta=fs_per_delta) diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index da49aa0..1a9c876 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -5,7 +5,9 @@ import os.path import enum as py_enum from ..hdl import * +from ..hdl._mem import MemoryInstance from ..hdl._ast import SignalDict +from ..lib import data, wiring from ._base import * from ._async import * from ._pyeval import eval_format, eval_value, eval_assign @@ -49,7 +51,7 @@ class _VCDWriter: self.gtkw_file = gtkw_file self.gtkw_save = gtkw_file and vcd.gtkw.GTKWSave(self.gtkw_file) - self.traces = [] + self.traces = traces signal_names = SignalDict() memories = {} @@ -64,9 +66,9 @@ class _VCDWriter: trace_names = SignalDict() assigned_names = set() - for trace in traces: - if isinstance(trace, ValueLike): - trace = Value.cast(trace) + def traverse_traces(traces): + if isinstance(traces, ValueLike): + trace = Value.cast(traces) if isinstance(trace, MemoryData._Row): memory = trace._memory if not memory in memories: @@ -77,7 +79,6 @@ class _VCDWriter: assert name not in assigned_names memories[memory] = ("bench", name) assigned_names.add(name) - self.traces.append(trace) else: for trace_signal in trace._rhs_signals(): if trace_signal not in signal_names: @@ -88,19 +89,27 @@ class _VCDWriter: assert name not in assigned_names trace_names[trace_signal] = {("bench", name)} assigned_names.add(name) - self.traces.append(trace_signal) - elif isinstance(trace, MemoryData): - if not trace in memories: - if trace.name not in assigned_names: - name = trace.name + elif isinstance(traces, MemoryData): + if not traces in memories: + if traces.name not in assigned_names: + name = traces.name else: - name = f"{trace.name}${len(assigned_names)}" + name = f"{traces.name}${len(assigned_names)}" assert name not in assigned_names - memories[trace] = ("bench", name) + memories[traces] = ("bench", name) assigned_names.add(name) - self.traces.append(trace) + elif hasattr(traces, "signature") and isinstance(traces.signature, wiring.Signature): + for name in traces.signature.members: + traverse_traces(getattr(traces, name)) + elif isinstance(traces, list) or isinstance(traces, tuple): + for trace in traces: + traverse_traces(trace) + elif isinstance(traces, dict): + for trace in traces.values(): + traverse_traces(trace) else: - raise TypeError(f"{trace!r} is not a traceable object") + raise TypeError(f"{traces!r} is not a traceable object") + traverse_traces(traces) if self.vcd_writer is None: return @@ -277,19 +286,40 @@ class _VCDWriter: self.gtkw_save.dumpfile_size(self.vcd_file.tell()) self.gtkw_save.treeopen("top") - for trace in self.traces: - if isinstance(trace, Signal): - for name in self.gtkw_signal_names[trace]: + + def traverse_traces(traces): + if isinstance(traces, Signal): + for name in self.gtkw_signal_names[traces]: self.gtkw_save.trace(name) - elif isinstance(trace, MemoryData): - for row_names in self.gtkw_memory_names[trace]: + elif isinstance(traces, data.View): + with self.gtkw_save.group("view"): + trace = Value.cast(traces) + for trace_signal in trace._rhs_signals(): + for name in self.gtkw_signal_names[trace_signal]: + self.gtkw_save.trace(name) + elif isinstance(traces, ValueLike): + traverse_traces(Value.cast(traces)) + elif isinstance(traces, MemoryData): + for row_names in self.gtkw_memory_names[traces]: for name in row_names: self.gtkw_save.trace(name) - elif isinstance(trace, MemoryData._Row): - for name in self.gtkw_memory_names[trace._memory][trace._index]: + elif isinstance(traces, MemoryData._Row): + for name in self.gtkw_memory_names[traces._memory][traces._index]: self.gtkw_save.trace(name) + elif hasattr(traces, "signature") and isinstance(traces.signature, wiring.Signature): + with self.gtkw_save.group("interface"): + for _, _, member in traces.signature.flatten(traces): + traverse_traces(member) + elif isinstance(traces, list) or isinstance(traces, tuple): + for trace in traces: + traverse_traces(trace) + elif isinstance(traces, dict): + for name, trace in traces.items(): + with self.gtkw_save.group(name): + traverse_traces(trace) else: assert False # :nocov: + traverse_traces(self.traces) if self.close_vcd: self.vcd_file.close() diff --git a/tests/test_sim.py b/tests/test_sim.py index 8859a19..4179a17 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -16,7 +16,7 @@ from amaranth.hdl._ir import * from amaranth.sim import * from amaranth.sim._pyeval import eval_format from amaranth.lib.memory import Memory -from amaranth.lib import enum, data +from amaranth.lib import enum, data, wiring from .utils import * from amaranth._utils import _ignore_deprecated @@ -1393,6 +1393,49 @@ class SimulatorIntegrationTestCase(FHDLTestCase): sim.add_testbench(testbench) +class SimulatorTracesTestCase(FHDLTestCase): + def assertDef(self, traces, flat_traces): + frag = Fragment() + + def process(): + yield Delay(1e-6) + + sim = Simulator(frag) + sim.add_testbench(process) + with sim.write_vcd("test.vcd", "test.gtkw", traces=traces): + sim.run() + + def test_signal(self): + a = Signal() + self.assertDef(a, [a]) + + def test_list(self): + a = Signal() + self.assertDef([a], [a]) + + def test_tuple(self): + a = Signal() + self.assertDef((a,), [a]) + + def test_dict(self): + a = Signal() + self.assertDef({"a": a}, [a]) + + def test_struct_view(self): + a = Signal(data.StructLayout({"a": 1, "b": 3})) + self.assertDef(a, [a]) + + def test_interface(self): + sig = wiring.Signature({ + "a": wiring.In(1), + "b": wiring.Out(3), + "c": wiring.Out(2).array(4), + "d": wiring.In(wiring.Signature({"e": wiring.In(5)})) + }) + a = sig.create() + self.assertDef(a, [a]) + + class SimulatorRegressionTestCase(FHDLTestCase): def test_bug_325(self): dut = Module()