hdl.{_ast,_dsl}: factor out the pattern normalization logic.

This commit is contained in:
Wanda 2024-04-02 22:20:13 +02:00 committed by Catherine
parent 0e4c2de725
commit cd6cbd71ca
7 changed files with 69 additions and 97 deletions

View file

@ -1,9 +1,10 @@
from ._ast import SyntaxError, SyntaxWarning
from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike
from ._ast import Value, ValueCastable, ValueLike
from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal
from ._ast import Format, Print, Assert, Assume, Cover
from ._ast import IOValue, IOPort
from ._dsl import SyntaxError, SyntaxWarning, Module
from ._dsl import Module
from ._cd import DomainError, ClockDomain
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
from ._ir import Instance, IOBufferInstance
@ -14,13 +15,14 @@ from ._xfrm import DomainRenamer, ResetInserter, EnableInserter
__all__ = [
# _ast
"SyntaxError", "SyntaxWarning",
"Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike",
"Value", "ValueCastable", "ValueLike",
"Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal",
"Format", "Print", "Assert", "Assume", "Cover",
"IOValue", "IOPort",
# _dsl
"SyntaxError", "SyntaxWarning", "Module",
"Module",
# _cd
"DomainError", "ClockDomain",
# _ir

View file

@ -16,6 +16,7 @@ from .._unused import *
__all__ = [
"SyntaxError", "SyntaxWarning",
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
"Array", "ArrayProxy",
@ -30,6 +31,14 @@ __all__ = [
]
class SyntaxError(Exception):
pass
class SyntaxWarning(Warning):
pass
class DUID:
"""Deterministic Unique IDentifier."""
__next_uid = 0
@ -426,6 +435,37 @@ class ShapeLike(metaclass=_ShapeLikeMeta):
raise TypeError("ShapeLike is an abstract class and cannot be instantiated")
def _normalize_patterns(patterns, shape, *, src_loc_at=1):
new_patterns = []
for pattern in patterns:
orig_pattern = pattern
if isinstance(pattern, str):
if any(bit not in "01- \t" for bit in pattern):
raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't "
f"care) bits, and may include whitespace")
pattern = "".join(pattern.split()) # remove whitespace
if len(pattern) != shape.width:
raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as "
f"match value (which is {shape.width})")
else:
try:
pattern = Const.cast(pattern)
except TypeError as e:
raise SyntaxError(f"Pattern must be a string or a constant-castable "
f"expression, not {pattern!r}") from e
cast_pattern = Const(pattern.value, shape)
if cast_pattern.value != pattern.value:
warnings.warn(f"Pattern '{orig_pattern!r}' "
f"({pattern.shape().width}'{pattern.value:b}) is not "
f"representable in match value shape "
f"({shape!r}); comparison will never be true",
SyntaxWarning, stacklevel=2 + src_loc_at)
continue
pattern = pattern.value
new_patterns.append(pattern)
return tuple(new_patterns)
def _overridable_by_reflected(method_name):
"""Allow overriding the decorated method.
@ -1248,36 +1288,12 @@ class Value(metaclass=ABCMeta):
If a pattern has invalid syntax.
"""
matches = []
# This code should accept exactly the same patterns as `with m.Case(...):`.
for pattern in patterns:
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace"
.format(pattern))
if (isinstance(pattern, str) and
len("".join(pattern.split())) != len(self)):
raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})"
.format(pattern, len(self)))
for pattern in _normalize_patterns(patterns, self.shape()):
if isinstance(pattern, str):
pattern = "".join(pattern.split()) # remove whitespace
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2)
mask = int("0" + pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int("0" + pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern)
else:
try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Match pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern_len > len(self):
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value, len(self)),
SyntaxWarning, stacklevel=2)
continue
matches.append(self == pattern)
if not matches:
return Const(0)
@ -2770,17 +2786,9 @@ class Switch(Statement):
# Map: 2 -> "0010"; "0010" -> "0010"
new_keys = ()
key_mask = (1 << len(self.test)) - 1
for key in keys:
if isinstance(key, str):
key = "".join(key.split()) # remove whitespace
elif isinstance(key, int):
for key in _normalize_patterns(keys, self._test.shape()):
if isinstance(key, int):
key = to_binary(key & key_mask, len(self.test))
elif isinstance(key, Enum):
key = to_binary(key.value & key_mask, len(self.test))
else:
raise TypeError("Object {!r} cannot be used as a switch key"
.format(key))
assert len(key) == len(self.test)
new_keys = (*new_keys, key)
if not isinstance(stmts, Iterable):
stmts = [stmts]

View file

