lib.enum: add Enum wrappers that allow specifying shape.

See #756 and amaranth-lang/rfcs#3.
This commit is contained in:
Catherine 2023-02-20 22:58:38 +00:00
parent ef2e9fa809
commit 57612f1dce
10 changed files with 343 additions and 40 deletions

View file

@ -78,11 +78,34 @@ class Shape:
self.width = width self.width = width
self.signed = signed self.signed = signed
# The algorithm for inferring shape for standard Python enumerations is factored out so that
# `Shape.cast()` and Amaranth's `EnumMeta.as_shape()` can both use it.
@staticmethod
def _cast_plain_enum(obj):
signed = False
width = 0
for member in obj:
try:
member_shape = Const.cast(member.value).shape()
except TypeError as e:
raise TypeError("Only enumerations whose members have constant-castable "
"values can be used in Amaranth code")
if not signed and member_shape.signed:
signed = True
width = max(width + 1, member_shape.width)
elif signed and not member_shape.signed:
width = max(width, member_shape.width + 1)
else:
width = max(width, member_shape.width)
return Shape(width, signed)
@staticmethod @staticmethod
def cast(obj, *, src_loc_at=0): def cast(obj, *, src_loc_at=0):
while True: while True:
if isinstance(obj, Shape): if isinstance(obj, Shape):
return obj return obj
elif isinstance(obj, ShapeCastable):
new_obj = obj.as_shape()
elif isinstance(obj, int): elif isinstance(obj, int):
return Shape(obj) return Shape(obj)
elif isinstance(obj, range): elif isinstance(obj, range):
@ -93,24 +116,9 @@ class Shape:
bits_for(obj.stop - obj.step, signed)) bits_for(obj.stop - obj.step, signed))
return Shape(width, signed) return Shape(width, signed)
elif isinstance(obj, type) and issubclass(obj, Enum): elif isinstance(obj, type) and issubclass(obj, Enum):
signed = False # For compatibility with third party enumerations, handle them as if they were
width = 0 # defined as subclasses of lib.enum.Enum with no explicitly specified shape.
for member in obj: return Shape._cast_plain_enum(obj)
try:
member_shape = Const.cast(member.value).shape()
except TypeError as e:
raise TypeError("Only enumerations whose members have constant-castable "
"values can be used in Amaranth code")
if not signed and member_shape.signed:
signed = True
width = max(width + 1, member_shape.width)
elif signed and not member_shape.signed:
width = max(width, member_shape.width + 1)
else:
width = max(width, member_shape.width)
return Shape(width, signed)
elif isinstance(obj, ShapeCastable):
new_obj = obj.as_shape()
else: else:
raise TypeError("Object {!r} cannot be converted to an Amaranth shape".format(obj)) raise TypeError("Object {!r} cannot be converted to an Amaranth shape".format(obj))
if new_obj is obj: if new_obj is obj:
@ -866,9 +874,17 @@ class Cat(Value):
super().__init__(src_loc_at=src_loc_at) super().__init__(src_loc_at=src_loc_at)
self.parts = [] self.parts = []
for index, arg in enumerate(flatten(args)): for index, arg in enumerate(flatten(args)):
if isinstance(arg, Enum) and (not isinstance(type(arg), ShapeCastable) or
not hasattr(arg, "_amaranth_shape_")):
warnings.warn("Argument #{} of Cat() is an enumerated value {!r} without "
"a defined shape used in bit vector context; define the enumeration "
"by inheriting from the class in amaranth.lib.enum and specifying "
"the 'shape=' keyword argument"
.format(index + 1, arg),
SyntaxWarning, stacklevel=2 + src_loc_at)
if isinstance(arg, int) and not isinstance(arg, Enum) and arg not in [0, 1]: if isinstance(arg, int) and not isinstance(arg, Enum) and arg not in [0, 1]:
warnings.warn("Argument #{} of Cat() is a bare integer {} used in bit vector " warnings.warn("Argument #{} of Cat() is a bare integer {} used in bit vector "
"context; consider specifying explicit width using C({}, {}) instead" "context; specify the width explicitly using C({}, {})"
.format(index + 1, arg, arg, bits_for(arg)), .format(index + 1, arg, arg, bits_for(arg)),
SyntaxWarning, stacklevel=2 + src_loc_at) SyntaxWarning, stacklevel=2 + src_loc_at)
self.parts.append(Value.cast(arg)) self.parts.append(Value.cast(arg))

