From cd6cbd71caa64f3065502e40e97888cb365deef4 Mon Sep 17 00:00:00 2001 From: Wanda Date: Tue, 2 Apr 2024 22:20:13 +0200 Subject: [PATCH] hdl.{_ast,_dsl}: factor out the pattern normalization logic. --- amaranth/hdl/__init__.py | 6 ++- amaranth/hdl/_ast.py | 82 ++++++++++++++++++++++------------------ amaranth/hdl/_dsl.py | 43 +-------------------- amaranth/lib/enum.py | 2 +- tests/test_hdl_ast.py | 18 ++++----- tests/test_hdl_dsl.py | 14 +++---- tests/test_lib_enum.py | 1 + 7 files changed, 69 insertions(+), 97 deletions(-) diff --git a/amaranth/hdl/__init__.py b/amaranth/hdl/__init__.py index bc0ebbf..43852da 100644 --- a/amaranth/hdl/__init__.py +++ b/amaranth/hdl/__init__.py @@ -1,9 +1,10 @@ +from ._ast import SyntaxError, SyntaxWarning from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike from ._ast import Value, ValueCastable, ValueLike from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal from ._ast import Format, Print, Assert, Assume, Cover from ._ast import IOValue, IOPort -from ._dsl import SyntaxError, SyntaxWarning, Module +from ._dsl import Module from ._cd import DomainError, ClockDomain from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment from ._ir import Instance, IOBufferInstance @@ -14,13 +15,14 @@ from ._xfrm import DomainRenamer, ResetInserter, EnableInserter __all__ = [ # _ast + "SyntaxError", "SyntaxWarning", "Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike", "Value", "ValueCastable", "ValueLike", "Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal", "Format", "Print", "Assert", "Assume", "Cover", "IOValue", "IOPort", # _dsl - "SyntaxError", "SyntaxWarning", "Module", + "Module", # _cd "DomainError", "ClockDomain", # _ir diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 15188c8..5b48ee0 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -16,6 +16,7 @@ from .._unused import * __all__ = [ + "SyntaxError", "SyntaxWarning", "Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "Array", "ArrayProxy", @@ -30,6 +31,14 @@ __all__ = [ ] +class SyntaxError(Exception): + pass + + +class SyntaxWarning(Warning): + pass + + class DUID: """Deterministic Unique IDentifier.""" __next_uid = 0 @@ -426,6 +435,37 @@ class ShapeLike(metaclass=_ShapeLikeMeta): raise TypeError("ShapeLike is an abstract class and cannot be instantiated") +def _normalize_patterns(patterns, shape, *, src_loc_at=1): + new_patterns = [] + for pattern in patterns: + orig_pattern = pattern + if isinstance(pattern, str): + if any(bit not in "01- \t" for bit in pattern): + raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't " + f"care) bits, and may include whitespace") + pattern = "".join(pattern.split()) # remove whitespace + if len(pattern) != shape.width: + raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as " + f"match value (which is {shape.width})") + else: + try: + pattern = Const.cast(pattern) + except TypeError as e: + raise SyntaxError(f"Pattern must be a string or a constant-castable " + f"expression, not {pattern!r}") from e + cast_pattern = Const(pattern.value, shape) + if cast_pattern.value != pattern.value: + warnings.warn(f"Pattern '{orig_pattern!r}' " + f"({pattern.shape().width}'{pattern.value:b}) is not " + f"representable in match value shape " + f"({shape!r}); comparison will never be true", + SyntaxWarning, stacklevel=2 + src_loc_at) + continue + pattern = pattern.value + new_patterns.append(pattern) + return tuple(new_patterns) + + def _overridable_by_reflected(method_name): """Allow overriding the decorated method. @@ -1248,36 +1288,12 @@ class Value(metaclass=ABCMeta): If a pattern has invalid syntax. """ matches = [] - # This code should accept exactly the same patterns as `with m.Case(...):`. - for pattern in patterns: - if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): - raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " - "bits, and may include whitespace" - .format(pattern)) - if (isinstance(pattern, str) and - len("".join(pattern.split())) != len(self)): - raise SyntaxError("Match pattern '{}' must have the same width as match value " - "(which is {})" - .format(pattern, len(self))) + for pattern in _normalize_patterns(patterns, self.shape()): if isinstance(pattern, str): - pattern = "".join(pattern.split()) # remove whitespace - mask = int(pattern.replace("0", "1").replace("-", "0"), 2) - pattern = int(pattern.replace("-", "0"), 2) + mask = int("0" + pattern.replace("0", "1").replace("-", "0"), 2) + pattern = int("0" + pattern.replace("-", "0"), 2) matches.append((self & mask) == pattern) else: - try: - orig_pattern, pattern = pattern, Const.cast(pattern) - except TypeError as e: - raise SyntaxError("Match pattern must be a string or a constant-castable " - "expression, not {!r}" - .format(pattern)) from e - pattern_len = bits_for(pattern.value) - if pattern_len > len(self): - warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value " - "(which has width {}); comparison will never be true" - .format(orig_pattern, pattern_len, pattern.value, len(self)), - SyntaxWarning, stacklevel=2) - continue matches.append(self == pattern) if not matches: return Const(0) @@ -2770,17 +2786,9 @@ class Switch(Statement): # Map: 2 -> "0010"; "0010" -> "0010" new_keys = () key_mask = (1 << len(self.test)) - 1 - for key in keys: - if isinstance(key, str): - key = "".join(key.split()) # remove whitespace - elif isinstance(key, int): + for key in _normalize_patterns(keys, self._test.shape()): + if isinstance(key, int): key = to_binary(key & key_mask, len(self.test)) - elif isinstance(key, Enum): - key = to_binary(key.value & key_mask, len(self.test)) - else: - raise TypeError("Object {!r} cannot be used as a switch key" - .format(key)) - assert len(key) == len(self.test) new_keys = (*new_keys, key) if not isinstance(stmts, Iterable): stmts = [stmts] diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index 2322a83..af636b8 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -9,7 +9,7 @@ from .._utils import flatten from ..utils import bits_for from .. import tracer from ._ast import * -from ._ast import _StatementList, _LateBoundStatement, Property, Print +from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns from ._ir import * from ._cd import * from ._xfrm import * @@ -18,14 +18,6 @@ from ._xfrm import * __all__ = ["SyntaxError", "SyntaxWarning", "Module"] -class SyntaxError(Exception): - pass - - -class SyntaxWarning(Warning): - pass - - class _ModuleBuilderProxy: def __init__(self, builder, depth): object.__setattr__(self, "_builder", builder) @@ -344,41 +336,10 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._check_context("Case", context="Switch") src_loc = tracer.get_src_loc(src_loc_at=1) switch_data = self._get_ctrl("Switch") - new_patterns = () if () in switch_data["cases"]: warnings.warn("A case defined after the default case will never be active", SyntaxWarning, stacklevel=3) - # This code should accept exactly the same patterns as `v.matches(...)`. - for pattern in patterns: - if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): - raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) " - "bits, and may include whitespace" - .format(pattern)) - if (isinstance(pattern, str) and - len("".join(pattern.split())) != len(switch_data["test"])): - raise SyntaxError("Case pattern '{}' must have the same width as switch value " - "(which is {})" - .format(pattern, len(switch_data["test"]))) - if isinstance(pattern, str): - new_patterns = (*new_patterns, pattern) - else: - try: - orig_pattern, pattern = pattern, Const.cast(pattern) - except TypeError as e: - raise SyntaxError("Case pattern must be a string or a constant-castable " - "expression, not {!r}" - .format(pattern)) from e - pattern_len = bits_for(pattern.value) - if pattern.value == 0: - pattern_len = 0 - if pattern_len > len(switch_data["test"]): - warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value " - "(which has width {}); comparison will never be true" - .format(orig_pattern, pattern_len, pattern.value, - len(switch_data["test"])), - SyntaxWarning, stacklevel=3) - continue - new_patterns = (*new_patterns, pattern.value) + new_patterns = _normalize_patterns(patterns, switch_data["test"].shape()) try: _outer_case, self._statements = self._statements, {} self._ctrl_context = None diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index adf94d6..6a2fce7 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -2,7 +2,7 @@ import enum as py_enum import warnings import operator -from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const +from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning from ..hdl._repr import * diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index d3a20fa..b173698 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -795,7 +795,7 @@ class OperatorTestCase(FHDLTestCase): def test_matches_enum(self): s = Signal(SignedEnum) self.assertRepr(s.matches(SignedEnum.FOO), """ - (== (sig s) (const 2'sd-1)) + (== (sig s) (const 1'sd-1)) """) def test_matches_const_castable(self): @@ -807,28 +807,28 @@ class OperatorTestCase(FHDLTestCase): def test_matches_width_wrong(self): s = Signal(4) with self.assertRaisesRegex(SyntaxError, - r"^Match pattern '--' must have the same width as match value \(which is 4\)$"): + r"^Pattern '--' must have the same width as match value \(which is 4\)$"): s.matches("--") with self.assertWarnsRegex(SyntaxWarning, - r"^Match pattern '22' \(5'10110\) is wider than match value \(which has " - r"width 4\); comparison will never be true$"): + r"^Pattern '22' \(5'10110\) is not representable in match value shape " + r"\(unsigned\(4\)\); comparison will never be true$"): s.matches(0b10110) with self.assertWarnsRegex(SyntaxWarning, - r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider " - r"than match value \(which has width 4\); comparison will never be true$"): + r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not " + r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"): s.matches(Cat(0, C(0b1011, 4))) def test_matches_bits_wrong(self): s = Signal(4) with self.assertRaisesRegex(SyntaxError, - r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " + r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " r"and may include whitespace$"): s.matches("abc") def test_matches_pattern_wrong(self): s = Signal(4) with self.assertRaisesRegex(SyntaxError, - r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"): + r"^Pattern must be a string or a constant-castable expression, not 1\.0$"): s.matches(1.0) def test_hash(self): @@ -1695,7 +1695,7 @@ class SwitchTestCase(FHDLTestCase): self.assertEqual(s.cases, {("00001010",): []}) def test_int_neg_case(self): - s = Switch(Const(0, 8), {-10: []}) + s = Switch(Const(0, signed(8)), {-10: []}) self.assertEqual(s.cases, {("11110110",): []}) def test_int_zero_width(self): diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index 8048e54..ff7cfa0 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -487,17 +487,17 @@ class DSLTestCase(FHDLTestCase): dummy = Signal() with m.Switch(self.w1): with self.assertRaisesRegex(SyntaxError, - r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"): + r"^Pattern '--' must have the same width as match value \(which is 4\)$"): with m.Case("--"): m.d.comb += dummy.eq(0) with self.assertWarnsRegex(SyntaxWarning, - r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has " - r"width 4\); comparison will never be true$"): + r"^Pattern '22' \(5'10110\) is not representable in match value shape " + r"\(unsigned\(4\)\); comparison will never be true$"): with m.Case(0b10110): m.d.comb += dummy.eq(0) with self.assertWarnsRegex(SyntaxWarning, - r"^Case pattern '' \(8'10101010\) is wider than switch value " - r"\(which has width 4\); comparison will never be true$"): + r"^Pattern '' \(8'10101010\) is not representable in " + r"match value shape \(unsigned\(4\)\); comparison will never be true$"): with m.Case(Color.RED): m.d.comb += dummy.eq(0) self.assertEqual(m._statements, {}) @@ -521,7 +521,7 @@ class DSLTestCase(FHDLTestCase): m = Module() with m.Switch(self.w1): with self.assertRaisesRegex(SyntaxError, - (r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " + (r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " r"and may include whitespace$")): with m.Case("abc"): pass @@ -530,7 +530,7 @@ class DSLTestCase(FHDLTestCase): m = Module() with m.Switch(self.w1): with self.assertRaisesRegex(SyntaxError, - r"^Case pattern must be a string or a constant-castable expression, " + r"^Pattern must be a string or a constant-castable expression, " r"not 1\.0$"): with m.Case(1.0): pass diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 9d8d0cf..af4b006 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -4,6 +4,7 @@ import sys import unittest from amaranth import * +from amaranth.hdl import * from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView from .utils import *