diff --git a/amaranth/sim/core.py b/amaranth/sim/core.py index 668870d..4ea4993 100644 --- a/amaranth/sim/core.py +++ b/amaranth/sim/core.py @@ -35,14 +35,15 @@ class Simulator: self._engine = engine(self._design) self._clocked = set() - def _check_process(self, process): - if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)): - raise TypeError("Cannot add a process {!r} because it is not a generator function" - .format(process)) - return process + def _check_function(self, function, *, kind): + if not (inspect.isgeneratorfunction(function) or inspect.iscoroutinefunction(function)): + raise TypeError( + f"Cannot add a {kind} {function!r} because it is not an async function or " + f"generator function") + return function def add_process(self, process): - process = self._check_process(process) + process = self._check_function(process, kind="process") if inspect.iscoroutinefunction(process): self._engine.add_async_process(self, process) else: @@ -54,16 +55,17 @@ class Simulator: wrap_process = coro_wrapper(wrapper, testbench=False) self._engine.add_async_process(self, wrap_process) - def add_testbench(self, process, *, background=False): - if inspect.iscoroutinefunction(process): - self._engine.add_async_testbench(self, process, background=background) + def add_testbench(self, testbench, *, background=False): + testbench = self._check_function(testbench, kind="testbench") + if inspect.iscoroutinefunction(testbench): + self._engine.add_async_testbench(self, testbench, background=background) else: - process = coro_wrapper(process, testbench=True) - self._engine.add_async_testbench(self, process, background=background) + testbench = coro_wrapper(testbench, testbench=True) + self._engine.add_async_testbench(self, testbench, background=background) @deprecated("The `add_sync_process` method is deprecated per RFC 27. Use `add_process` or `add_testbench` instead.") def add_sync_process(self, process, *, domain="sync"): - process = self._check_process(process) + process = self._check_function(process, kind="process") def wrapper(): # Only start a sync process after the first clock edge (or reset edge, if the domain # uses an asynchronous reset). This matches the behavior of synchronous FFs. diff --git a/tests/test_sim.py b/tests/test_sim.py index 678e557..8859a19 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -718,17 +718,35 @@ class SimulatorIntegrationTestCase(FHDLTestCase): def test_add_process_wrong(self): with self.assertSimulation(Module()) as sim: with self.assertRaisesRegex(TypeError, - r"^Cannot add a process 1 because it is not a generator function$"): + r"^Cannot add a process 1 because it is not an async function or " + r"generator function$"): sim.add_process(1) def test_add_process_wrong_generator(self): with self.assertSimulation(Module()) as sim: with self.assertRaisesRegex(TypeError, - r"^Cannot add a process <.+?> because it is not a generator function$"): + r"^Cannot add a process <.+?> because it is not an async function or " + r"generator function$"): def process(): yield Delay() sim.add_process(process()) + def test_add_testbench_wrong(self): + with self.assertSimulation(Module()) as sim: + with self.assertRaisesRegex(TypeError, + r"^Cannot add a testbench 1 because it is not an async function or " + r"generator function$"): + sim.add_testbench(1) + + def test_add_testbench_wrong_generator(self): + with self.assertSimulation(Module()) as sim: + with self.assertRaisesRegex(TypeError, + r"^Cannot add a testbench <.+?> because it is not an async function or " + r"generator function$"): + def testbench(): + yield Delay() + sim.add_testbench(testbench()) + def test_add_clock_wrong_twice(self): m = Module() s = Signal() @@ -1935,3 +1953,12 @@ class SimulatorRegressionTestCase(FHDLTestCase): self.assertTrue(reached_tb) self.assertTrue(reached_proc) + + def test_bug_1363(self): + sim = Simulator(Module()) + with self.assertRaisesRegex(TypeError, + r"^Cannot add a testbench <.+?> because it is not an async function or " + r"generator function$"): + async def testbench(): + yield Delay() + sim.add_testbench(testbench())