hdl.{_ast,_dsl}: factor out the pattern normalization logic.

This commit is contained in:
Wanda 2024-04-02 22:20:13 +02:00 committed by Catherine
parent 0e4c2de725
commit cd6cbd71ca
7 changed files with 69 additions and 97 deletions

View file

@ -1,9 +1,10 @@
from ._ast import SyntaxError, SyntaxWarning
from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike
from ._ast import Value, ValueCastable, ValueLike from ._ast import Value, ValueCastable, ValueLike
from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal
from ._ast import Format, Print, Assert, Assume, Cover from ._ast import Format, Print, Assert, Assume, Cover
from ._ast import IOValue, IOPort from ._ast import IOValue, IOPort
from ._dsl import SyntaxError, SyntaxWarning, Module from ._dsl import Module
from ._cd import DomainError, ClockDomain from ._cd import DomainError, ClockDomain
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
from ._ir import Instance, IOBufferInstance from ._ir import Instance, IOBufferInstance
@ -14,13 +15,14 @@ from ._xfrm import DomainRenamer, ResetInserter, EnableInserter
__all__ = [ __all__ = [
# _ast # _ast
"SyntaxError", "SyntaxWarning",
"Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike", "Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike",
"Value", "ValueCastable", "ValueLike", "Value", "ValueCastable", "ValueLike",
"Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal", "Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal",
"Format", "Print", "Assert", "Assume", "Cover", "Format", "Print", "Assert", "Assume", "Cover",
"IOValue", "IOPort", "IOValue", "IOPort",
# _dsl # _dsl
"SyntaxError", "SyntaxWarning", "Module", "Module",
# _cd # _cd
"DomainError", "ClockDomain", "DomainError", "ClockDomain",
# _ir # _ir

View file

@ -16,6 +16,7 @@ from .._unused import *
__all__ = [ __all__ = [
"SyntaxError", "SyntaxWarning",
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike", "Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
"Array", "ArrayProxy", "Array", "ArrayProxy",
@ -30,6 +31,14 @@ __all__ = [
] ]
class SyntaxError(Exception):
pass
class SyntaxWarning(Warning):
pass
class DUID: class DUID:
"""Deterministic Unique IDentifier.""" """Deterministic Unique IDentifier."""
__next_uid = 0 __next_uid = 0
@ -426,6 +435,37 @@ class ShapeLike(metaclass=_ShapeLikeMeta):
raise TypeError("ShapeLike is an abstract class and cannot be instantiated") raise TypeError("ShapeLike is an abstract class and cannot be instantiated")
def _normalize_patterns(patterns, shape, *, src_loc_at=1):
new_patterns = []
for pattern in patterns:
orig_pattern = pattern
if isinstance(pattern, str):
if any(bit not in "01- \t" for bit in pattern):
raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't "
f"care) bits, and may include whitespace")
pattern = "".join(pattern.split()) # remove whitespace
if len(pattern) != shape.width:
raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as "
f"match value (which is {shape.width})")
else:
try:
pattern = Const.cast(pattern)
except TypeError as e:
raise SyntaxError(f"Pattern must be a string or a constant-castable "
f"expression, not {pattern!r}") from e
cast_pattern = Const(pattern.value, shape)
if cast_pattern.value != pattern.value:
warnings.warn(f"Pattern '{orig_pattern!r}' "
f"({pattern.shape().width}'{pattern.value:b}) is not "
f"representable in match value shape "
f"({shape!r}); comparison will never be true",
SyntaxWarning, stacklevel=2 + src_loc_at)
continue
pattern = pattern.value
new_patterns.append(pattern)
return tuple(new_patterns)
def _overridable_by_reflected(method_name): def _overridable_by_reflected(method_name):
"""Allow overriding the decorated method. """Allow overriding the decorated method.
@ -1248,36 +1288,12 @@ class Value(metaclass=ABCMeta):
If a pattern has invalid syntax. If a pattern has invalid syntax.
""" """
matches = [] matches = []
# This code should accept exactly the same patterns as `with m.Case(...):`. for pattern in _normalize_patterns(patterns, self.shape()):
for pattern in patterns:
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace"
.format(pattern))
if (isinstance(pattern, str) and
len("".join(pattern.split())) != len(self)):
raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})"
.format(pattern, len(self)))
if isinstance(pattern, str): if isinstance(pattern, str):
pattern = "".join(pattern.split()) # remove whitespace mask = int("0" + pattern.replace("0", "1").replace("-", "0"), 2)
mask = int(pattern.replace("0", "1").replace("-", "0"), 2) pattern = int("0" + pattern.replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern) matches.append((self & mask) == pattern)
else: else:
try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Match pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern_len > len(self):
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value, len(self)),
SyntaxWarning, stacklevel=2)
continue
matches.append(self == pattern) matches.append(self == pattern)
if not matches: if not matches:
return Const(0) return Const(0)
@ -2770,17 +2786,9 @@ class Switch(Statement):
# Map: 2 -> "0010"; "0010" -> "0010" # Map: 2 -> "0010"; "0010" -> "0010"
new_keys = () new_keys = ()
key_mask = (1 << len(self.test)) - 1 key_mask = (1 << len(self.test)) - 1
for key in keys: for key in _normalize_patterns(keys, self._test.shape()):
if isinstance(key, str): if isinstance(key, int):
key = "".join(key.split()) # remove whitespace
elif isinstance(key, int):
key = to_binary(key & key_mask, len(self.test)) key = to_binary(key & key_mask, len(self.test))
elif isinstance(key, Enum):
key = to_binary(key.value & key_mask, len(self.test))
else:
raise TypeError("Object {!r} cannot be used as a switch key"
.format(key))
assert len(key) == len(self.test)
new_keys = (*new_keys, key) new_keys = (*new_keys, key)
if not isinstance(stmts, Iterable): if not isinstance(stmts, Iterable):
stmts = [stmts] stmts = [stmts]