108
amaranth/lib/enum.py Normal file
View file

@ -0,0 +1,108 @@
import enum as py_enum
import warnings
from ..hdl.ast import Shape, ShapeCastable, Const
from .._utils import bits_for
__all__ = py_enum.__all__
for member in py_enum.__all__:
globals()[member] = getattr(py_enum, member)
del member
class EnumMeta(ShapeCastable, py_enum.EnumMeta):
"""Subclass of the standard :class:`enum.EnumMeta` that implements the :class:`ShapeCastable`
protocol.
This metaclass provides the :meth:`as_shape` method, making its instances
:ref:`shape-castable <lang-shapecasting>`, and accepts a ``shape=`` keyword argument
to specify a shape explicitly. Other than this, it acts the same as the standard
:class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and
:meth:`as_shape` is never called, it places no restrictions on the enumeration class
or the values of its members.
"""
def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs)
if shape is not None:
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
# the values of every member can be cast to the provided shape without truncation.
cls._amaranth_shape_ = shape = Shape.cast(shape)
for member in cls:
try:
Const.cast(member.value)
except TypeError as e:
raise TypeError("Value of enumeration member {!r} must be "
"a constant-castable expression"
.format(member)) from e
width = bits_for(member.value, shape.signed)
if member.value < 0 and not shape.signed:
warnings.warn(
message="Value of enumeration member {!r} is signed, but enumeration "
"shape is {!r}" # the repr will be `unsigned(X)`
.format(member, shape),
category=RuntimeWarning,
stacklevel=2)
elif width > shape.width:
warnings.warn(
message="Value of enumeration member {!r} will be truncated to "
"enumeration shape {!r}"
.format(member, shape),
category=RuntimeWarning,
stacklevel=2)
else:
# Shape is not provided explicitly. Behave the same as a standard enumeration;
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
# an enumeration is used in a concatenation.
pass
return cls
def as_shape(cls):
"""Cast this enumeration to a shape.
Returns
-------
:class:`Shape`
Explicitly provided shape. If not provided, returns the result of shape-casting
this class :ref:`as a standard Python enumeration <lang-shapeenum>`.
Raises
------
TypeError
If the enumeration has neither an explicitly provided shape nor any members.
"""
if hasattr(cls, "_amaranth_shape_"):
# Shape was provided explicitly; return it.
return cls._amaranth_shape_
elif cls.__members__:
# Shape was not provided explicitly, but enumeration has members; treat it
# the same way `Shape.cast` treats standard library enumerations, so that
# `amaranth.lib.enum.Enum` can be a drop-in replacement for `enum.Enum`.
return Shape._cast_plain_enum(cls)
else:
# Shape was not provided explicitly, and enumeration has no members.
# This is a base or mixin class that cannot be instantiated directly.
raise TypeError("Enumeration '{}.{}' does not have a defined shape"
.format(cls.__module__, cls.__qualname__))
class Enum(py_enum.Enum, metaclass=EnumMeta):
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
its metaclass."""
class IntEnum(py_enum.IntEnum, metaclass=EnumMeta):
"""Subclass of the standard :class:`enum.IntEnum` that has :class:`EnumMeta` as
its metaclass."""
class Flag(py_enum.Flag, metaclass=EnumMeta):
"""Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as
its metaclass."""
class IntFlag(py_enum.IntFlag, metaclass=EnumMeta):
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
its metaclass."""

View file

