hdl: implement constant-castable expressions.

See #755 and amaranth-lang/rfcs#4.
This commit is contained in:
Catherine 2023-02-21 10:07:55 +00:00
parent bef2052c1e
commit 58721ee4fe
5 changed files with 203 additions and 106 deletions

View file

@ -93,13 +93,21 @@ class Shape:
bits_for(obj.stop - obj.step, signed)) bits_for(obj.stop - obj.step, signed))
return Shape(width, signed) return Shape(width, signed)
elif isinstance(obj, type) and issubclass(obj, Enum): elif isinstance(obj, type) and issubclass(obj, Enum):
min_value = min(member.value for member in obj) signed = False
max_value = max(member.value for member in obj) width = 0
if not isinstance(min_value, int) or not isinstance(max_value, int): for member in obj:
raise TypeError("Only enumerations with integer values can be used " try:
"as value shapes") member_shape = Const.cast(member.value).shape()
signed = min_value < 0 or max_value < 0 except TypeError as e:
width = max(bits_for(min_value, signed), bits_for(max_value, signed)) raise TypeError("Only enumerations whose members have constant-castable "
"values can be used in Amaranth code")
if not signed and member_shape.signed:
signed = True
width = max(width + 1, member_shape.width)
elif signed and not member_shape.signed:
width = max(width, member_shape.width + 1)
else:
width = max(width, member_shape.width)
return Shape(width, signed) return Shape(width, signed)
elif isinstance(obj, ShapeCastable): elif isinstance(obj, ShapeCastable):
new_obj = obj.as_shape() new_obj = obj.as_shape()
@ -402,11 +410,8 @@ class Value(metaclass=ABCMeta):
``1`` if any pattern matches the value, ``0`` otherwise. ``1`` if any pattern matches the value, ``0`` otherwise.
""" """
matches = [] matches = []
# This code should accept exactly the same patterns as `with m.Case(...):`.
for pattern in patterns: for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace" "bits, and may include whitespace"
@ -416,23 +421,26 @@ class Value(metaclass=ABCMeta):
raise SyntaxError("Match pattern '{}' must have the same width as match value " raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})" "(which is {})"
.format(pattern, len(self))) .format(pattern, len(self)))
if isinstance(pattern, int) and bits_for(pattern) > len(self):
warnings.warn("Match pattern '{:b}' is wider than match value "
"(which has width {}); comparison will never be true"
.format(pattern, len(self)),
SyntaxWarning, stacklevel=3)
continue
if isinstance(pattern, str): if isinstance(pattern, str):
pattern = "".join(pattern.split()) # remove whitespace pattern = "".join(pattern.split()) # remove whitespace
mask = int(pattern.replace("0", "1").replace("-", "0"), 2) mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2) pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern) matches.append((self & mask) == pattern)
elif isinstance(pattern, int):
matches.append(self == pattern)
elif isinstance(pattern, Enum):
matches.append(self == pattern.value)
else: else:
assert False try:
orig_pattern, pattern = pattern, Const.cast(pattern)
except TypeError as e:
raise SyntaxError("Match pattern must be a string or a constant-castable "
"expression, not {!r}"
.format(pattern)) from e
pattern_len = bits_for(pattern.value)
if pattern_len > len(self):
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
"(which has width {}); comparison will never be true"
.format(orig_pattern, pattern_len, pattern.value, len(self)),
SyntaxWarning, stacklevel=2)
continue
matches.append(self == pattern)
if not matches: if not matches:
return Const(0) return Const(0)
elif len(matches) == 1: elif len(matches) == 1:
@ -560,9 +568,6 @@ class Value(metaclass=ABCMeta):
def _rhs_signals(self): def _rhs_signals(self):
pass # :nocov: pass # :nocov:
def _as_const(self):
raise TypeError("Value {!r} cannot be evaluated as constant".format(self))
__hash__ = None __hash__ = None
@ -595,6 +600,28 @@ class Const(Value):
value |= ~mask value |= ~mask
return value return value
@staticmethod
def cast(obj):
"""Converts ``obj`` to an Amaranth constant.
First, ``obj`` is converted to a value using :meth:`Value.cast`. If it is a constant, it
is returned. If it is a constant-castable expression, it is evaluated and returned.
Otherwise, :exn:`TypeError` is raised.
"""
obj = Value.cast(obj)
if type(obj) is Const:
return obj
elif type(obj) is Cat:
value = 0
width = 0
for part in obj.parts:
const = Const.cast(part)
value |= const.value << width
width += len(const)
return Const(value, width)
else:
raise TypeError("Value {!r} cannot be converted to an Amaranth constant".format(obj))
def __init__(self, value, shape=None, *, src_loc_at=0): def __init__(self, value, shape=None, *, src_loc_at=0):
# We deliberately do not call Value.__init__ here. # We deliberately do not call Value.__init__ here.
self.value = int(value) self.value = int(value)
@ -617,9 +644,6 @@ class Const(Value):
def _rhs_signals(self): def _rhs_signals(self):
return SignalSet() return SignalSet()
def _as_const(self):
return self.value
def __repr__(self): def __repr__(self):
return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value) return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value)
@ -858,13 +882,6 @@ class Cat(Value):
def _rhs_signals(self): def _rhs_signals(self):
return union((part._rhs_signals() for part in self.parts), start=SignalSet()) return union((part._rhs_signals() for part in self.parts), start=SignalSet())
def _as_const(self):
value = 0
for part in reversed(self.parts):
value <<= len(part)
value |= part._as_const()
return value
def __repr__(self): def __repr__(self):
return "(cat {})".format(" ".join(map(repr, self.parts))) return "(cat {})".format(" ".join(map(repr, self.parts)))