View file

@ -9,7 +9,7 @@ from .._utils import flatten
from ..utils import bits_for from ..utils import bits_for
from .. import tracer from .. import tracer
from ._ast import * from ._ast import *
from ._ast import _StatementList, _LateBoundStatement, Property, Print from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns
from ._ir import * from ._ir import *
from ._cd import * from ._cd import *
from ._xfrm import * from ._xfrm import *
@ -18,14 +18,6 @@ from ._xfrm import *
__all__ = ["SyntaxError", "SyntaxWarning", "Module"] __all__ = ["SyntaxError", "SyntaxWarning", "Module"]
class SyntaxError(Exception):
pass
class SyntaxWarning(Warning):
pass
class _ModuleBuilderProxy: class _ModuleBuilderProxy:
def __init__(self, builder, depth): def __init__(self, builder, depth):
object.__setattr__(self, "_builder", builder) object.__setattr__(self, "_builder", builder)
@ -344,41 +336,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._check_context("Case", context="Switch") self._check_context("Case", context="Switch")
src_loc = tracer.get_src_loc(src_loc_at=1) src_loc = tracer.get_src_loc(src_loc_at=1)
switch_data = self._get_ctrl("Switch") switch_data = self._get_ctrl("Switch")
new_patterns = ()
if () in switch_data["cases"]: if () in switch_data["cases"]:
warnings.warn("A case defined after the default case will never be active", warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3) SyntaxWarning, stacklevel=3)
# This code should accept exactly the same patterns as `v.matches(...)`. new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
for pattern in patterns:
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace"
.format(pattern))
if (isinstance(pattern, str) and
len("".join(pattern.split())) != len(switch_data["test"])):
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})"
.format(pattern, len(switch_data["test"])))
if isinstance(pattern, str):
new_patterns = (*new_patterns, pattern)
else:
try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Case pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern.value == 0:
pattern_len = 0
if pattern_len > len(switch_data["test"]):
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value,
len(switch_data["test"])),
SyntaxWarning, stacklevel=3)
continue
new_patterns = (*new_patterns, pattern.value)
try: try:
_outer_case, self._statements = self._statements, {} _outer_case, self._statements = self._statements, {}
self._ctrl_context = None self._ctrl_context = None

View file

@ -2,7 +2,7 @@ import enum as py_enum
import warnings import warnings
import operator import operator
from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning
from ..hdl._repr import * from ..hdl._repr import *

