hdl.{_ast,_dsl}: factor out the pattern normalization logic.
This commit is contained in:
parent
0e4c2de725
commit
cd6cbd71ca
|
@ -1,9 +1,10 @@
|
|||
from ._ast import SyntaxError, SyntaxWarning
|
||||
from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike
|
||||
from ._ast import Value, ValueCastable, ValueLike
|
||||
from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal
|
||||
from ._ast import Format, Print, Assert, Assume, Cover
|
||||
from ._ast import IOValue, IOPort
|
||||
from ._dsl import SyntaxError, SyntaxWarning, Module
|
||||
from ._dsl import Module
|
||||
from ._cd import DomainError, ClockDomain
|
||||
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
|
||||
from ._ir import Instance, IOBufferInstance
|
||||
|
@ -14,13 +15,14 @@ from ._xfrm import DomainRenamer, ResetInserter, EnableInserter
|
|||
|
||||
__all__ = [
|
||||
# _ast
|
||||
"SyntaxError", "SyntaxWarning",
|
||||
"Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike",
|
||||
"Value", "ValueCastable", "ValueLike",
|
||||
"Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal",
|
||||
"Format", "Print", "Assert", "Assume", "Cover",
|
||||
"IOValue", "IOPort",
|
||||
# _dsl
|
||||
"SyntaxError", "SyntaxWarning", "Module",
|
||||
"Module",
|
||||
# _cd
|
||||
"DomainError", "ClockDomain",
|
||||
# _ir
|
||||
|
|
|
@ -16,6 +16,7 @@ from .._unused import *
|
|||
|
||||
|
||||
__all__ = [
|
||||
"SyntaxError", "SyntaxWarning",
|
||||
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
|
||||
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
|
||||
"Array", "ArrayProxy",
|
||||
|
@ -30,6 +31,14 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
class SyntaxError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SyntaxWarning(Warning):
|
||||
pass
|
||||
|
||||
|
||||
class DUID:
|
||||
"""Deterministic Unique IDentifier."""
|
||||
__next_uid = 0
|
||||
|
@ -426,6 +435,37 @@ class ShapeLike(metaclass=_ShapeLikeMeta):
|
|||
raise TypeError("ShapeLike is an abstract class and cannot be instantiated")
|
||||
|
||||
|
||||
def _normalize_patterns(patterns, shape, *, src_loc_at=1):
|
||||
new_patterns = []
|
||||
for pattern in patterns:
|
||||
orig_pattern = pattern
|
||||
if isinstance(pattern, str):
|
||||
if any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't "
|
||||
f"care) bits, and may include whitespace")
|
||||
pattern = "".join(pattern.split()) # remove whitespace
|
||||
if len(pattern) != shape.width:
|
||||
raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as "
|
||||
f"match value (which is {shape.width})")
|
||||
else:
|
||||
try:
|
||||
pattern = Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError(f"Pattern must be a string or a constant-castable "
|
||||
f"expression, not {pattern!r}") from e
|
||||
cast_pattern = Const(pattern.value, shape)
|
||||
if cast_pattern.value != pattern.value:
|
||||
warnings.warn(f"Pattern '{orig_pattern!r}' "
|
||||
f"({pattern.shape().width}'{pattern.value:b}) is not "
|
||||
f"representable in match value shape "
|
||||
f"({shape!r}); comparison will never be true",
|
||||
SyntaxWarning, stacklevel=2 + src_loc_at)
|
||||
continue
|
||||
pattern = pattern.value
|
||||
new_patterns.append(pattern)
|
||||
return tuple(new_patterns)
|
||||
|
||||
|
||||
def _overridable_by_reflected(method_name):
|
||||
"""Allow overriding the decorated method.
|
||||
|
||||
|
@ -1248,36 +1288,12 @@ class Value(metaclass=ABCMeta):
|
|||
If a pattern has invalid syntax.
|
||||
"""
|
||||
matches = []
|
||||
# This code should accept exactly the same patterns as `with m.Case(...):`.
|
||||
for pattern in patterns:
|
||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||
"bits, and may include whitespace"
|
||||
.format(pattern))
|
||||
if (isinstance(pattern, str) and
|
||||
len("".join(pattern.split())) != len(self)):
|
||||
raise SyntaxError("Match pattern '{}' must have the same width as match value "
|
||||
"(which is {})"
|
||||
.format(pattern, len(self)))
|
||||
for pattern in _normalize_patterns(patterns, self.shape()):
|
||||
if isinstance(pattern, str):
|
||||
pattern = "".join(pattern.split()) # remove whitespace
|
||||
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
|
||||
pattern = int(pattern.replace("-", "0"), 2)
|
||||
mask = int("0" + pattern.replace("0", "1").replace("-", "0"), 2)
|
||||
pattern = int("0" + pattern.replace("-", "0"), 2)
|
||||
matches.append((self & mask) == pattern)
|
||||
else:
|
||||
try:
|
||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError("Match pattern must be a string or a constant-castable "
|
||||
"expression, not {!r}"
|
||||
.format(pattern)) from e
|
||||
pattern_len = bits_for(pattern.value)
|
||||
if pattern_len > len(self):
|
||||
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(orig_pattern, pattern_len, pattern.value, len(self)),
|
||||
SyntaxWarning, stacklevel=2)
|
||||
continue
|
||||
matches.append(self == pattern)
|
||||
if not matches:
|
||||
return Const(0)
|
||||
|
@ -2770,17 +2786,9 @@ class Switch(Statement):
|
|||
# Map: 2 -> "0010"; "0010" -> "0010"
|
||||
new_keys = ()
|
||||
key_mask = (1 << len(self.test)) - 1
|
||||
for key in keys:
|
||||
if isinstance(key, str):
|
||||
key = "".join(key.split()) # remove whitespace
|
||||
elif isinstance(key, int):
|
||||
for key in _normalize_patterns(keys, self._test.shape()):
|
||||
if isinstance(key, int):
|
||||
key = to_binary(key & key_mask, len(self.test))
|
||||
elif isinstance(key, Enum):
|
||||
key = to_binary(key.value & key_mask, len(self.test))
|
||||
else:
|
||||
raise TypeError("Object {!r} cannot be used as a switch key"
|
||||
.format(key))
|
||||
assert len(key) == len(self.test)
|
||||
new_keys = (*new_keys, key)
|
||||
if not isinstance(stmts, Iterable):
|
||||
stmts = [stmts]
|
||||
|
|
|
@ -9,7 +9,7 @@ from .._utils import flatten
|
|||
from ..utils import bits_for
|
||||
from .. import tracer
|
||||
from ._ast import *
|
||||
from ._ast import _StatementList, _LateBoundStatement, Property, Print
|
||||
from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns
|
||||
from ._ir import *
|
||||
from ._cd import *
|
||||
from ._xfrm import *
|
||||
|
@ -18,14 +18,6 @@ from ._xfrm import *
|
|||
__all__ = ["SyntaxError", "SyntaxWarning", "Module"]
|
||||
|
||||
|
||||
class SyntaxError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SyntaxWarning(Warning):
|
||||
pass
|
||||
|
||||
|
||||
class _ModuleBuilderProxy:
|
||||
def __init__(self, builder, depth):
|
||||
object.__setattr__(self, "_builder", builder)
|
||||
|
@ -344,41 +336,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
self._check_context("Case", context="Switch")
|
||||
src_loc = tracer.get_src_loc(src_loc_at=1)
|
||||
switch_data = self._get_ctrl("Switch")
|
||||
new_patterns = ()
|
||||
if () in switch_data["cases"]:
|
||||
warnings.warn("A case defined after the default case will never be active",
|
||||
SyntaxWarning, stacklevel=3)
|
||||
# This code should accept exactly the same patterns as `v.matches(...)`.
|
||||
for pattern in patterns:
|
||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||
"bits, and may include whitespace"
|
||||
.format(pattern))
|
||||
if (isinstance(pattern, str) and
|
||||
len("".join(pattern.split())) != len(switch_data["test"])):
|
||||
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
|
||||
"(which is {})"
|
||||
.format(pattern, len(switch_data["test"])))
|
||||
if isinstance(pattern, str):
|
||||
new_patterns = (*new_patterns, pattern)
|
||||
else:
|
||||
try:
|
||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError("Case pattern must be a string or a constant-castable "
|
||||
"expression, not {!r}"
|
||||
.format(pattern)) from e
|
||||
pattern_len = bits_for(pattern.value)
|
||||
if pattern.value == 0:
|
||||
pattern_len = 0
|
||||
if pattern_len > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(orig_pattern, pattern_len, pattern.value,
|
||||
len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
new_patterns = (*new_patterns, pattern.value)
|
||||
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
|
||||
try:
|
||||
_outer_case, self._statements = self._statements, {}
|
||||
self._ctrl_context = None
|
||||
|
|
|
@ -2,7 +2,7 @@ import enum as py_enum
|
|||
import warnings
|
||||
import operator
|
||||
|
||||
from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const
|
||||
from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning
|
||||
from ..hdl._repr import *
|
||||
|
||||
|
||||
|
|
|
@ -795,7 +795,7 @@ class OperatorTestCase(FHDLTestCase):
|
|||
def test_matches_enum(self):
|
||||
s = Signal(SignedEnum)
|
||||
self.assertRepr(s.matches(SignedEnum.FOO), """
|
||||
(== (sig s) (const 2'sd-1))
|
||||
(== (sig s) (const 1'sd-1))
|
||||
""")
|
||||
|
||||
def test_matches_const_castable(self):
|
||||
|
@ -807,28 +807,28 @@ class OperatorTestCase(FHDLTestCase):
|
|||
def test_matches_width_wrong(self):
|
||||
s = Signal(4)
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"):
|
||||
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
|
||||
s.matches("--")
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
r"^Match pattern '22' \(5'10110\) is wider than match value \(which has "
|
||||
r"width 4\); comparison will never be true$"):
|
||||
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
|
||||
r"\(unsigned\(4\)\); comparison will never be true$"):
|
||||
s.matches(0b10110)
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider "
|
||||
r"than match value \(which has width 4\); comparison will never be true$"):
|
||||
r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not "
|
||||
r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"):
|
||||
s.matches(Cat(0, C(0b1011, 4)))
|
||||
|
||||
def test_matches_bits_wrong(self):
|
||||
s = Signal(4)
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||
r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||
r"and may include whitespace$"):
|
||||
s.matches("abc")
|
||||
|
||||
def test_matches_pattern_wrong(self):
|
||||
s = Signal(4)
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"):
|
||||
r"^Pattern must be a string or a constant-castable expression, not 1\.0$"):
|
||||
s.matches(1.0)
|
||||
|
||||
def test_hash(self):
|
||||
|
@ -1695,7 +1695,7 @@ class SwitchTestCase(FHDLTestCase):
|
|||
self.assertEqual(s.cases, {("00001010",): []})
|
||||
|
||||
def test_int_neg_case(self):
|
||||
s = Switch(Const(0, 8), {-10: []})
|
||||
s = Switch(Const(0, signed(8)), {-10: []})
|
||||
self.assertEqual(s.cases, {("11110110",): []})
|
||||
|
||||
def test_int_zero_width(self):
|
||||
|
|
|
@ -487,17 +487,17 @@ class DSLTestCase(FHDLTestCase):
|
|||
dummy = Signal()
|
||||
with m.Switch(self.w1):
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"):
|
||||
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
|
||||
with m.Case("--"):
|
||||
m.d.comb += dummy.eq(0)
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
|
||||
r"width 4\); comparison will never be true$"):
|
||||
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
|
||||
r"\(unsigned\(4\)\); comparison will never be true$"):
|
||||
with m.Case(0b10110):
|
||||
m.d.comb += dummy.eq(0)
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
|
||||
r"\(which has width 4\); comparison will never be true$"):
|
||||
r"^Pattern '<Color.RED: 170>' \(8'10101010\) is not representable in "
|
||||
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
|
||||
with m.Case(Color.RED):
|
||||
m.d.comb += dummy.eq(0)
|
||||
self.assertEqual(m._statements, {})
|
||||
|
@ -521,7 +521,7 @@ class DSLTestCase(FHDLTestCase):
|
|||
m = Module()
|
||||
with m.Switch(self.w1):
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
(r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||
(r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||
r"and may include whitespace$")):
|
||||
with m.Case("abc"):
|
||||
pass
|
||||
|
@ -530,7 +530,7 @@ class DSLTestCase(FHDLTestCase):
|
|||
m = Module()
|
||||
with m.Switch(self.w1):
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Case pattern must be a string or a constant-castable expression, "
|
||||
r"^Pattern must be a string or a constant-castable expression, "
|
||||
r"not 1\.0$"):
|
||||
with m.Case(1.0):
|
||||
pass
|
||||
|
|
|
@ -4,6 +4,7 @@ import sys
|
|||
import unittest
|
||||
|
||||
from amaranth import *
|
||||
from amaranth.hdl import *
|
||||
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
|
||||
|
||||
from .utils import *
|
||||
|
|
Loading…
Reference in a new issue