hdl.ast: add Value.matches(), accepting same language as Case().

Fixes #202.
This commit is contained in:
whitequark 2019-09-14 21:06:12 +00:00
parent f292a1977c
commit e8f79c5539
2 changed files with 89 additions and 4 deletions

View file

@ -191,9 +191,9 @@ class Value(metaclass=ABCMeta):
Parameters
----------
offset : Value, in
index of first selected bit
Index of first selected bit.
width : int
number of selected bits
Number of selected bits.
Returns
-------
@ -211,9 +211,9 @@ class Value(metaclass=ABCMeta):
Parameters
----------
offset : Value, in
index of first selected word
Index of first selected word.
width : int
number of selected bits
Number of selected bits.
Returns
-------
@ -222,6 +222,56 @@ class Value(metaclass=ABCMeta):
"""
return Part(self, offset, width, stride=width, src_loc_at=1)
def matches(self, *patterns):
"""Pattern matching.
Matches against a set of patterns, which may be integers or bit strings, recognizing
the same grammar as ``Case()``.
Parameters
----------
patterns : int or str
Patterns to match against.
Returns
-------
Value, out
``1`` if any pattern matches the value, ``0`` otherwise.
"""
matches = []
for pattern in patterns:
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits"
.format(pattern))
if isinstance(pattern, str) and len(pattern) != len(self):
raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})"
.format(pattern, len(self)))
if not isinstance(pattern, (int, str)):
raise SyntaxError("Match pattern must be an integer or a string, not {}"
.format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(self):
warnings.warn("Match pattern '{:b}' is wider than match value "
"(which has width {}); comparison will never be true"
.format(pattern, len(self)),
SyntaxWarning, stacklevel=3)
continue
if isinstance(pattern, int):
matches.append(self == pattern)
elif isinstance(pattern, str):
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern)
else:
assert False
if not matches:
return Const(0)
elif len(matches) == 1:
return matches[0]
else:
return Cat(*matches).any()
def eq(self, value):
"""Assignment.

View file

@ -263,6 +263,41 @@ class OperatorTestCase(FHDLTestCase):
v = Const(0b101).xor()
self.assertEqual(repr(v), "(r^ (const 3'd5))")
def test_matches(self):
s = Signal(4)
self.assertRepr(s.matches(), "(const 1'd0)")
self.assertRepr(s.matches(1), """
(== (sig s) (const 1'd1))
""")
self.assertRepr(s.matches(0, 1), """
(r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
""")
self.assertRepr(s.matches("10--"), """
(== (& (sig s) (const 4'd12)) (const 4'd8))
""")
def test_matches_width_wrong(self):
s = Signal(4)
with self.assertRaises(SyntaxError,
msg="Match pattern '--' must have the same width as match value (which is 4)"):
s.matches("--")
with self.assertWarns(SyntaxWarning,
msg="Match pattern '10110' is wider than match value (which has width 4); "
"comparison will never be true"):
s.matches(0b10110)
def test_matches_bits_wrong(self):
s = Signal(4)
with self.assertRaises(SyntaxError,
msg="Match pattern 'abc' must consist of 0, 1, and - (don't care) bits"):
s.matches("abc")
def test_matches_pattern_wrong(self):
s = Signal(4)
with self.assertRaises(SyntaxError,
msg="Match pattern must be an integer or a string, not 1.0"):
s.matches(1.0)
def test_hash(self):
with self.assertRaises(TypeError):
hash(Const(0) + Const(0))