@ -22,6 +22,17 @@ Apply the following changes to code written against Amaranth 0.3 to migrate it t
While code that uses the features listed as deprecated below will work in Amaranth 0.4, they will be removed in the next version. While code that uses the features listed as deprecated below will work in Amaranth 0.4, they will be removed in the next version.
Implemented RFCs
----------------
.. _RFC 3: https://amaranth-lang.org/rfcs/0003-enumeration-shapes.html
.. _RFC 4: https://amaranth-lang.org/rfcs/0004-const-castable-exprs.html
.. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html
* `RFC 3`_: Enumeration shapes
* `RFC 4`_: Constant-castable expressions
* `RFC 5`_: Remove Const.normalize
Language changes Language changes
---------------- ----------------
@ -30,19 +41,24 @@ Language changes
* Added: :class:`ShapeCastable`, similar to :class:`ValueCastable`. * Added: :class:`ShapeCastable`, similar to :class:`ValueCastable`.
* Added: :meth:`Value.as_signed` and :meth:`Value.as_unsigned` can be used on left-hand side of assignment (with no difference in behavior). * Added: :meth:`Value.as_signed` and :meth:`Value.as_unsigned` can be used on left-hand side of assignment (with no difference in behavior).
* Added: :meth:`Const.cast`, evaluating constant-castable values and returning a :class:`Const`. (`RFC 4`_) * Added: :meth:`Const.cast`. (`RFC 4`_)
* Added: :meth:`Value.matches` and ``with m.Case():`` accept any constant-castable objects. (`RFC 4`_) * Added: :meth:`Value.matches` and ``with m.Case():`` accept any constant-castable objects. (`RFC 4`_)
* Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively. * Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively.
* Changed: :meth:`Value.cast` treats instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) as enumerations rather than integers. * Changed: :meth:`Value.cast` treats instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) as enumerations rather than integers.
* Changed: ``Value.matches()`` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match ``with m.Case():``. * Changed: :meth:`Value.matches` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match the behavior of ``with m.Case():``.
* Changed: :class:`Cat` accepts instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) without warning. * Changed: :class:`Cat` warns if an enumeration without an explicitly specified shape is used.
* Deprecated: :meth:`Const.normalize`. (`RFC 5`_) * Deprecated: :meth:`Const.normalize`. (`RFC 5`_)
* Removed: (deprecated in 0.1) casting of :class:`Shape` to and from a ``(width, signed)`` tuple. * Removed: (deprecated in 0.1) casting of :class:`Shape` to and from a ``(width, signed)`` tuple.
* Removed: (deprecated in 0.3) :class:`ast.UserValue`. * Removed: (deprecated in 0.3) :class:`ast.UserValue`.
* Removed: (deprecated in 0.3) support for ``# nmigen:`` linter instructions at the beginning of file. * Removed: (deprecated in 0.3) support for ``# nmigen:`` linter instructions at the beginning of file.
.. _RFC 4: https://amaranth-lang.org/rfcs/0004-const-castable-exprs.html
.. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html Standard library changes
------------------------
.. currentmodule:: amaranth.lib
* Added: :mod:`amaranth.lib.enum`.
Toolchain changes Toolchain changes

View file

@ -27,9 +27,17 @@ intersphinx_mapping = {"python": ("https://docs.python.org/3", None)}
todo_include_todos = True todo_include_todos = True
autodoc_member_order = "bysource"
autodoc_default_options = {
"members": True
}
autodoc_preserve_defaults = True
napoleon_google_docstring = False napoleon_google_docstring = False
napoleon_numpy_docstring = True napoleon_numpy_docstring = True
napoleon_use_ivar = True napoleon_use_ivar = True
napoleon_include_init_with_doc = True
napoleon_include_special_with_doc = True
napoleon_custom_sections = ["Platform overrides"] napoleon_custom_sections = ["Platform overrides"]
html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"

View file

@ -183,11 +183,7 @@ Specifying a shape with a range is convenient for counters, indexes, and all oth
Shapes from enumerations Shapes from enumerations
------------------------ ------------------------
Casting a shape from an :class:`enum.Enum` subclass ``E``: Casting a shape from an :class:`enum.Enum` subclass requires all of the enumeration members to have :ref:`constant-castable <lang-constcasting>` values. The shape has a width large enough to represent the value of every member, and is signed only if there is a member with a negative value.
* fails if any of the enumeration members have non-integer values,
* has a width large enough to represent both ``min(m.value for m in E)`` and ``max(m.value for m in E)``, and
* is signed if either ``min(m.value for m in E)`` or ``max(m.value for m in E)`` are negative, unsigned otherwise.
Specifying a shape with an enumeration is convenient for finite state machines, multiplexers, complex control signals, and all other values whose width is derived from a few distinct choices they must be able to fit: Specifying a shape with an enumeration is convenient for finite state machines, multiplexers, complex control signals, and all other values whose width is derived from a few distinct choices they must be able to fit:
@ -208,9 +204,27 @@ Specifying a shape with an enumeration is convenient for finite state machines,
>>> Shape.cast(Direction) >>> Shape.cast(Direction)
unsigned(2) unsigned(2)
The :mod:`amaranth.lib.enum` module extends the standard enumerations such that their shape can be specified explicitly when they are defined:
.. testsetup::
import amaranth.lib.enum
.. testcode::
class Funct4(amaranth.lib.enum.Enum, shape=unsigned(4)):
ADD = 0
SUB = 1
MUL = 2
.. doctest::
>>> Shape.cast(Funct4)
unsigned(4)
.. note:: .. note::
The enumeration does not have to subclass :class:`enum.IntEnum`; it only needs to have integers as values of every member. Using enumerations based on :class:`enum.Enum` rather than :class:`enum.IntEnum` prevents unwanted implicit conversion of enum members to integers. The enumeration does not have to subclass :class:`enum.IntEnum` or have :class:`int` as one of its base classes; it only needs to have integers as values of every member. Using enumerations based on :class:`enum.Enum` rather than :class:`enum.IntEnum` prevents unwanted implicit conversion of enum members to integers.
.. _lang-valuecasting: .. _lang-valuecasting:

