hdl: implement constant-castable expressions.
See #755 and amaranth-lang/rfcs#4.
This commit is contained in:
parent
bef2052c1e
commit
58721ee4fe
5 changed files with 203 additions and 106 deletions
|
|
@ -93,13 +93,21 @@ class Shape:
|
|||
bits_for(obj.stop - obj.step, signed))
|
||||
return Shape(width, signed)
|
||||
elif isinstance(obj, type) and issubclass(obj, Enum):
|
||||
min_value = min(member.value for member in obj)
|
||||
max_value = max(member.value for member in obj)
|
||||
if not isinstance(min_value, int) or not isinstance(max_value, int):
|
||||
raise TypeError("Only enumerations with integer values can be used "
|
||||
"as value shapes")
|
||||
signed = min_value < 0 or max_value < 0
|
||||
width = max(bits_for(min_value, signed), bits_for(max_value, signed))
|
||||
signed = False
|
||||
width = 0
|
||||
for member in obj:
|
||||
try:
|
||||
member_shape = Const.cast(member.value).shape()
|
||||
except TypeError as e:
|
||||
raise TypeError("Only enumerations whose members have constant-castable "
|
||||
"values can be used in Amaranth code")
|
||||
if not signed and member_shape.signed:
|
||||
signed = True
|
||||
width = max(width + 1, member_shape.width)
|
||||
elif signed and not member_shape.signed:
|
||||
width = max(width, member_shape.width + 1)
|
||||
else:
|
||||
width = max(width, member_shape.width)
|
||||
return Shape(width, signed)
|
||||
elif isinstance(obj, ShapeCastable):
|
||||
new_obj = obj.as_shape()
|
||||
|
|
@ -402,11 +410,8 @@ class Value(metaclass=ABCMeta):
|
|||
``1`` if any pattern matches the value, ``0`` otherwise.
|
||||
"""
|
||||
matches = []
|
||||
# This code should accept exactly the same patterns as `with m.Case(...):`.
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, (int, str, Enum)):
|
||||
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
|
||||
"not {!r}"
|
||||
.format(pattern))
|
||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||
"bits, and may include whitespace"
|
||||
|
|
@ -416,23 +421,26 @@ class Value(metaclass=ABCMeta):
|
|||
raise SyntaxError("Match pattern '{}' must have the same width as match value "
|
||||
"(which is {})"
|
||||
.format(pattern, len(self)))
|
||||
if isinstance(pattern, int) and bits_for(pattern) > len(self):
|
||||
warnings.warn("Match pattern '{:b}' is wider than match value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(pattern, len(self)),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
if isinstance(pattern, str):
|
||||
pattern = "".join(pattern.split()) # remove whitespace
|
||||
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
|
||||
pattern = int(pattern.replace("-", "0"), 2)
|
||||
matches.append((self & mask) == pattern)
|
||||
elif isinstance(pattern, int):
|
||||
matches.append(self == pattern)
|
||||
elif isinstance(pattern, Enum):
|
||||
matches.append(self == pattern.value)
|
||||
else:
|
||||
assert False
|
||||
try:
|
||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError("Match pattern must be a string or a constant-castable "
|
||||
"expression, not {!r}"
|
||||
.format(pattern)) from e
|
||||
pattern_len = bits_for(pattern.value)
|
||||
if pattern_len > len(self):
|
||||
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(orig_pattern, pattern_len, pattern.value, len(self)),
|
||||
SyntaxWarning, stacklevel=2)
|
||||
continue
|
||||
matches.append(self == pattern)
|
||||
if not matches:
|
||||
return Const(0)
|
||||
elif len(matches) == 1:
|
||||
|
|
@ -560,9 +568,6 @@ class Value(metaclass=ABCMeta):
|
|||
def _rhs_signals(self):
|
||||
pass # :nocov:
|
||||
|
||||
def _as_const(self):
|
||||
raise TypeError("Value {!r} cannot be evaluated as constant".format(self))
|
||||
|
||||
__hash__ = None
|
||||
|
||||
|
||||
|
|
@ -595,6 +600,28 @@ class Const(Value):
|
|||
value |= ~mask
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def cast(obj):
|
||||
"""Converts ``obj`` to an Amaranth constant.
|
||||
|
||||
First, ``obj`` is converted to a value using :meth:`Value.cast`. If it is a constant, it
|
||||
is returned. If it is a constant-castable expression, it is evaluated and returned.
|
||||
Otherwise, :exn:`TypeError` is raised.
|
||||
"""
|
||||
obj = Value.cast(obj)
|
||||
if type(obj) is Const:
|
||||
return obj
|
||||
elif type(obj) is Cat:
|
||||
value = 0
|
||||
width = 0
|
||||
for part in obj.parts:
|
||||
const = Const.cast(part)
|
||||
value |= const.value << width
|
||||
width += len(const)
|
||||
return Const(value, width)
|
||||
else:
|
||||
raise TypeError("Value {!r} cannot be converted to an Amaranth constant".format(obj))
|
||||
|
||||
def __init__(self, value, shape=None, *, src_loc_at=0):
|
||||
# We deliberately do not call Value.__init__ here.
|
||||
self.value = int(value)
|
||||
|
|
@ -617,9 +644,6 @@ class Const(Value):
|
|||
def _rhs_signals(self):
|
||||
return SignalSet()
|
||||
|
||||
def _as_const(self):
|
||||
return self.value
|
||||
|
||||
def __repr__(self):
|
||||
return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value)
|
||||
|
||||
|
|
@ -858,13 +882,6 @@ class Cat(Value):
|
|||
def _rhs_signals(self):
|
||||
return union((part._rhs_signals() for part in self.parts), start=SignalSet())
|
||||
|
||||
def _as_const(self):
|
||||
value = 0
|
||||
for part in reversed(self.parts):
|
||||
value <<= len(part)
|
||||
value |= part._as_const()
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return "(cat {})".format(" ".join(map(repr, self.parts)))
|
||||
|
||||
|
|
|
|||
|
|
@ -305,11 +305,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
src_loc = tracer.get_src_loc(src_loc_at=1)
|
||||
switch_data = self._get_ctrl("Switch")
|
||||
new_patterns = ()
|
||||
# This code should accept exactly the same patterns as `v.matches(...)`.
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, (int, str, Enum)):
|
||||
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
|
||||
"not {!r}"
|
||||
.format(pattern))
|
||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||
"bits, and may include whitespace"
|
||||
|
|
@ -319,20 +316,24 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
|
||||
"(which is {})"
|
||||
.format(pattern, len(switch_data["test"])))
|
||||
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{:b}' is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(pattern, len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
if isinstance(pattern, Enum) and bits_for(pattern.value) > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{:b}' ({}.{}) is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(pattern.value, pattern.__class__.__name__, pattern.name,
|
||||
len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
new_patterns = (*new_patterns, pattern)
|
||||
if isinstance(pattern, str):
|
||||
new_patterns = (*new_patterns, pattern)
|
||||
else:
|
||||
try:
|
||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError("Case pattern must be a string or a constant-castable "
|
||||
"expression, not {!r}"
|
||||
.format(pattern)) from e
|
||||
pattern_len = bits_for(pattern.value)
|
||||
if pattern_len > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(orig_pattern, pattern_len, pattern.value,
|
||||
len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
new_patterns = (*new_patterns, pattern.value)
|
||||
try:
|
||||
_outer_case, self._statements = self._statements, []
|
||||
self._ctrl_context = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue