hdl: implement constant-castable expressions.

See #755 and amaranth-lang/rfcs#4.
This commit is contained in:
Catherine 2023-02-21 10:07:55 +00:00
parent bef2052c1e
commit 58721ee4fe
5 changed files with 203 additions and 106 deletions

View file

@ -93,13 +93,21 @@ class Shape:
bits_for(obj.stop - obj.step, signed))
return Shape(width, signed)
elif isinstance(obj, type) and issubclass(obj, Enum):
min_value = min(member.value for member in obj)
max_value = max(member.value for member in obj)
if not isinstance(min_value, int) or not isinstance(max_value, int):
raise TypeError("Only enumerations with integer values can be used "
"as value shapes")
signed = min_value < 0 or max_value < 0
width = max(bits_for(min_value, signed), bits_for(max_value, signed))
signed = False
width = 0
for member in obj:
try:
member_shape = Const.cast(member.value).shape()
except TypeError as e:
raise TypeError("Only enumerations whose members have constant-castable "
"values can be used in Amaranth code")
if not signed and member_shape.signed:
signed = True
width = max(width + 1, member_shape.width)
elif signed and not member_shape.signed:
width = max(width, member_shape.width + 1)
else:
width = max(width, member_shape.width)
return Shape(width, signed)
elif isinstance(obj, ShapeCastable):
new_obj = obj.as_shape()
@ -402,11 +410,8 @@ class Value(metaclass=ABCMeta):
``1`` if any pattern matches the value, ``0`` otherwise.
"""
matches = []
# This code should accept exactly the same patterns as `with m.Case(...):`.
for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace"
@ -416,23 +421,26 @@ class Value(metaclass=ABCMeta):
raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})"
.format(pattern, len(self)))
if isinstance(pattern, int) and bits_for(pattern) > len(self):
warnings.warn("Match pattern '{:b}' is wider than match value "
"(which has width {}); comparison will never be true"
.format(pattern, len(self)),
SyntaxWarning, stacklevel=3)
continue
if isinstance(pattern, str):
pattern = "".join(pattern.split()) # remove whitespace
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern)
elif isinstance(pattern, int):
matches.append(self == pattern)
elif isinstance(pattern, Enum):
matches.append(self == pattern.value)
else:
assert False
try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Match pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern_len > len(self):
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value, len(self)),
SyntaxWarning, stacklevel=2)
continue
matches.append(self == pattern)
if not matches:
return Const(0)
elif len(matches) == 1:
@ -560,9 +568,6 @@ class Value(metaclass=ABCMeta):
def _rhs_signals(self):
pass # :nocov:
def _as_const(self):
raise TypeError("Value {!r} cannot be evaluated as constant".format(self))
__hash__ = None
@ -595,6 +600,28 @@ class Const(Value):
value |= ~mask
return value
@staticmethod
def cast(obj):
"""Converts ``obj`` to an Amaranth constant.
First, ``obj`` is converted to a value using :meth:`Value.cast`. If it is a constant, it
is returned. If it is a constant-castable expression, it is evaluated and returned.
Otherwise, :exn:`TypeError` is raised.
"""
obj = Value.cast(obj)
if type(obj) is Const:
return obj
elif type(obj) is Cat:
value = 0
width = 0
for part in obj.parts:
const = Const.cast(part)
value |= const.value << width
width += len(const)
return Const(value, width)
else:
raise TypeError("Value {!r} cannot be converted to an Amaranth constant".format(obj))
def __init__(self, value, shape=None, *, src_loc_at=0):
# We deliberately do not call Value.__init__ here.
self.value = int(value)
@ -617,9 +644,6 @@ class Const(Value):
def _rhs_signals(self):
return SignalSet()
def _as_const(self):
return self.value
def __repr__(self):
return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value)
@ -858,13 +882,6 @@ class Cat(Value):
def _rhs_signals(self):
return union((part._rhs_signals() for part in self.parts), start=SignalSet())
def _as_const(self):
value = 0
for part in reversed(self.parts):
value <<= len(part)
value |= part._as_const()
return value
def __repr__(self):
return "(cat {})".format(" ".join(map(repr, self.parts)))

View file

@ -305,11 +305,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
src_loc = tracer.get_src_loc(src_loc_at=1)
switch_data = self._get_ctrl("Switch")
new_patterns = ()
# This code should accept exactly the same patterns as `v.matches(...)`.
for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace"
@ -319,20 +316,24 @@ class Module(_ModuleBuilderRoot, Elaboratable):
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})"
.format(pattern, len(switch_data["test"])))
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
warnings.warn("Case pattern '{:b}' is wider than switch value "
"(which has width {}); comparison will never be true"
.format(pattern, len(switch_data["test"])),
SyntaxWarning, stacklevel=3)
continue
if isinstance(pattern, Enum) and bits_for(pattern.value) > len(switch_data["test"]):
warnings.warn("Case pattern '{:b}' ({}.{}) is wider than switch value "
"(which has width {}); comparison will never be true"
.format(pattern.value, pattern.__class__.__name__, pattern.name,
len(switch_data["test"])),
SyntaxWarning, stacklevel=3)
continue
new_patterns = (*new_patterns, pattern)
if isinstance(pattern, str):
new_patterns = (*new_patterns, pattern)
else:
try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Case pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern_len > len(switch_data["test"]):
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value,
len(switch_data["test"])),
SyntaxWarning, stacklevel=3)
continue
new_patterns = (*new_patterns, pattern.value)
try:
_outer_case, self._statements = self._statements, []
self._ctrl_context = None