View file

@ -8,6 +8,7 @@ Standard library
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
stdlib/enum
stdlib/coding stdlib/coding
stdlib/cdc stdlib/cdc
stdlib/fifo stdlib/fifo

43
docs/stdlib/enum.rst Normal file
View file

@ -0,0 +1,43 @@
Enumerations
############
.. py:module:: amaranth.lib.enum
The :mod:`amaranth.lib.enum` module is a drop-in replacement for the standard :mod:`enum` module that provides extended :class:`Enum`, :class:`IntEnum`, :class:`Flag`, and :class:`IntFlag` classes with the ability to specify a shape explicitly.
A shape can be specified for an enumeration with the ``shape=`` keyword argument:
.. testsetup::
from amaranth import *
.. testcode::
from amaranth.lib import enum
class Funct4(enum.Enum, shape=4):
ADD = 0
SUB = 1
MUL = 2
.. doctest::
>>> Shape.cast(Funct4)
unsigned(4)
This module is a drop-in replacement for the standard :mod:`enum` module, and re-exports all of its members (not just the ones described below). In an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``.
Metaclass
=========
.. autoclass:: EnumMeta()
Base classes
============
.. autoclass:: Enum()
.. autoclass:: IntEnum()
.. autoclass:: Flag()
.. autoclass:: IntFlag()

View file

@ -798,28 +798,34 @@ class CatTestCase(FHDLTestCase):
warnings.filterwarnings(action="error", category=SyntaxWarning) warnings.filterwarnings(action="error", category=SyntaxWarning)
Cat(0, 1, 1, 0) Cat(0, 1, 1, 0)
def test_enum(self): def test_enum_wrong(self):
class Color(Enum): class Color(Enum):
RED = 1 RED = 1
BLUE = 2 BLUE = 2
with warnings.catch_warnings(): with self.assertWarnsRegex(SyntaxWarning,
warnings.filterwarnings(action="error", category=SyntaxWarning) r"^Argument #1 of Cat\(\) is an enumerated value <Color\.RED: 1> without "
r"a defined shape used in bit vector context; define the enumeration by "
r"inheriting from the class in amaranth\.lib\.enum and specifying "
r"the 'shape=' keyword argument$"):
c = Cat(Color.RED, Color.BLUE) c = Cat(Color.RED, Color.BLUE)
self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))") self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))")
def test_intenum(self): def test_intenum_wrong(self):
class Color(int, Enum): class Color(int, Enum):
RED = 1 RED = 1
BLUE = 2 BLUE = 2
with warnings.catch_warnings(): with self.assertWarnsRegex(SyntaxWarning,
warnings.filterwarnings(action="error", category=SyntaxWarning) r"^Argument #1 of Cat\(\) is an enumerated value <Color\.RED: 1> without "
r"a defined shape used in bit vector context; define the enumeration by "
r"inheriting from the class in amaranth\.lib\.enum and specifying "
r"the 'shape=' keyword argument$"):
c = Cat(Color.RED, Color.BLUE) c = Cat(Color.RED, Color.BLUE)
self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))") self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))")
def test_int_wrong(self): def test_int_wrong(self):
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Argument #1 of Cat\(\) is a bare integer 2 used in bit vector context; " r"^Argument #1 of Cat\(\) is a bare integer 2 used in bit vector context; "
r"consider specifying explicit width using C\(2, 2\) instead$"): r"specify the width explicitly using C\(2, 2\)$"):
Cat(2) Cat(2)

