diff --git a/amaranth/tracer.py b/amaranth/tracer.py index 5a21dd0..bc61e2f 100644 --- a/amaranth/tracer.py +++ b/amaranth/tracer.py @@ -26,24 +26,40 @@ def get_var_name(depth=2, default=_raise_exception): break if call_opc not in ("CALL_FUNCTION", "CALL_FUNCTION_KW", "CALL_FUNCTION_EX", "CALL_METHOD", "CALL"): - return default + if default is _raise_exception: + raise NameNotFound + else: + return default index = call_index + 2 + imm = 0 while True: opc = opname[code.co_code[index]] - if opc in ("STORE_NAME", "STORE_ATTR"): - name_index = int(code.co_code[index + 1]) - return code.co_names[name_index] + if opc == 'EXTENDED_ARG': + imm |= int(code.co_code[index + 1]) + imm <<= 8 + index += 2 + elif opc in ("STORE_NAME", "STORE_ATTR"): + imm |= int(code.co_code[index + 1]) + return code.co_names[imm] elif opc == "STORE_FAST": - name_index = int(code.co_code[index + 1]) - return code.co_varnames[name_index] - elif opc == "STORE_DEREF": - name_index = int(code.co_code[index + 1]) + imm |= int(code.co_code[index + 1]) if sys.version_info >= (3, 11): - name_index -= code.co_nlocals - return code.co_cellvars[name_index] + return code._varname_from_oparg(imm) + else: + return code.co_varnames[imm] + elif opc == "STORE_DEREF": + imm |= int(code.co_code[index + 1]) + if sys.version_info >= (3, 11): + return code._varname_from_oparg(imm) + else: + if imm < len(code.co_cellvars): + return code.co_cellvars[imm] + else: + return code.co_freevars[imm - len(code.co_cellvars)] elif opc in ("LOAD_GLOBAL", "LOAD_NAME", "LOAD_ATTR", "LOAD_FAST", "LOAD_DEREF", "DUP_TOP", "BUILD_LIST", "CACHE", "COPY"): + imm = 0 index += 2 else: if default is _raise_exception: diff --git a/tests/test_tracer.py b/tests/test_tracer.py new file mode 100644 index 0000000..54c6994 --- /dev/null +++ b/tests/test_tracer.py @@ -0,0 +1,79 @@ +from amaranth.hdl.ast import * +from types import SimpleNamespace + +from .utils import * + +class TracerTestCase(FHDLTestCase): + def test_fast(self): + s1 = Signal() + self.assertEqual(s1.name, "s1") + s2 = Signal() + self.assertEqual(s2.name, "s2") + + def test_name(self): + class Dummy: + s1 = Signal() + self.assertEqual(s1.name, "s1") + s2 = Signal() + self.assertEqual(s2.name, "s2") + + def test_attr(self): + ns = SimpleNamespace() + ns.s1 = Signal() + self.assertEqual(ns.s1.name, "s1") + ns.s2 = Signal() + self.assertEqual(ns.s2.name, "s2") + + def test_index(self): + l = [None] + l[0] = Signal() + self.assertEqual(l[0].name, "$signal") + + def test_deref_cell(self): + s1 = Signal() + self.assertEqual(s1.name, "s1") + s2 = Signal() + self.assertEqual(s2.name, "s2") + + def dummy(): + return s1, s2 + + def test_deref_free(self): + def inner(): + nonlocal s3, s4 + s3 = Signal() + s4 = Signal() + return s1, s2 + + s1 = Signal() + s2 = Signal() + s3 = None + s4 = None + inner() + self.assertEqual(s1.name, "s1") + self.assertEqual(s2.name, "s2") + self.assertEqual(s3.name, "s3") + self.assertEqual(s4.name, "s4") + + def test_long(self): + test = "" + for i in range(100000): + test += f"dummy{i} = None\n" + test += "s1 = Signal()\n" + test += "s2 = Signal()\n" + ns = {"Signal": Signal} + exec(test, ns) + self.assertEqual(ns["s1"].name, "s1") + self.assertEqual(ns["s2"].name, "s2") + + def test_deref_fast(self): + def inner(s2): + s1 = Signal() + s2 = Signal() + self.assertEqual(s1.name, "s1") + self.assertEqual(s2.name, "s2") + + def dummy(): + return s1, s2 + + inner(None)