View file

@ -795,7 +795,7 @@ class OperatorTestCase(FHDLTestCase):
def test_matches_enum(self): def test_matches_enum(self):
s = Signal(SignedEnum) s = Signal(SignedEnum)
self.assertRepr(s.matches(SignedEnum.FOO), """ self.assertRepr(s.matches(SignedEnum.FOO), """
(== (sig s) (const 2'sd-1)) (== (sig s) (const 1'sd-1))
""") """)
def test_matches_const_castable(self): def test_matches_const_castable(self):
@ -807,28 +807,28 @@ class OperatorTestCase(FHDLTestCase):
def test_matches_width_wrong(self): def test_matches_width_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"): r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
s.matches("--") s.matches("--")
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Match pattern '22' \(5'10110\) is wider than match value \(which has " r"^Pattern '22' \(5'10110\) is not representable in match value shape "
r"width 4\); comparison will never be true$"): r"\(unsigned\(4\)\); comparison will never be true$"):
s.matches(0b10110) s.matches(0b10110)
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider " r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not "
r"than match value \(which has width 4\); comparison will never be true$"): r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"):
s.matches(Cat(0, C(0b1011, 4))) s.matches(Cat(0, C(0b1011, 4)))
def test_matches_bits_wrong(self): def test_matches_bits_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
r"and may include whitespace$"): r"and may include whitespace$"):
s.matches("abc") s.matches("abc")
def test_matches_pattern_wrong(self): def test_matches_pattern_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"): r"^Pattern must be a string or a constant-castable expression, not 1\.0$"):
s.matches(1.0) s.matches(1.0)
def test_hash(self): def test_hash(self):
@ -1695,7 +1695,7 @@ class SwitchTestCase(FHDLTestCase):
self.assertEqual(s.cases, {("00001010",): []}) self.assertEqual(s.cases, {("00001010",): []})
def test_int_neg_case(self): def test_int_neg_case(self):
s = Switch(Const(0, 8), {-10: []}) s = Switch(Const(0, signed(8)), {-10: []})
self.assertEqual(s.cases, {("11110110",): []}) self.assertEqual(s.cases, {("11110110",): []})
def test_int_zero_width(self): def test_int_zero_width(self):

View file

@ -487,17 +487,17 @@ class DSLTestCase(FHDLTestCase):
dummy = Signal() dummy = Signal()
with m.Switch(self.w1): with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"): r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
with m.Case("--"): with m.Case("--"):
m.d.comb += dummy.eq(0) m.d.comb += dummy.eq(0)
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has " r"^Pattern '22' \(5'10110\) is not representable in match value shape "
r"width 4\); comparison will never be true$"): r"\(unsigned\(4\)\); comparison will never be true$"):
with m.Case(0b10110): with m.Case(0b10110):
m.d.comb += dummy.eq(0) m.d.comb += dummy.eq(0)
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value " r"^Pattern '<Color.RED: 170>' \(8'10101010\) is not representable in "
r"\(which has width 4\); comparison will never be true$"): r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
with m.Case(Color.RED): with m.Case(Color.RED):
m.d.comb += dummy.eq(0) m.d.comb += dummy.eq(0)
self.assertEqual(m._statements, {}) self.assertEqual(m._statements, {})
@ -521,7 +521,7 @@ class DSLTestCase(FHDLTestCase):
m = Module() m = Module()
with m.Switch(self.w1): with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
(r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " (r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
r"and may include whitespace$")): r"and may include whitespace$")):
with m.Case("abc"): with m.Case("abc"):
pass pass
@ -530,7 +530,7 @@ class DSLTestCase(FHDLTestCase):
m = Module() m = Module()
with m.Switch(self.w1): with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Case pattern must be a string or a constant-castable expression, " r"^Pattern must be a string or a constant-castable expression, "
r"not 1\.0$"): r"not 1\.0$"):
with m.Case(1.0): with m.Case(1.0):
pass pass

View file

@ -4,6 +4,7 @@ import sys
import unittest import unittest
from amaranth import * from amaranth import *
from amaranth.hdl import *
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
from .utils import * from .utils import *