View file

@ -1,11 +1,11 @@
# amaranth: UnusedElaboratable=no # amaranth: UnusedElaboratable=no
from collections import OrderedDict from collections import OrderedDict
from enum import Enum
from amaranth.hdl.ast import * from amaranth.hdl.ast import *
from amaranth.hdl.cd import * from amaranth.hdl.cd import *
from amaranth.hdl.dsl import * from amaranth.hdl.dsl import *
from amaranth.lib.enum import Enum
from .utils import * from .utils import *
@ -447,7 +447,7 @@ class DSLTestCase(FHDLTestCase):
""") """)
def test_Switch_const_castable(self): def test_Switch_const_castable(self):
class Color(Enum): class Color(Enum, shape=1):
RED = 0 RED = 0
BLUE = 1 BLUE = 1
m = Module() m = Module()

91
tests/test_lib_enum.py Normal file
View file

@ -0,0 +1,91 @@
from amaranth import *
from amaranth.lib.enum import Enum
from .utils import *
class EnumTestCase(FHDLTestCase):
def test_non_int_members(self):
# Mustn't raise to be a drop-in replacement for Enum.
class EnumA(Enum):
A = "str"
def test_non_int_members_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Value of enumeration member <EnumA\.A: 'str'> must be "
r"a constant-castable expression$"):
class EnumA(Enum, shape=unsigned(1)):
A = "str"
def test_shape_no_members(self):
class EnumA(Enum):
pass
with self.assertRaisesRegex(TypeError,
r"^Enumeration '.+?\.EnumA' does not have a defined shape$"):
Shape.cast(EnumA)
def test_shape_explicit(self):
class EnumA(Enum, shape=signed(2)):
pass
self.assertEqual(Shape.cast(EnumA), signed(2))
def test_shape_explicit_cast(self):
class EnumA(Enum, shape=range(10)):
pass
self.assertEqual(Shape.cast(EnumA), unsigned(4))
def test_shape_implicit(self):
class EnumA(Enum):
A = 0
B = 1
self.assertEqual(Shape.cast(EnumA), unsigned(1))
class EnumB(Enum):
A = 0
B = 5
self.assertEqual(Shape.cast(EnumB), unsigned(3))
class EnumC(Enum):
A = 0
B = -1
self.assertEqual(Shape.cast(EnumC), signed(2))
class EnumD(Enum):
A = 3
B = -5
self.assertEqual(Shape.cast(EnumD), signed(4))
def test_shape_explicit_wrong_signed_mismatch(self):
with self.assertWarnsRegex(RuntimeWarning,
r"^Value of enumeration member <EnumA\.A: -1> is signed, but enumeration "
r"shape is unsigned\(1\)$"):
class EnumA(Enum, shape=unsigned(1)):
A = -1
def test_shape_explicit_wrong_too_wide(self):
with self.assertWarnsRegex(RuntimeWarning,
r"^Value of enumeration member <EnumA\.A: 2> will be truncated to enumeration "
r"shape unsigned\(1\)$"):
class EnumA(Enum, shape=unsigned(1)):
A = 2
with self.assertWarnsRegex(RuntimeWarning,
r"^Value of enumeration member <EnumB\.A: 1> will be truncated to enumeration "
r"shape signed\(1\)$"):
class EnumB(Enum, shape=signed(1)):
A = 1
with self.assertWarnsRegex(RuntimeWarning,
r"^Value of enumeration member <EnumC\.A: -2> will be truncated to enumeration "
r"shape signed\(1\)$"):
class EnumC(Enum, shape=signed(1)):
A = -2
def test_value_shape_from_enum_member(self):
class EnumA(Enum, shape=unsigned(10)):
A = 1
self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)")
def test_shape_implicit_wrong_in_concat(self):
class EnumA(Enum):
A = 0
with self.assertWarnsRegex(SyntaxWarning,
r"^Argument #1 of Cat\(\) is an enumerated value <EnumA\.A: 0> without a defined "
r"shape used in bit vector context; define the enumeration by inheriting from "
r"the class in amaranth\.lib\.enum and specifying the 'shape=' keyword argument$"):
Cat(EnumA.A)