From fc9369b8e19dd57e88a1b09396ff9a4a771e0170 Mon Sep 17 00:00:00 2001 From: Wanda Date: Sun, 3 Mar 2024 19:32:15 +0100 Subject: [PATCH] hdl._xfrm: Simplify `EnableInserter` logic. --- amaranth/hdl/_xfrm.py | 21 ++++++--------------- tests/test_hdl_xfrm.py | 15 +++++---------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index e60be23..7f61b9d 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -575,23 +575,14 @@ class ResetInserter(_ControlInserter): fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) -class _PropertyEnableInserter(StatementTransformer): - def __init__(self, en): - self.en = en - - def on_Property(self, stmt): - return Switch( - self.en, - {1: [stmt]}, - src_loc=stmt.src_loc, - ) - - class EnableInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): - stmts = [s.eq(s) for s in signals] - fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) - fragment.statements[domain] = _PropertyEnableInserter(self.controls[domain])(fragment.statements[domain]) + if domain in fragment.statements: + fragment.statements[domain] = _StatementList([Switch( + self.controls[domain], + {1: fragment.statements[domain]}, + src_loc=self.src_loc, + )]) def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index b13fad2..450939b 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -337,9 +337,8 @@ class EnableInserterTestCase(FHDLTestCase): f = EnableInserter(self.c1)(f) self.assertRepr(f.statements["sync"], """ ( - (eq (sig s1) (const 1'd1)) (switch (sig c1) - (case 0 (eq (sig s1) (sig s1))) + (case 1 (eq (sig s1) (const 1'd1))) ) ) """) @@ -359,9 +358,8 @@ class EnableInserterTestCase(FHDLTestCase): """) self.assertRepr(f.statements["pix"], """ ( - (eq (sig s2) (const 1'd0)) (switch (sig c1) - (case 0 (eq (sig s2) (sig s2))) + (case 1 (eq (sig s2) (const 1'd0))) ) ) """) @@ -380,17 +378,15 @@ class EnableInserterTestCase(FHDLTestCase): (f2, _, _), = f1.subfragments self.assertRepr(f1.statements["sync"], """ ( - (eq (sig s1) (const 1'd1)) (switch (sig c1) - (case 0 (eq (sig s1) (sig s1))) + (case 1 (eq (sig s1) (const 1'd1))) ) ) """) self.assertRepr(f2.statements["sync"], """ ( - (eq (sig s2) (const 1'd1)) (switch (sig c1) - (case 0 (eq (sig s2) (sig s2))) + (case 1 (eq (sig s2) (const 1'd1))) ) ) """) @@ -451,9 +447,8 @@ class TransformedElaboratableTestCase(FHDLTestCase): f = Fragment.get(te2, None) self.assertRepr(f.statements["sync"], """ ( - (eq (sig s1) (const 1'd1)) (switch (sig c1) - (case 0 (eq (sig s1) (sig s1))) + (case 1 (eq (sig s1) (const 1'd1))) ) (switch (sig c2) (case 1 (eq (sig s1) (const 1'd0)))