hdl.{_ast,_dsl}: factor out the pattern normalization logic.
This commit is contained in:
		
							parent
							
								
									0e4c2de725
								
							
						
					
					
						commit
						cd6cbd71ca
					
				|  | @ -1,9 +1,10 @@ | |||
| from ._ast import SyntaxError, SyntaxWarning | ||||
| from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike | ||||
| from ._ast import Value, ValueCastable, ValueLike | ||||
| from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal | ||||
| from ._ast import Format, Print, Assert, Assume, Cover | ||||
| from ._ast import IOValue, IOPort | ||||
| from ._dsl import SyntaxError, SyntaxWarning, Module | ||||
| from ._dsl import Module | ||||
| from ._cd import DomainError, ClockDomain | ||||
| from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment | ||||
| from ._ir import Instance, IOBufferInstance | ||||
|  | @ -14,13 +15,14 @@ from ._xfrm import DomainRenamer, ResetInserter, EnableInserter | |||
| 
 | ||||
| __all__ = [ | ||||
|     # _ast | ||||
|     "SyntaxError", "SyntaxWarning", | ||||
|     "Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike", | ||||
|     "Value", "ValueCastable", "ValueLike", | ||||
|     "Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal", | ||||
|     "Format", "Print", "Assert", "Assume", "Cover", | ||||
|     "IOValue", "IOPort", | ||||
|     # _dsl | ||||
|     "SyntaxError", "SyntaxWarning", "Module", | ||||
|     "Module", | ||||
|     # _cd | ||||
|     "DomainError", "ClockDomain", | ||||
|     # _ir | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ from .._unused import * | |||
| 
 | ||||
| 
 | ||||
| __all__ = [ | ||||
|     "SyntaxError", "SyntaxWarning", | ||||
|     "Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike", | ||||
|     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", | ||||
|     "Array", "ArrayProxy", | ||||
|  | @ -30,6 +31,14 @@ __all__ = [ | |||
| ] | ||||
| 
 | ||||
| 
 | ||||
| class SyntaxError(Exception): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class SyntaxWarning(Warning): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class DUID: | ||||
|     """Deterministic Unique IDentifier.""" | ||||
|     __next_uid = 0 | ||||
|  | @ -426,6 +435,37 @@ class ShapeLike(metaclass=_ShapeLikeMeta): | |||
|         raise TypeError("ShapeLike is an abstract class and cannot be instantiated") | ||||
| 
 | ||||
| 
 | ||||
| def _normalize_patterns(patterns, shape, *, src_loc_at=1): | ||||
|     new_patterns = [] | ||||
|     for pattern in patterns: | ||||
|         orig_pattern = pattern | ||||
|         if isinstance(pattern, str): | ||||
|             if any(bit not in "01- \t" for bit in pattern): | ||||
|                 raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't " | ||||
|                                   f"care) bits, and may include whitespace") | ||||
|             pattern = "".join(pattern.split()) # remove whitespace | ||||
|             if len(pattern) != shape.width: | ||||
|                 raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as " | ||||
|                                   f"match value (which is {shape.width})") | ||||
|         else: | ||||
|             try: | ||||
|                 pattern = Const.cast(pattern) | ||||
|             except TypeError as e: | ||||
|                 raise SyntaxError(f"Pattern must be a string or a constant-castable " | ||||
|                                   f"expression, not {pattern!r}") from e | ||||
|             cast_pattern = Const(pattern.value, shape) | ||||
|             if cast_pattern.value != pattern.value: | ||||
|                 warnings.warn(f"Pattern '{orig_pattern!r}' " | ||||
|                               f"({pattern.shape().width}'{pattern.value:b}) is not " | ||||
|                               f"representable in match value shape " | ||||
|                               f"({shape!r}); comparison will never be true", | ||||
|                               SyntaxWarning, stacklevel=2 + src_loc_at) | ||||
|                 continue | ||||
|             pattern = pattern.value | ||||
|         new_patterns.append(pattern) | ||||
|     return tuple(new_patterns) | ||||
| 
 | ||||
| 
 | ||||