View file

@ -305,11 +305,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
src_loc = tracer.get_src_loc(src_loc_at=1) src_loc = tracer.get_src_loc(src_loc_at=1)
switch_data = self._get_ctrl("Switch") switch_data = self._get_ctrl("Switch")
new_patterns = () new_patterns = ()
# This code should accept exactly the same patterns as `v.matches(...)`.
for pattern in patterns: for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) " raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
"bits, and may include whitespace" "bits, and may include whitespace"
@ -319,20 +316,24 @@ class Module(_ModuleBuilderRoot, Elaboratable):
raise SyntaxError("Case pattern '{}' must have the same width as switch value " raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})" "(which is {})"
.format(pattern, len(switch_data["test"]))) .format(pattern, len(switch_data["test"])))
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]): if isinstance(pattern, str):
warnings.warn("Case pattern '{:b}' is wider than switch value " new_patterns = (*new_patterns, pattern)
"(which has width {}); comparison will never be true" else:
.format(pattern, len(switch_data["test"])), try:
SyntaxWarning, stacklevel=3) orig_pattern, pattern = pattern, Const.cast(pattern)
continue except TypeError as e:
if isinstance(pattern, Enum) and bits_for(pattern.value) > len(switch_data["test"]): raise SyntaxError("Case pattern must be a string or a constant-castable "
warnings.warn("Case pattern '{:b}' ({}.{}) is wider than switch value " "expression, not {!r}"
"(which has width {}); comparison will never be true" .format(pattern)) from e
.format(pattern.value, pattern.__class__.__name__, pattern.name, pattern_len = bits_for(pattern.value)
len(switch_data["test"])), if pattern_len > len(switch_data["test"]):
SyntaxWarning, stacklevel=3) warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
continue "(which has width {}); comparison will never be true"
new_patterns = (*new_patterns, pattern) .format(orig_pattern, pattern_len, pattern.value,
len(switch_data["test"])),
SyntaxWarning, stacklevel=3)
continue
new_patterns = (*new_patterns, pattern.value)
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, []
self._ctrl_context = None self._ctrl_context = None

View file

