From 3a51b612844a23b08e744c4b3372ecb44bf9fe5d Mon Sep 17 00:00:00 2001 From: Jin Xue Date: Sat, 24 Sep 2022 18:22:47 +0800 Subject: [PATCH] sim._pyrtl: translate ArrayProxy to pattern matching when supported. Current the value compiler translates ArrayProxy into if-elif trees which can cause the compiler to crash due to deep recursion (#359). After this commit, it instead translates them into pattern matching when it is supported (on Python >= 3.10) to avoid this problem. --- amaranth/sim/_pyrtl.py | 61 +++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index b016c16..0e37847 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -1,6 +1,7 @@ import os import tempfile from contextlib import contextmanager +import sys from ..hdl import * from ..hdl.ast import SignalSet @@ -10,6 +11,7 @@ from ._base import BaseProcess __all__ = ["PyRTLProcess"] +_USE_PATTERN_MATCHING = (sys.version_info >= (3, 10)) class PyRTLProcess(BaseProcess): __slots__ = ("is_comb", "runnable", "passive", "run") @@ -228,16 +230,28 @@ class _RHSValueCompiler(_ValueCompiler): gen_index = self.emitter.def_var("rhs_index", f"{index_mask:#x} & {self(value.index)}") gen_value = self.emitter.gen_var("rhs_proxy") if value.elems: - for index, elem in enumerate(value.elems): - if index == 0: - self.emitter.append(f"if {index} == {gen_index}:") - else: - self.emitter.append(f"elif {index} == {gen_index}:") + if _USE_PATTERN_MATCHING: + self.emitter.append(f"match {gen_index}:") with self.emitter.indent(): - self.emitter.append(f"{gen_value} = {self(elem)}") - self.emitter.append(f"else:") - with self.emitter.indent(): - self.emitter.append(f"{gen_value} = {self(value.elems[-1])}") + for index, elem in enumerate(value.elems): + self.emitter.append(f"case {index}:") + with self.emitter.indent(): + self.emitter.append(f"{gen_value} = {self(elem)}") + self.emitter.append("case _:") + with self.emitter.indent(): + self.emitter.append(f"{gen_value} = {self(value.elems[-1])}") + else: + for index, elem in enumerate(value.elems): + if index == 0: + self.emitter.append(f"if {index} == {gen_index}:") + else: + self.emitter.append(f"elif {index} == {gen_index}:") + with self.emitter.indent(): + self.emitter.append(f"{gen_value} = {self(elem)}") + self.emitter.append(f"else:") + with self.emitter.indent(): + self.emitter.append(f"{gen_value} = {self(value.elems[-1])}") + return gen_value else: return f"0" @@ -319,16 +333,27 @@ class _LHSValueCompiler(_ValueCompiler): index_mask = (1 << len(value.index)) - 1 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask:#x}") if value.elems: - for index, elem in enumerate(value.elems): - if index == 0: - self.emitter.append(f"if {index} == {gen_index}:") - else: - self.emitter.append(f"elif {index} == {gen_index}:") + if _USE_PATTERN_MATCHING: + self.emitter.append(f"match {gen_index}:") with self.emitter.indent(): - self(elem)(arg) - self.emitter.append(f"else:") - with self.emitter.indent(): - self(value.elems[-1])(arg) + for index, elem in enumerate(value.elems): + self.emitter.append(f"case {index}:") + with self.emitter.indent(): + self(elem)(arg) + self.emitter.append("case _:") + with self.emitter.indent(): + self(value.elems[-1])(arg) + else: + for index, elem in enumerate(value.elems): + if index == 0: + self.emitter.append(f"if {index} == {gen_index}:") + else: + self.emitter.append(f"elif {index} == {gen_index}:") + with self.emitter.indent(): + self(elem)(arg) + self.emitter.append(f"else:") + with self.emitter.indent(): + self(value.elems[-1])(arg) else: self.emitter.append(f"pass") return gen