| def _overridable_by_reflected(method_name): | ||||
|     """Allow overriding the decorated method. | ||||
| 
 | ||||
|  | @ -1248,36 +1288,12 @@ class Value(metaclass=ABCMeta): | |||
|             If a pattern has invalid syntax. | ||||
|         """ | ||||
|         matches = [] | ||||
|         # This code should accept exactly the same patterns as `with m.Case(...):`. | ||||
|         for pattern in patterns: | ||||
|             if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): | ||||
|                 raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " | ||||
|                                   "bits, and may include whitespace" | ||||
|                                   .format(pattern)) | ||||
|             if (isinstance(pattern, str) and | ||||
|                     len("".join(pattern.split())) != len(self)): | ||||
|                 raise SyntaxError("Match pattern '{}' must have the same width as match value " | ||||
|                                   "(which is {})" | ||||
|                                   .format(pattern, len(self))) | ||||
|         for pattern in _normalize_patterns(patterns, self.shape()): | ||||
|             if isinstance(pattern, str): | ||||
|                 pattern = "".join(pattern.split()) # remove whitespace | ||||
|                 mask    = int(pattern.replace("0", "1").replace("-", "0"), 2) | ||||
|                 pattern = int(pattern.replace("-", "0"), 2) | ||||
|                 mask    = int("0" + pattern.replace("0", "1").replace("-", "0"), 2) | ||||
|                 pattern = int("0" + pattern.replace("-", "0"), 2) | ||||
|                 matches.append((self & mask) == pattern) | ||||
|             else: | ||||
|                 try: | ||||
|                     orig_pattern, pattern = pattern, Const.cast(pattern) | ||||
|                 except TypeError as e: | ||||
|                     raise SyntaxError("Match pattern must be a string or a constant-castable " | ||||
|                                       "expression, not {!r}" | ||||
|                                       .format(pattern)) from e | ||||
|                 pattern_len = bits_for(pattern.value) | ||||
|                 if pattern_len > len(self): | ||||
|                     warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value " | ||||
|                                   "(which has width {}); comparison will never be true" | ||||
|                                   .format(orig_pattern, pattern_len, pattern.value, len(self)), | ||||
|                                   SyntaxWarning, stacklevel=2) | ||||
|                     continue | ||||
|                 matches.append(self == pattern) | ||||
|         if not matches: | ||||
|             return Const(0) | ||||
|  | @ -2770,17 +2786,9 @@ class Switch(Statement): | |||
|             # Map: 2 -> "0010"; "0010" -> "0010" | ||||
|             new_keys = () | ||||
|             key_mask = (1 << len(self.test)) - 1 | ||||
|             for key in keys: | ||||
|                 if isinstance(key, str): | ||||
|                     key = "".join(key.split()) # remove whitespace | ||||
|                 elif isinstance(key, int): | ||||
|             for key in _normalize_patterns(keys, self._test.shape()): | ||||
|                 if isinstance(key, int): | ||||
|                     key = to_binary(key & key_mask, len(self.test)) | ||||
|                 elif isinstance(key, Enum): | ||||
|                     key = to_binary(key.value & key_mask, len(self.test)) | ||||
|                 else: | ||||
|                     raise TypeError("Object {!r} cannot be used as a switch key" | ||||
|                                     .format(key)) | ||||
|                 assert len(key) == len(self.test) | ||||
|                 new_keys = (*new_keys, key) | ||||
|             if not isinstance(stmts, Iterable): | ||||
|                 stmts = [stmts] | ||||
|  |  | |||
|  | @ -9,7 +9,7 @@ from .._utils import flatten | |||
| from ..utils import bits_for | ||||
| from .. import tracer | ||||
| from ._ast import * | ||||
| from ._ast import _StatementList, _LateBoundStatement, Property, Print | ||||
| from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns | ||||
| from ._ir import * | ||||
| from ._cd import * | ||||
| from ._xfrm import * | ||||
|  | @ -18,14 +18,6 @@ from ._xfrm import * | |||
| __all__ = ["SyntaxError", "SyntaxWarning", "Module"] | ||||
| 
 | ||||
| 
 | ||||