@ -34,6 +34,43 @@ All of the examples below assume that a glob import is used.
from amaranth import * from amaranth import *
.. _lang-shapes:
Shapes
======
A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly:
.. doctest::
>>> Shape(width=5, signed=False)
unsigned(5)
>>> Shape(width=12, signed=True)
signed(12)
However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient:
.. doctest::
>>> unsigned(5) == Shape(width=5, signed=False)
True
>>> signed(12) == Shape(width=12, signed=True)
True
Shapes of values
----------------
All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``.
.. doctest::
>>> Const(5).shape()
unsigned(3)
>>> len(Const(5))
3
.. _lang-values: .. _lang-values:
Values Values
@ -79,43 +116,6 @@ The shape of the constant can be specified explicitly, in which case the number'
0 0
.. _lang-shapes:
Shapes
======
A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly:
.. doctest::
>>> Shape(width=5, signed=False)
unsigned(5)
>>> Shape(width=12, signed=True)
signed(12)
However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient:
.. doctest::
>>> unsigned(5) == Shape(width=5, signed=False)
True
>>> signed(12) == Shape(width=12, signed=True)
True
Shapes of values
----------------
All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``.
.. doctest::
>>> Const(5).shape()
unsigned(3)
>>> len(Const(5))
3
.. _lang-shapecasting: .. _lang-shapecasting:
Shape casting Shape casting
@ -218,7 +218,7 @@ Specifying a shape with an enumeration is convenient for finite state machines,
Value casting Value casting
============= =============
Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly. Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting to values allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly.
.. TODO: link to ValueCastable .. TODO: link to ValueCastable
@ -228,7 +228,7 @@ Casting to a value can be done explicitly with ``Value.cast``, but is usually im
Values from integers Values from integers
-------------------- --------------------
Casting a value from an integer ``i`` is a shorthand for ``Const(i)``: Casting a value from an integer ``i`` is equivalent to ``Const(i)``:
.. doctest:: .. doctest::
@ -242,7 +242,7 @@ Casting a value from an integer ``i`` is a shorthand for ``Const(i)``:
Values from enumeration members Values from enumeration members
------------------------------- -------------------------------
Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.value, type(m))``: Casting a value from an enumeration member ``m`` is equivalent to ``Const(m.value, type(m))``:
.. doctest:: .. doctest::
@ -254,6 +254,55 @@ Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.va
If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration. If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration.
.. _lang-constcasting:
Constant casting
================
A subset of :ref:`values <lang-values>` are *constant-castable*. If a value is constant-castable and all of its operands are also constant-castable, it can be converted to a :class:`Const`, the numeric value of which can then be read by Python code. This provides a way to perform computation on Amaranth values while constructing the design.
.. TODO: link to m.Case and v.matches() below
Constant-castable objects are accepted anywhere a constant integer is accepted. Casting to a constant can also be done explicitly with :meth:`Const.cast`:
.. doctest::
>>> Const.cast(Cat(Direction.TOP, Direction.LEFT))
(const 4'd4)
.. TODO: uncomment when this actually works
.. comment::
They may be used in enumeration members:
.. testcode::
class Funct(enum.Enum):
ADD = 0
...
class Op(enum.Enum):
REG = 0
IMM = 1
class Instr(enum.Enum):
ADD = Cat(Funct.ADD, Op.REG)
ADDI = Cat(Funct.ADD, Op.IMM)
...
.. note::
At the moment, only the following expressions are constant-castable:
* :class:`Const`
* :class:`Cat`
This list will be expanded in the future.
.. _lang-signals: .. _lang-signals:
Signals Signals

View file

@ -133,7 +133,8 @@ class ShapeTestCase(FHDLTestCase):
def test_cast_enum_bad(self): def test_cast_enum_bad(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Only enumerations with integer values can be used as value shapes$"): r"^Only enumerations whose members have constant-castable values can be used "
r"in Amaranth code$"):
Shape.cast(StringEnum) Shape.cast(StringEnum)
def test_cast_bad(self): def test_cast_bad(self):
@ -203,7 +204,8 @@ class ValueTestCase(FHDLTestCase):
def test_cast_enum_wrong(self): def test_cast_enum_wrong(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Only enumerations with integer values can be used as value shapes$"): r"^Only enumerations whose members have constant-castable values can be used "
r"in Amaranth code$"):
Value.cast(StringEnum.FOO) Value.cast(StringEnum.FOO)
def test_bool(self): def test_bool(self):
@ -614,7 +616,13 @@ class OperatorTestCase(FHDLTestCase):
def test_matches_enum(self): def test_matches_enum(self):
s = Signal(SignedEnum) s = Signal(SignedEnum)
self.assertRepr(s.matches(SignedEnum.FOO), """ self.assertRepr(s.matches(SignedEnum.FOO), """
(== (sig s) (const 1'sd-1)) (== (sig s) (const 2'sd-1))
""")
def test_matches_const_castable(self):
s = Signal(4)
self.assertRepr(s.matches(Cat(C(0b10, 2), C(0b11, 2))), """
(== (sig s) (const 4'd14))
""") """)
def test_matches_width_wrong(self): def test_matches_width_wrong(self):
@ -623,21 +631,25 @@ class OperatorTestCase(FHDLTestCase):
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"): r"^Match pattern '--' must have the same width as match value \(which is 4\)$"):
s.matches("--") s.matches("--")
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
(r"^Match pattern '10110' is wider than match value \(which has width 4\); " r"^Match pattern '22' \(5'10110\) is wider than match value \(which has "
r"comparison will never be true$")): r"width 4\); comparison will never be true$"):
s.matches(0b10110) s.matches(0b10110)
with self.assertWarnsRegex(SyntaxWarning,
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider "
r"than match value \(which has width 4\); comparison will never be true$"):
s.matches(Cat(0, C(0b1011, 4)))
def test_matches_bits_wrong(self): def test_matches_bits_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
(r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
r"and may include whitespace$")): r"and may include whitespace$"):
s.matches("abc") s.matches("abc")
def test_matches_pattern_wrong(self): def test_matches_pattern_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Match pattern must be an integer, a string, or an enumeration, not 1\.0$"): r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"):
s.matches(1.0) s.matches(1.0)
def test_hash(self): def test_hash(self):

