hdl.{_ast,_dsl}: factor out the pattern normalization logic.
This commit is contained in:
parent
0e4c2de725
commit
cd6cbd71ca
|
@ -1,9 +1,10 @@
|
||||||
|
from ._ast import SyntaxError, SyntaxWarning
|
||||||
from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike
|
from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike
|
||||||
from ._ast import Value, ValueCastable, ValueLike
|
from ._ast import Value, ValueCastable, ValueLike
|
||||||
from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal
|
from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal
|
||||||
from ._ast import Format, Print, Assert, Assume, Cover
|
from ._ast import Format, Print, Assert, Assume, Cover
|
||||||
from ._ast import IOValue, IOPort
|
from ._ast import IOValue, IOPort
|
||||||
from ._dsl import SyntaxError, SyntaxWarning, Module
|
from ._dsl import Module
|
||||||
from ._cd import DomainError, ClockDomain
|
from ._cd import DomainError, ClockDomain
|
||||||
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
|
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
|
||||||
from ._ir import Instance, IOBufferInstance
|
from ._ir import Instance, IOBufferInstance
|
||||||
|
@ -14,13 +15,14 @@ from ._xfrm import DomainRenamer, ResetInserter, EnableInserter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# _ast
|
# _ast
|
||||||
|
"SyntaxError", "SyntaxWarning",
|
||||||
"Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike",
|
"Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike",
|
||||||
"Value", "ValueCastable", "ValueLike",
|
"Value", "ValueCastable", "ValueLike",
|
||||||
"Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal",
|
"Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal",
|
||||||
"Format", "Print", "Assert", "Assume", "Cover",
|
"Format", "Print", "Assert", "Assume", "Cover",
|
||||||
"IOValue", "IOPort",
|
"IOValue", "IOPort",
|
||||||
# _dsl
|
# _dsl
|
||||||
"SyntaxError", "SyntaxWarning", "Module",
|
"Module",
|
||||||
# _cd
|
# _cd
|
||||||
"DomainError", "ClockDomain",
|
"DomainError", "ClockDomain",
|
||||||
# _ir
|
# _ir
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .._unused import *
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"SyntaxError", "SyntaxWarning",
|
||||||
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
|
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
|
||||||
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
|
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
|
||||||
"Array", "ArrayProxy",
|
"Array", "ArrayProxy",
|
||||||
|
@ -30,6 +31,14 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SyntaxError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SyntaxWarning(Warning):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DUID:
|
class DUID:
|
||||||
"""Deterministic Unique IDentifier."""
|
"""Deterministic Unique IDentifier."""
|
||||||
__next_uid = 0
|
__next_uid = 0
|
||||||
|
@ -426,6 +435,37 @@ class ShapeLike(metaclass=_ShapeLikeMeta):
|
||||||
raise TypeError("ShapeLike is an abstract class and cannot be instantiated")
|
raise TypeError("ShapeLike is an abstract class and cannot be instantiated")
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_patterns(patterns, shape, *, src_loc_at=1):
|
||||||
|
new_patterns = []
|
||||||
|
for pattern in patterns:
|
||||||
|
orig_pattern = pattern
|
||||||
|
if isinstance(pattern, str):
|
||||||
|
if any(bit not in "01- \t" for bit in pattern):
|
||||||
|
raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't "
|
||||||
|
f"care) bits, and may include whitespace")
|
||||||
|
pattern = "".join(pattern.split()) # remove whitespace
|
||||||
|
if len(pattern) != shape.width:
|
||||||
|
raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as "
|
||||||
|
f"match value (which is {shape.width})")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
pattern = Const.cast(pattern)
|
||||||
|
except TypeError as e:
|
||||||
|
raise SyntaxError(f"Pattern must be a string or a constant-castable "
|
||||||
|
f"expression, not {pattern!r}") from e
|
||||||
|
cast_pattern = Const(pattern.value, shape)
|
||||||
|
if cast_pattern.value != pattern.value:
|
||||||
|
warnings.warn(f"Pattern '{orig_pattern!r}' "
|
||||||
|
f"({pattern.shape().width}'{pattern.value:b}) is not "
|
||||||
|
f"representable in match value shape "
|
||||||
|
f"({shape!r}); comparison will never be true",
|
||||||
|
SyntaxWarning, stacklevel=2 + src_loc_at)
|
||||||
|
continue
|
||||||
|
pattern = pattern.value
|
||||||
|
new_patterns.append(pattern)
|
||||||
|
return tuple(new_patterns)
|
||||||
|
|
||||||
|
|
||||||
def _overridable_by_reflected(method_name):
|
def _overridable_by_reflected(method_name):
|
||||||
"""Allow overriding the decorated method.
|
"""Allow overriding the decorated method.
|
||||||
|
|
||||||
|
@ -1248,36 +1288,12 @@ class Value(metaclass=ABCMeta):
|
||||||
If a pattern has invalid syntax.
|
If a pattern has invalid syntax.
|
||||||
"""
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
# This code should accept exactly the same patterns as `with m.Case(...):`.
|
for pattern in _normalize_patterns(patterns, self.shape()):
|
||||||
for pattern in patterns:
|
|
||||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
|
||||||
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
|
|
||||||
"bits, and may include whitespace"
|
|
||||||
.format(pattern))
|
|
||||||
if (isinstance(pattern, str) and
|
|
||||||
len("".join(pattern.split())) != len(self)):
|
|
||||||
raise SyntaxError("Match pattern '{}' must have the same width as match value "
|
|
||||||
"(which is {})"
|
|
||||||
.format(pattern, len(self)))
|
|
||||||
if isinstance(pattern, str):
|
if isinstance(pattern, str):
|
||||||
pattern = "".join(pattern.split()) # remove whitespace
|
mask = int("0" + pattern.replace("0", "1").replace("-", "0"), 2)
|
||||||
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
|
pattern = int("0" + pattern.replace("-", "0"), 2)
|
||||||
pattern = int(pattern.replace("-", "0"), 2)
|
|
||||||
matches.append((self & mask) == pattern)
|
matches.append((self & mask) == pattern)
|
||||||
else:
|
else:
|
||||||
try:
|
|
||||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
|
||||||
except TypeError as e:
|
|
||||||
raise SyntaxError("Match pattern must be a string or a constant-castable "
|
|
||||||
"expression, not {!r}"
|
|
||||||
.format(pattern)) from e
|
|
||||||
pattern_len = bits_for(pattern.value)
|
|
||||||
if pattern_len > len(self):
|
|
||||||
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
|
|
||||||
"(which has width {}); comparison will never be true"
|
|
||||||
.format(orig_pattern, pattern_len, pattern.value, len(self)),
|
|
||||||
SyntaxWarning, stacklevel=2)
|
|
||||||
continue
|
|
||||||
matches.append(self == pattern)
|
matches.append(self == pattern)
|
||||||
if not matches:
|
if not matches:
|
||||||
return Const(0)
|
return Const(0)
|
||||||
|
@ -2770,17 +2786,9 @@ class Switch(Statement):
|
||||||
# Map: 2 -> "0010"; "0010" -> "0010"
|
# Map: 2 -> "0010"; "0010" -> "0010"
|
||||||
new_keys = ()
|
new_keys = ()
|
||||||
key_mask = (1 << len(self.test)) - 1
|
key_mask = (1 << len(self.test)) - 1
|
||||||
for key in keys:
|
for key in _normalize_patterns(keys, self._test.shape()):
|
||||||
if isinstance(key, str):
|
if isinstance(key, int):
|
||||||
key = "".join(key.split()) # remove whitespace
|
|
||||||
elif isinstance(key, int):
|
|
||||||
key = to_binary(key & key_mask, len(self.test))
|
key = to_binary(key & key_mask, len(self.test))
|
||||||
elif isinstance(key, Enum):
|
|
||||||
key = to_binary(key.value & key_mask, len(self.test))
|
|
||||||
else:
|
|
||||||
raise TypeError("Object {!r} cannot be used as a switch key"
|
|
||||||
.format(key))
|
|
||||||
assert len(key) == len(self.test)
|
|
||||||
new_keys = (*new_keys, key)
|
new_keys = (*new_keys, key)
|
||||||
if not isinstance(stmts, Iterable):
|
if not isinstance(stmts, Iterable):
|
||||||
stmts = [stmts]
|
stmts = [stmts]
|
||||||
|
|
|
@ -9,7 +9,7 @@ from .._utils import flatten
|
||||||
from ..utils import bits_for
|
from ..utils import bits_for
|
||||||
from .. import tracer
|
from .. import tracer
|
||||||
from ._ast import *
|
from ._ast import *
|
||||||
from ._ast import _StatementList, _LateBoundStatement, Property, Print
|
from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns
|
||||||
from ._ir import *
|
from ._ir import *
|
||||||
from ._cd import *
|
from ._cd import *
|
||||||
from ._xfrm import *
|
from ._xfrm import *
|
||||||
|
@ -18,14 +18,6 @@ from ._xfrm import *
|
||||||
__all__ = ["SyntaxError", "SyntaxWarning", "Module"]
|
__all__ = ["SyntaxError", "SyntaxWarning", "Module"]
|
||||||
|
|
||||||
|
|
||||||
class SyntaxError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SyntaxWarning(Warning):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _ModuleBuilderProxy:
|
class _ModuleBuilderProxy:
|
||||||
def __init__(self, builder, depth):
|
def __init__(self, builder, depth):
|
||||||
object.__setattr__(self, "_builder", builder)
|
object.__setattr__(self, "_builder", builder)
|
||||||
|
@ -344,41 +336,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
||||||
self._check_context("Case", context="Switch")
|
self._check_context("Case", context="Switch")
|
||||||
src_loc = tracer.get_src_loc(src_loc_at=1)
|
src_loc = tracer.get_src_loc(src_loc_at=1)
|
||||||
switch_data = self._get_ctrl("Switch")
|
switch_data = self._get_ctrl("Switch")
|
||||||
new_patterns = ()
|
|
||||||
if () in switch_data["cases"]:
|
if () in switch_data["cases"]:
|
||||||
warnings.warn("A case defined after the default case will never be active",
|
warnings.warn("A case defined after the default case will never be active",
|
||||||
SyntaxWarning, stacklevel=3)
|
SyntaxWarning, stacklevel=3)
|
||||||
# This code should accept exactly the same patterns as `v.matches(...)`.
|
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
|
||||||
for pattern in patterns:
|
|
||||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
|
||||||
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
|
|
||||||
"bits, and may include whitespace"
|
|
||||||
.format(pattern))
|
|
||||||
if (isinstance(pattern, str) and
|
|
||||||
len("".join(pattern.split())) != len(switch_data["test"])):
|
|
||||||
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
|
|
||||||
"(which is {})"
|
|
||||||
.format(pattern, len(switch_data["test"])))
|
|
||||||
if isinstance(pattern, str):
|
|
||||||
new_patterns = (*new_patterns, pattern)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
|
||||||
except TypeError as e:
|
|
||||||
raise SyntaxError("Case pattern must be a string or a constant-castable "
|
|
||||||
"expression, not {!r}"
|
|
||||||
.format(pattern)) from e
|
|
||||||
pattern_len = bits_for(pattern.value)
|
|
||||||
if pattern.value == 0:
|
|
||||||
pattern_len = 0
|
|
||||||
if pattern_len > len(switch_data["test"]):
|
|
||||||
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
|
|
||||||
"(which has width {}); comparison will never be true"
|
|
||||||
.format(orig_pattern, pattern_len, pattern.value,
|
|
||||||
len(switch_data["test"])),
|
|
||||||
SyntaxWarning, stacklevel=3)
|
|
||||||
continue
|
|
||||||
new_patterns = (*new_patterns, pattern.value)
|
|
||||||
try:
|
try:
|
||||||
_outer_case, self._statements = self._statements, {}
|
_outer_case, self._statements = self._statements, {}
|
||||||
self._ctrl_context = None
|
self._ctrl_context = None
|
||||||
|
|
|
@ -2,7 +2,7 @@ import enum as py_enum
|
||||||
import warnings
|
import warnings
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const
|
from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning
|
||||||
from ..hdl._repr import *
|
from ..hdl._repr import *
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -795,7 +795,7 @@ class OperatorTestCase(FHDLTestCase):
|
||||||
def test_matches_enum(self):
|
def test_matches_enum(self):
|
||||||
s = Signal(SignedEnum)
|
s = Signal(SignedEnum)
|
||||||
self.assertRepr(s.matches(SignedEnum.FOO), """
|
self.assertRepr(s.matches(SignedEnum.FOO), """
|
||||||
(== (sig s) (const 2'sd-1))
|
(== (sig s) (const 1'sd-1))
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_matches_const_castable(self):
|
def test_matches_const_castable(self):
|
||||||
|
@ -807,28 +807,28 @@ class OperatorTestCase(FHDLTestCase):
|
||||||
def test_matches_width_wrong(self):
|
def test_matches_width_wrong(self):
|
||||||
s = Signal(4)
|
s = Signal(4)
|
||||||
with self.assertRaisesRegex(SyntaxError,
|
with self.assertRaisesRegex(SyntaxError,
|
||||||
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"):
|
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
|
||||||
s.matches("--")
|
s.matches("--")
|
||||||
with self.assertWarnsRegex(SyntaxWarning,
|
with self.assertWarnsRegex(SyntaxWarning,
|
||||||
r"^Match pattern '22' \(5'10110\) is wider than match value \(which has "
|
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
|
||||||
r"width 4\); comparison will never be true$"):
|
r"\(unsigned\(4\)\); comparison will never be true$"):
|
||||||
s.matches(0b10110)
|
s.matches(0b10110)
|
||||||
with self.assertWarnsRegex(SyntaxWarning,
|
with self.assertWarnsRegex(SyntaxWarning,
|
||||||
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider "
|
r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not "
|
||||||
r"than match value \(which has width 4\); comparison will never be true$"):
|
r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"):
|
||||||
s.matches(Cat(0, C(0b1011, 4)))
|
s.matches(Cat(0, C(0b1011, 4)))
|
||||||
|
|
||||||
def test_matches_bits_wrong(self):
|
def test_matches_bits_wrong(self):
|
||||||
s = Signal(4)
|
s = Signal(4)
|
||||||
with self.assertRaisesRegex(SyntaxError,
|
with self.assertRaisesRegex(SyntaxError,
|
||||||
r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||||
r"and may include whitespace$"):
|
r"and may include whitespace$"):
|
||||||
s.matches("abc")
|
s.matches("abc")
|
||||||
|
|
||||||
def test_matches_pattern_wrong(self):
|
def test_matches_pattern_wrong(self):
|
||||||
s = Signal(4)
|
s = Signal(4)
|
||||||
with self.assertRaisesRegex(SyntaxError,
|
with self.assertRaisesRegex(SyntaxError,
|
||||||
r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"):
|
r"^Pattern must be a string or a constant-castable expression, not 1\.0$"):
|
||||||
s.matches(1.0)
|
s.matches(1.0)
|
||||||
|
|
||||||
def test_hash(self):
|
def test_hash(self):
|
||||||
|
@ -1695,7 +1695,7 @@ class SwitchTestCase(FHDLTestCase):
|
||||||
self.assertEqual(s.cases, {("00001010",): []})
|
self.assertEqual(s.cases, {("00001010",): []})
|
||||||
|
|
||||||
def test_int_neg_case(self):
|
def test_int_neg_case(self):
|
||||||
s = Switch(Const(0, 8), {-10: []})
|
s = Switch(Const(0, signed(8)), {-10: []})
|
||||||
self.assertEqual(s.cases, {("11110110",): []})
|
self.assertEqual(s.cases, {("11110110",): []})
|
||||||
|
|
||||||
def test_int_zero_width(self):
|
def test_int_zero_width(self):
|
||||||
|
|
|
@ -487,17 +487,17 @@ class DSLTestCase(FHDLTestCase):
|
||||||
dummy = Signal()
|
dummy = Signal()
|
||||||
with m.Switch(self.w1):
|
with m.Switch(self.w1):
|
||||||
with self.assertRaisesRegex(SyntaxError,
|
with self.assertRaisesRegex(SyntaxError,
|
||||||
r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"):
|
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
|
||||||
with m.Case("--"):
|
with m.Case("--"):
|
||||||
m.d.comb += dummy.eq(0)
|
m.d.comb += dummy.eq(0)
|
||||||
with self.assertWarnsRegex(SyntaxWarning,
|
with self.assertWarnsRegex(SyntaxWarning,
|
||||||
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
|
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
|
||||||
r"width 4\); comparison will never be true$"):
|
r"\(unsigned\(4\)\); comparison will never be true$"):
|
||||||
with m.Case(0b10110):
|
with m.Case(0b10110):
|
||||||
m.d.comb += dummy.eq(0)
|
m.d.comb += dummy.eq(0)
|
||||||
with self.assertWarnsRegex(SyntaxWarning,
|
with self.assertWarnsRegex(SyntaxWarning,
|
||||||
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
|
r"^Pattern '<Color.RED: 170>' \(8'10101010\) is not representable in "
|
||||||
r"\(which has width 4\); comparison will never be true$"):
|
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
|
||||||
with m.Case(Color.RED):
|
with m.Case(Color.RED):
|
||||||
m.d.comb += dummy.eq(0)
|
m.d.comb += dummy.eq(0)
|
||||||
self.assertEqual(m._statements, {})
|
self.assertEqual(m._statements, {})
|
||||||
|
@ -521,7 +521,7 @@ class DSLTestCase(FHDLTestCase):
|
||||||
m = Module()
|
m = Module()
|
||||||
with m.Switch(self.w1):
|
with m.Switch(self.w1):
|
||||||
with self.assertRaisesRegex(SyntaxError,
|
with self.assertRaisesRegex(SyntaxError,
|
||||||
(r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
(r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||||
r"and may include whitespace$")):
|
r"and may include whitespace$")):
|
||||||
with m.Case("abc"):
|
with m.Case("abc"):
|
||||||
pass
|
pass
|
||||||
|
@ -530,7 +530,7 @@ class DSLTestCase(FHDLTestCase):
|
||||||
m = Module()
|
m = Module()
|
||||||
with m.Switch(self.w1):
|
with m.Switch(self.w1):
|
||||||
with self.assertRaisesRegex(SyntaxError,
|
with self.assertRaisesRegex(SyntaxError,
|
||||||
r"^Case pattern must be a string or a constant-castable expression, "
|
r"^Pattern must be a string or a constant-castable expression, "
|
||||||
r"not 1\.0$"):
|
r"not 1\.0$"):
|
||||||
with m.Case(1.0):
|
with m.Case(1.0):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4,6 +4,7 @@ import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from amaranth import *
|
from amaranth import *
|
||||||
|
from amaranth.hdl import *
|
||||||
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
|
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
|
||||||
|
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
|
Loading…
Reference in a new issue