hdl: implement constant-castable expressions.
See #755 and amaranth-lang/rfcs#4.
This commit is contained in:
parent
bef2052c1e
commit
58721ee4fe
|
@ -93,13 +93,21 @@ class Shape:
|
|||
bits_for(obj.stop - obj.step, signed))
|
||||
return Shape(width, signed)
|
||||
elif isinstance(obj, type) and issubclass(obj, Enum):
|
||||
min_value = min(member.value for member in obj)
|
||||
max_value = max(member.value for member in obj)
|
||||
if not isinstance(min_value, int) or not isinstance(max_value, int):
|
||||
raise TypeError("Only enumerations with integer values can be used "
|
||||
"as value shapes")
|
||||
signed = min_value < 0 or max_value < 0
|
||||
width = max(bits_for(min_value, signed), bits_for(max_value, signed))
|
||||
signed = False
|
||||
width = 0
|
||||
for member in obj:
|
||||
try:
|
||||
member_shape = Const.cast(member.value).shape()
|
||||
except TypeError as e:
|
||||
raise TypeError("Only enumerations whose members have constant-castable "
|
||||
"values can be used in Amaranth code")
|
||||
if not signed and member_shape.signed:
|
||||
signed = True
|
||||
width = max(width + 1, member_shape.width)
|
||||
elif signed and not member_shape.signed:
|
||||
width = max(width, member_shape.width + 1)
|
||||
else:
|
||||
width = max(width, member_shape.width)
|
||||
return Shape(width, signed)
|
||||
elif isinstance(obj, ShapeCastable):
|
||||
new_obj = obj.as_shape()
|
||||
|
@ -402,11 +410,8 @@ class Value(metaclass=ABCMeta):
|
|||
``1`` if any pattern matches the value, ``0`` otherwise.
|
||||
"""
|
||||
matches = []
|
||||
# This code should accept exactly the same patterns as `with m.Case(...):`.
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, (int, str, Enum)):
|
||||
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
|
||||
"not {!r}"
|
||||
.format(pattern))
|
||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||
"bits, and may include whitespace"
|
||||
|
@ -416,23 +421,26 @@ class Value(metaclass=ABCMeta):
|
|||
raise SyntaxError("Match pattern '{}' must have the same width as match value "
|
||||
"(which is {})"
|
||||
.format(pattern, len(self)))
|
||||
if isinstance(pattern, int) and bits_for(pattern) > len(self):
|
||||
warnings.warn("Match pattern '{:b}' is wider than match value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(pattern, len(self)),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
if isinstance(pattern, str):
|
||||
pattern = "".join(pattern.split()) # remove whitespace
|
||||
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
|
||||
pattern = int(pattern.replace("-", "0"), 2)
|
||||
matches.append((self & mask) == pattern)
|
||||
elif isinstance(pattern, int):
|
||||
matches.append(self == pattern)
|
||||
elif isinstance(pattern, Enum):
|
||||
matches.append(self == pattern.value)
|
||||
else:
|
||||
assert False
|
||||
try:
|
||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError("Match pattern must be a string or a constant-castable "
|
||||
"expression, not {!r}"
|
||||
.format(pattern)) from e
|
||||
pattern_len = bits_for(pattern.value)
|
||||
if pattern_len > len(self):
|
||||
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(orig_pattern, pattern_len, pattern.value, len(self)),
|
||||
SyntaxWarning, stacklevel=2)
|
||||
continue
|
||||
matches.append(self == pattern)
|
||||
if not matches:
|
||||
return Const(0)
|
||||
elif len(matches) == 1:
|
||||
|
@ -560,9 +568,6 @@ class Value(metaclass=ABCMeta):
|
|||
def _rhs_signals(self):
|
||||
pass # :nocov:
|
||||
|
||||
def _as_const(self):
|
||||
raise TypeError("Value {!r} cannot be evaluated as constant".format(self))
|
||||
|
||||
__hash__ = None
|
||||
|
||||
|
||||
|
@ -595,6 +600,28 @@ class Const(Value):
|
|||
value |= ~mask
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def cast(obj):
|
||||
"""Converts ``obj`` to an Amaranth constant.
|
||||
|
||||
First, ``obj`` is converted to a value using :meth:`Value.cast`. If it is a constant, it
|
||||
is returned. If it is a constant-castable expression, it is evaluated and returned.
|
||||
Otherwise, :exn:`TypeError` is raised.
|
||||
"""
|
||||
obj = Value.cast(obj)
|
||||
if type(obj) is Const:
|
||||
return obj
|
||||
elif type(obj) is Cat:
|
||||
value = 0
|
||||
width = 0
|
||||
for part in obj.parts:
|
||||
const = Const.cast(part)
|
||||
value |= const.value << width
|
||||
width += len(const)
|
||||
return Const(value, width)
|
||||
else:
|
||||
raise TypeError("Value {!r} cannot be converted to an Amaranth constant".format(obj))
|
||||
|
||||
def __init__(self, value, shape=None, *, src_loc_at=0):
|
||||
# We deliberately do not call Value.__init__ here.
|
||||
self.value = int(value)
|
||||
|
@ -617,9 +644,6 @@ class Const(Value):
|
|||
def _rhs_signals(self):
|
||||
return SignalSet()
|
||||
|
||||
def _as_const(self):
|
||||
return self.value
|
||||
|
||||
def __repr__(self):
|
||||
return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value)
|
||||
|
||||
|
@ -858,13 +882,6 @@ class Cat(Value):
|
|||
def _rhs_signals(self):
|
||||
return union((part._rhs_signals() for part in self.parts), start=SignalSet())
|
||||
|
||||
def _as_const(self):
|
||||
value = 0
|
||||
for part in reversed(self.parts):
|
||||
value <<= len(part)
|
||||
value |= part._as_const()
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return "(cat {})".format(" ".join(map(repr, self.parts)))
|
||||
|
||||
|
|
|
@ -305,11 +305,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
src_loc = tracer.get_src_loc(src_loc_at=1)
|
||||
switch_data = self._get_ctrl("Switch")
|
||||
new_patterns = ()
|
||||
# This code should accept exactly the same patterns as `v.matches(...)`.
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, (int, str, Enum)):
|
||||
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
|
||||
"not {!r}"
|
||||
.format(pattern))
|
||||
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
|
||||
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||
"bits, and may include whitespace"
|
||||
|
@ -319,20 +316,24 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
|||
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
|
||||
"(which is {})"
|
||||
.format(pattern, len(switch_data["test"])))
|
||||
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{:b}' is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(pattern, len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
if isinstance(pattern, Enum) and bits_for(pattern.value) > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{:b}' ({}.{}) is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(pattern.value, pattern.__class__.__name__, pattern.name,
|
||||
len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
new_patterns = (*new_patterns, pattern)
|
||||
if isinstance(pattern, str):
|
||||
new_patterns = (*new_patterns, pattern)
|
||||
else:
|
||||
try:
|
||||
orig_pattern, pattern = pattern, Const.cast(pattern)
|
||||
except TypeError as e:
|
||||
raise SyntaxError("Case pattern must be a string or a constant-castable "
|
||||
"expression, not {!r}"
|
||||
.format(pattern)) from e
|
||||
pattern_len = bits_for(pattern.value)
|
||||
if pattern_len > len(switch_data["test"]):
|
||||
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
|
||||
"(which has width {}); comparison will never be true"
|
||||
.format(orig_pattern, pattern_len, pattern.value,
|
||||
len(switch_data["test"])),
|
||||
SyntaxWarning, stacklevel=3)
|
||||
continue
|
||||
new_patterns = (*new_patterns, pattern.value)
|
||||
try:
|
||||
_outer_case, self._statements = self._statements, []
|
||||
self._ctrl_context = None
|
||||
|
|
129
docs/lang.rst
129
docs/lang.rst
|
@ -34,6 +34,43 @@ All of the examples below assume that a glob import is used.
|
|||
from amaranth import *
|
||||
|
||||
|
||||
.. _lang-shapes:
|
||||
|
||||
Shapes
|
||||
======
|
||||
|
||||
A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> Shape(width=5, signed=False)
|
||||
unsigned(5)
|
||||
>>> Shape(width=12, signed=True)
|
||||
signed(12)
|
||||
|
||||
However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> unsigned(5) == Shape(width=5, signed=False)
|
||||
True
|
||||
>>> signed(12) == Shape(width=12, signed=True)
|
||||
True
|
||||
|
||||
|
||||
Shapes of values
|
||||
----------------
|
||||
|
||||
All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> Const(5).shape()
|
||||
unsigned(3)
|
||||
>>> len(Const(5))
|
||||
3
|
||||
|
||||
|
||||
.. _lang-values:
|
||||
|
||||
Values
|
||||
|
@ -79,43 +116,6 @@ The shape of the constant can be specified explicitly, in which case the number'
|
|||
0
|
||||
|
||||
|
||||
.. _lang-shapes:
|
||||
|
||||
Shapes
|
||||
======
|
||||
|
||||
A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> Shape(width=5, signed=False)
|
||||
unsigned(5)
|
||||
>>> Shape(width=12, signed=True)
|
||||
signed(12)
|
||||
|
||||
However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> unsigned(5) == Shape(width=5, signed=False)
|
||||
True
|
||||
>>> signed(12) == Shape(width=12, signed=True)
|
||||
True
|
||||
|
||||
|
||||
Shapes of values
|
||||
----------------
|
||||
|
||||
All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> Const(5).shape()
|
||||
unsigned(3)
|
||||
>>> len(Const(5))
|
||||
3
|
||||
|
||||
|
||||
.. _lang-shapecasting:
|
||||
|
||||
Shape casting
|
||||
|
@ -218,7 +218,7 @@ Specifying a shape with an enumeration is convenient for finite state machines,
|
|||
Value casting
|
||||
=============
|
||||
|
||||
Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly.
|
||||
Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting to values allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly.
|
||||
|
||||
.. TODO: link to ValueCastable
|
||||
|
||||
|
@ -228,7 +228,7 @@ Casting to a value can be done explicitly with ``Value.cast``, but is usually im
|
|||
Values from integers
|
||||
--------------------
|
||||
|
||||
Casting a value from an integer ``i`` is a shorthand for ``Const(i)``:
|
||||
Casting a value from an integer ``i`` is equivalent to ``Const(i)``:
|
||||
|
||||
.. doctest::
|
||||
|
||||
|
@ -242,7 +242,7 @@ Casting a value from an integer ``i`` is a shorthand for ``Const(i)``:
|
|||
Values from enumeration members
|
||||
-------------------------------
|
||||
|
||||
Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.value, type(m))``:
|
||||
Casting a value from an enumeration member ``m`` is equivalent to ``Const(m.value, type(m))``:
|
||||
|
||||
.. doctest::
|
||||
|
||||
|
@ -254,6 +254,55 @@ Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.va
|
|||
|
||||
If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration.
|
||||
|
||||
|
||||
.. _lang-constcasting:
|
||||
|
||||
Constant casting
|
||||
================
|
||||
|
||||
A subset of :ref:`values <lang-values>` are *constant-castable*. If a value is constant-castable and all of its operands are also constant-castable, it can be converted to a :class:`Const`, the numeric value of which can then be read by Python code. This provides a way to perform computation on Amaranth values while constructing the design.
|
||||
|
||||
.. TODO: link to m.Case and v.matches() below
|
||||
|
||||
Constant-castable objects are accepted anywhere a constant integer is accepted. Casting to a constant can also be done explicitly with :meth:`Const.cast`:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> Const.cast(Cat(Direction.TOP, Direction.LEFT))
|
||||
(const 4'd4)
|
||||
|
||||
.. TODO: uncomment when this actually works
|
||||
|
||||
.. comment::
|
||||
|
||||
They may be used in enumeration members:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class Funct(enum.Enum):
|
||||
ADD = 0
|
||||
...
|
||||
|
||||
class Op(enum.Enum):
|
||||
REG = 0
|
||||
IMM = 1
|
||||
|
||||
class Instr(enum.Enum):
|
||||
ADD = Cat(Funct.ADD, Op.REG)
|
||||
ADDI = Cat(Funct.ADD, Op.IMM)
|
||||
...
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
At the moment, only the following expressions are constant-castable:
|
||||
|
||||
* :class:`Const`
|
||||
* :class:`Cat`
|
||||
|
||||
This list will be expanded in the future.
|
||||
|
||||
|
||||
.. _lang-signals:
|
||||
|
||||
Signals
|
||||
|
|
|
@ -133,7 +133,8 @@ class ShapeTestCase(FHDLTestCase):
|
|||
|
||||
def test_cast_enum_bad(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
r"^Only enumerations with integer values can be used as value shapes$"):
|
||||
r"^Only enumerations whose members have constant-castable values can be used "
|
||||
r"in Amaranth code$"):
|
||||
Shape.cast(StringEnum)
|
||||
|
||||
def test_cast_bad(self):
|
||||
|
@ -203,7 +204,8 @@ class ValueTestCase(FHDLTestCase):
|
|||
|
||||
def test_cast_enum_wrong(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
r"^Only enumerations with integer values can be used as value shapes$"):
|
||||
r"^Only enumerations whose members have constant-castable values can be used "
|
||||
r"in Amaranth code$"):
|
||||
Value.cast(StringEnum.FOO)
|
||||
|
||||
def test_bool(self):
|
||||
|
@ -614,7 +616,13 @@ class OperatorTestCase(FHDLTestCase):
|
|||
def test_matches_enum(self):
|
||||
s = Signal(SignedEnum)
|
||||
self.assertRepr(s.matches(SignedEnum.FOO), """
|
||||
(== (sig s) (const 1'sd-1))
|
||||
(== (sig s) (const 2'sd-1))
|
||||
""")
|
||||
|
||||
def test_matches_const_castable(self):
|
||||
s = Signal(4)
|
||||
self.assertRepr(s.matches(Cat(C(0b10, 2), C(0b11, 2))), """
|
||||
(== (sig s) (const 4'd14))
|
||||
""")
|
||||
|
||||
def test_matches_width_wrong(self):
|
||||
|
@ -623,21 +631,25 @@ class OperatorTestCase(FHDLTestCase):
|
|||
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"):
|
||||
s.matches("--")
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
(r"^Match pattern '10110' is wider than match value \(which has width 4\); "
|
||||
r"comparison will never be true$")):
|
||||
r"^Match pattern '22' \(5'10110\) is wider than match value \(which has "
|
||||
r"width 4\); comparison will never be true$"):
|
||||
s.matches(0b10110)
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider "
|
||||
r"than match value \(which has width 4\); comparison will never be true$"):
|
||||
s.matches(Cat(0, C(0b1011, 4)))
|
||||
|
||||
def test_matches_bits_wrong(self):
|
||||
s = Signal(4)
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
(r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||
r"and may include whitespace$")):
|
||||
r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
|
||||
r"and may include whitespace$"):
|
||||
s.matches("abc")
|
||||
|
||||
def test_matches_pattern_wrong(self):
|
||||
s = Signal(4)
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Match pattern must be an integer, a string, or an enumeration, not 1\.0$"):
|
||||
r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"):
|
||||
s.matches(1.0)
|
||||
|
||||
def test_hash(self):
|
||||
|
|
|
@ -446,6 +446,23 @@ class DSLTestCase(FHDLTestCase):
|
|||
)
|
||||
""")
|
||||
|
||||
def test_Switch_const_castable(self):
|
||||
class Color(Enum):
|
||||
RED = 0
|
||||
BLUE = 1
|
||||
m = Module()
|
||||
se = Signal(2)
|
||||
with m.Switch(se):
|
||||
with m.Case(Cat(Color.RED, Color.BLUE)):
|
||||
m.d.comb += self.c1.eq(1)
|
||||
self.assertRepr(m._statements, """
|
||||
(
|
||||
(switch (sig se)
|
||||
(case 10 (eq (sig c1) (const 1'd1)))
|
||||
)
|
||||
)
|
||||
""")
|
||||
|
||||
def test_Case_width_wrong(self):
|
||||
class Color(Enum):
|
||||
RED = 0b10101010
|
||||
|
@ -456,13 +473,13 @@ class DSLTestCase(FHDLTestCase):
|
|||
with m.Case("--"):
|
||||
pass
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
(r"^Case pattern '10110' is wider than switch value \(which has width 4\); "
|
||||
r"comparison will never be true$")):
|
||||
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
|
||||
r"width 4\); comparison will never be true$"):
|
||||
with m.Case(0b10110):
|
||||
pass
|
||||
with self.assertWarnsRegex(SyntaxWarning,
|
||||
(r"^Case pattern '10101010' \(Color\.RED\) is wider than switch value "
|
||||
r"\(which has width 4\); comparison will never be true$")):
|
||||
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
|
||||
r"\(which has width 4\); comparison will never be true$"):
|
||||
with m.Case(Color.RED):
|
||||
pass
|
||||
self.assertRepr(m._statements, """
|
||||
|
@ -484,7 +501,8 @@ class DSLTestCase(FHDLTestCase):
|
|||
m = Module()
|
||||
with m.Switch(self.w1):
|
||||
with self.assertRaisesRegex(SyntaxError,
|
||||
r"^Case pattern must be an integer, a string, or an enumeration, not 1\.0$"):
|
||||
r"^Case pattern must be a string or a constant-castable expression, "
|
||||
r"not 1\.0$"):
|
||||
with m.Case(1.0):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in a new issue