| class SyntaxError(Exception): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class SyntaxWarning(Warning): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class _ModuleBuilderProxy: | ||||
|     def __init__(self, builder, depth): | ||||
|         object.__setattr__(self, "_builder", builder) | ||||
|  | @ -344,41 +336,10 @@ class Module(_ModuleBuilderRoot, Elaboratable): | |||
|         self._check_context("Case", context="Switch") | ||||
|         src_loc = tracer.get_src_loc(src_loc_at=1) | ||||
|         switch_data = self._get_ctrl("Switch") | ||||
|         new_patterns = () | ||||
|         if () in switch_data["cases"]: | ||||
|             warnings.warn("A case defined after the default case will never be active", | ||||
|                           SyntaxWarning, stacklevel=3) | ||||
|         # This code should accept exactly the same patterns as `v.matches(...)`. | ||||
|         for pattern in patterns: | ||||
|             if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): | ||||
|                 raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) " | ||||
|                                   "bits, and may include whitespace" | ||||
|                                   .format(pattern)) | ||||
|             if (isinstance(pattern, str) and | ||||
|                     len("".join(pattern.split())) != len(switch_data["test"])): | ||||
|                 raise SyntaxError("Case pattern '{}' must have the same width as switch value " | ||||
|                                   "(which is {})" | ||||
|                                   .format(pattern, len(switch_data["test"]))) | ||||
|             if isinstance(pattern, str): | ||||
|                 new_patterns = (*new_patterns, pattern) | ||||
|             else: | ||||
|                 try: | ||||
|                     orig_pattern, pattern = pattern, Const.cast(pattern) | ||||
|                 except TypeError as e: | ||||
|                     raise SyntaxError("Case pattern must be a string or a constant-castable " | ||||
|                                       "expression, not {!r}" | ||||
|                                       .format(pattern)) from e | ||||
|                 pattern_len = bits_for(pattern.value) | ||||
|                 if pattern.value == 0: | ||||
|                     pattern_len = 0 | ||||
|                 if pattern_len > len(switch_data["test"]): | ||||
|                     warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value " | ||||
|                                   "(which has width {}); comparison will never be true" | ||||
|                                   .format(orig_pattern, pattern_len, pattern.value, | ||||
|                                           len(switch_data["test"])), | ||||
|                                   SyntaxWarning, stacklevel=3) | ||||
|                     continue | ||||
|                 new_patterns = (*new_patterns, pattern.value) | ||||
|         new_patterns = _normalize_patterns(patterns, switch_data["test"].shape()) | ||||
|         try: | ||||
|             _outer_case, self._statements = self._statements, {} | ||||
|             self._ctrl_context = None | ||||
|  |  | |||
|  | @ -2,7 +2,7 @@ import enum as py_enum | |||
| import warnings | ||||
| import operator | ||||
| 
 | ||||
| from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const | ||||
| from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning | ||||
| from ..hdl._repr import * | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -795,7 +795,7 @@ class OperatorTestCase(FHDLTestCase): | |||
|     def test_matches_enum(self): | ||||
|         s = Signal(SignedEnum) | ||||
|         self.assertRepr(s.matches(SignedEnum.FOO), """ | ||||
|         (== (sig s) (const 2'sd-1)) | ||||
|         (== (sig s) (const 1'sd-1)) | ||||
|         """) | ||||
| 
 | ||||
|     def test_matches_const_castable(self): | ||||
|  | @ -807,28 +807,28 @@ class OperatorTestCase(FHDLTestCase): | |||
|     def test_matches_width_wrong(self): | ||||
|         s = Signal(4) | ||||
|         with self.assertRaisesRegex(SyntaxError, | ||||
|                 r"^Match pattern '--' must have the same width as match value \(which is 4\)$"): | ||||
|                 r"^Pattern '--' must have the same width as match value \(which is 4\)$"): | ||||
|             s.matches("--") | ||||
|         with self.assertWarnsRegex(SyntaxWarning, | ||||
|                 r"^Match pattern '22' \(5'10110\) is wider than match value \(which has " | ||||
|                 r"width 4\); comparison will never be true$"): | ||||
|                 r"^Pattern '22' \(5'10110\) is not representable in match value shape " | ||||
|                 r"\(unsigned\(4\)\); comparison will never be true$"): | ||||
|             s.matches(0b10110) | ||||
|         with self.assertWarnsRegex(SyntaxWarning, | ||||
|                 r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider " | ||||
|                 r"than match value \(which has width 4\); comparison will never be true$"): | ||||
|                 r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not " | ||||
|                 r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"): | ||||
|             s.matches(Cat(0, C(0b1011, 4))) | ||||
| 
 | ||||
|     def test_matches_bits_wrong(self): | ||||
|         s = Signal(4) | ||||
|         with self.assertRaisesRegex(SyntaxError, | ||||
|                 r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " | ||||
|                 r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " | ||||
|                 r"and may include whitespace$"): | ||||
|             s.matches("abc") | ||||
| 
 | ||||
|     def test_matches_pattern_wrong(self): | ||||
|         s = Signal(4) | ||||
|         with self.assertRaisesRegex(SyntaxError, | ||||
|                 r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"): | ||||
|                 r"^Pattern must be a string or a constant-castable expression, not 1\.0$"): | ||||
|             s.matches(1.0) | ||||
| 
 | ||||
|     def test_hash(self): | ||||
|  | @ -1695,7 +1695,7 @@ class SwitchTestCase(FHDLTestCase): | |||
|         self.assertEqual(s.cases, {("00001010",): []}) | ||||
| 
 | ||||
|     def test_int_neg_case(self): | ||||
|         s = Switch(Const(0, 8), {-10: []}) | ||||
|         s = Switch(Const(0, signed(8)), {-10: []}) | ||||
|         self.assertEqual(s.cases, {("11110110",): []}) | ||||
| 
 | ||||
|     def test_int_zero_width(self): | ||||
|  |  | |||
|  | @ -487,17 +487,17 @@ class DSLTestCase(FHDLTestCase): | |||
|         dummy = Signal() | ||||
|         with m.Switch(self.w1): | ||||
|             with self.assertRaisesRegex(SyntaxError, | ||||
|                     r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"): | ||||
|                     r"^Pattern '--' must have the same width as match value \(which is 4\)$"): | ||||
|                 with m.Case("--"): | ||||
|                     m.d.comb += dummy.eq(0) | ||||
|             with self.assertWarnsRegex(SyntaxWarning, | ||||
|                     r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has " | ||||
|                     r"width 4\); comparison will never be true$"): | ||||
|                     r"^Pattern '22' \(5'10110\) is not representable in match value shape " | ||||
|                     r"\(unsigned\(4\)\); comparison will never be true$"): | ||||
|                 with m.Case(0b10110): | ||||
|                     m.d.comb += dummy.eq(0) | ||||
|             with self.assertWarnsRegex(SyntaxWarning, | ||||
|                     r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value " | ||||
|                     r"\(which has width 4\); comparison will never be true$"): | ||||
|                     r"^Pattern '<Color.RED: 170>' \(8'10101010\) is not representable in " | ||||
|                     r"match value shape \(unsigned\(4\)\); comparison will never be true$"): | ||||
|                 with m.Case(Color.RED): | ||||
|                     m.d.comb += dummy.eq(0) | ||||
|         self.assertEqual(m._statements, {}) | ||||
|  | @ -521,7 +521,7 @@ class DSLTestCase(FHDLTestCase): | |||
|         m = Module() | ||||
|         with m.Switch(self.w1): | ||||
|             with self.assertRaisesRegex(SyntaxError, | ||||
|                     (r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " | ||||
|                     (r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " | ||||
|                         r"and may include whitespace$")): | ||||
|                 with m.Case("abc"): | ||||
|                     pass | ||||
|  | @ -530,7 +530,7 @@ class DSLTestCase(FHDLTestCase): | |||
|         m = Module() | ||||
|         with m.Switch(self.w1): | ||||
|             with self.assertRaisesRegex(SyntaxError, | ||||
|                     r"^Case pattern must be a string or a constant-castable expression, " | ||||
|                     r"^Pattern must be a string or a constant-castable expression, " | ||||
|                     r"not 1\.0$"): | ||||
|                 with m.Case(1.0): | ||||
|                     pass | ||||
|  |  | |||
|  | @ -4,6 +4,7 @@ import sys | |||
| import unittest | ||||
| 
 | ||||
| from amaranth import * | ||||
| from amaranth.hdl import * | ||||
| from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView | ||||
| 
 | ||||
| from .utils import * | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue
	
	 Wanda
						Wanda