View file

@ -446,6 +446,23 @@ class DSLTestCase(FHDLTestCase):
) )
""") """)
def test_Switch_const_castable(self):
class Color(Enum):
RED = 0
BLUE = 1
m = Module()
se = Signal(2)
with m.Switch(se):
with m.Case(Cat(Color.RED, Color.BLUE)):
m.d.comb += self.c1.eq(1)
self.assertRepr(m._statements, """
(
(switch (sig se)
(case 10 (eq (sig c1) (const 1'd1)))
)
)
""")
def test_Case_width_wrong(self): def test_Case_width_wrong(self):
class Color(Enum): class Color(Enum):
RED = 0b10101010 RED = 0b10101010
@ -456,13 +473,13 @@ class DSLTestCase(FHDLTestCase):
with m.Case("--"): with m.Case("--"):
pass pass
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
(r"^Case pattern '10110' is wider than switch value \(which has width 4\); " r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
r"comparison will never be true$")): r"width 4\); comparison will never be true$"):
with m.Case(0b10110): with m.Case(0b10110):
pass pass
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
(r"^Case pattern '10101010' \(Color\.RED\) is wider than switch value " r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
r"\(which has width 4\); comparison will never be true$")): r"\(which has width 4\); comparison will never be true$"):
with m.Case(Color.RED): with m.Case(Color.RED):
pass pass
self.assertRepr(m._statements, """ self.assertRepr(m._statements, """
@ -484,7 +501,8 @@ class DSLTestCase(FHDLTestCase):
m = Module() m = Module()
with m.Switch(self.w1): with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Case pattern must be an integer, a string, or an enumeration, not 1\.0$"): r"^Case pattern must be a string or a constant-castable expression, "
r"not 1\.0$"):
with m.Case(1.0): with m.Case(1.0):
pass pass