@ -9,7 +9,7 @@ from .._utils import flatten
from ..utils import bits_for
from .. import tracer
from ._ast import *
from ._ast import _StatementList, _LateBoundStatement, Property, Print
from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns
from ._ir import *
from ._cd import *
from ._xfrm import *
@ -18,14 +18,6 @@ from ._xfrm import *
__all__ = ["SyntaxError", "SyntaxWarning", "Module"]
class SyntaxError(Exception):
pass
class SyntaxWarning(Warning):
pass
class _ModuleBuilderProxy:
def __init__(self, builder, depth):
object.__setattr__(self, "_builder", builder)
@ -344,41 +336,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._check_context("Case", context="Switch")
src_loc = tracer.get_src_loc(src_loc_at=1)
switch_data = self._get_ctrl("Switch")
new_patterns = ()
if () in switch_data["cases"]:
warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3)
# This code should accept exactly the same patterns as `v.matches(...)`.
for pattern in patterns:
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace"
.format(pattern))
if (isinstance(pattern, str) and
len("".join(pattern.split())) != len(switch_data["test"])):
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})"
.format(pattern, len(switch_data["test"])))
if isinstance(pattern, str):
new_patterns = (*new_patterns, pattern)
else:
try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Case pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern.value == 0:
pattern_len = 0
if pattern_len > len(switch_data["test"]):
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value,
len(switch_data["test"])),
SyntaxWarning, stacklevel=3)
continue
new_patterns = (*new_patterns, pattern.value)
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
try:
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None

View file

@ -2,7 +2,7 @@ import enum as py_enum
import warnings
import operator
from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const
from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning
from ..hdl._repr import *

View file

@ -795,7 +795,7 @@ class OperatorTestCase(FHDLTestCase):
def test_matches_enum(self):
s = Signal(SignedEnum)
self.assertRepr(s.matches(SignedEnum.FOO), """
(== (sig s) (const 2'sd-1))
(== (sig s) (const 1'sd-1))
""")
def test_matches_const_castable(self):
@ -807,28 +807,28 @@ class OperatorTestCase(FHDLTestCase):
def test_matches_width_wrong(self):
s = Signal(4)
with self.assertRaisesRegex(SyntaxError,
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"):
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
s.matches("--")
with self.assertWarnsRegex(SyntaxWarning,
r"^Match pattern '22' \(5'10110\) is wider than match value \(which has "
r"width 4\); comparison will never be true$"):
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
r"\(unsigned\(4\)\); comparison will never be true$"):
s.matches(0b10110)
with self.assertWarnsRegex(SyntaxWarning,
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider "
r"than match value \(which has width 4\); comparison will never be true$"):
r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not "
r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"):
s.matches(Cat(0, C(0b1011, 4)))
def test_matches_bits_wrong(self):
s = Signal(4)
with self.assertRaisesRegex(SyntaxError,
r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
r"and may include whitespace$"):
s.matches("abc")
def test_matches_pattern_wrong(self):
s = Signal(4)
with self.assertRaisesRegex(SyntaxError,
r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"):
r"^Pattern must be a string or a constant-castable expression, not 1\.0$"):
s.matches(1.0)
def test_hash(self):
@ -1695,7 +1695,7 @@ class SwitchTestCase(FHDLTestCase):
self.assertEqual(s.cases, {("00001010",): []})
def test_int_neg_case(self):
s = Switch(Const(0, 8), {-10: []})
s = Switch(Const(0, signed(8)), {-10: []})
self.assertEqual(s.cases, {("11110110",): []})
def test_int_zero_width(self):

View file

@ -487,17 +487,17 @@ class DSLTestCase(FHDLTestCase):
dummy = Signal()
with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError,
r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"):
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
with m.Case("--"):
m.d.comb += dummy.eq(0)
with self.assertWarnsRegex(SyntaxWarning,
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
r"width 4\); comparison will never be true$"):
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
r"\(unsigned\(4\)\); comparison will never be true$"):
with m.Case(0b10110):
m.d.comb += dummy.eq(0)
with self.assertWarnsRegex(SyntaxWarning,
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
r"\(which has width 4\); comparison will never be true$"):
r"^Pattern '<Color.RED: 170>' \(8'10101010\) is not representable in "
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
with m.Case(Color.RED):
m.d.comb += dummy.eq(0)
self.assertEqual(m._statements, {})
@ -521,7 +521,7 @@ class DSLTestCase(FHDLTestCase):
m = Module()
with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError,
(r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
(r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
r"and may include whitespace$")):
with m.Case("abc"):
pass
@ -530,7 +530,7 @@ class DSLTestCase(FHDLTestCase):
m = Module()
with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError,
r"^Case pattern must be a string or a constant-castable expression, "
r"^Pattern must be a string or a constant-castable expression, "
r"not 1\.0$"):
with m.Case(1.0):
pass

View file

@ -4,6 +4,7 @@ import sys
import unittest
from amaranth import *
from amaranth.hdl import *
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
from .utils import *