From 57612f1dce9f5e35a915235ab442d477b5f8db59 Mon Sep 17 00:00:00 2001 From: Catherine Date: Mon, 20 Feb 2023 22:58:38 +0000 Subject: [PATCH] lib.enum: add Enum wrappers that allow specifying shape. See #756 and amaranth-lang/rfcs#3. --- amaranth/hdl/ast.py | 54 +++++++++++++-------- amaranth/lib/enum.py | 108 +++++++++++++++++++++++++++++++++++++++++ docs/changes.rst | 26 ++++++++-- docs/conf.py | 8 +++ docs/lang.rst | 26 +++++++--- docs/stdlib.rst | 3 +- docs/stdlib/enum.rst | 43 ++++++++++++++++ tests/test_hdl_ast.py | 20 +++++--- tests/test_hdl_dsl.py | 4 +- tests/test_lib_enum.py | 91 ++++++++++++++++++++++++++++++++++ 10 files changed, 343 insertions(+), 40 deletions(-) create mode 100644 amaranth/lib/enum.py create mode 100644 docs/stdlib/enum.rst create mode 100644 tests/test_lib_enum.py diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index a795acf..18695c0 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -78,11 +78,34 @@ class Shape: self.width = width self.signed = signed + # The algorithm for inferring shape for standard Python enumerations is factored out so that + # `Shape.cast()` and Amaranth's `EnumMeta.as_shape()` can both use it. + @staticmethod + def _cast_plain_enum(obj): + signed = False + width = 0 + for member in obj: + try: + member_shape = Const.cast(member.value).shape() + except TypeError as e: + raise TypeError("Only enumerations whose members have constant-castable " + "values can be used in Amaranth code") + if not signed and member_shape.signed: + signed = True + width = max(width + 1, member_shape.width) + elif signed and not member_shape.signed: + width = max(width, member_shape.width + 1) + else: + width = max(width, member_shape.width) + return Shape(width, signed) + @staticmethod def cast(obj, *, src_loc_at=0): while True: if isinstance(obj, Shape): return obj + elif isinstance(obj, ShapeCastable): + new_obj = obj.as_shape() elif isinstance(obj, int): return Shape(obj) elif isinstance(obj, range): @@ -93,24 +116,9 @@ class Shape: bits_for(obj.stop - obj.step, signed)) return Shape(width, signed) elif isinstance(obj, type) and issubclass(obj, Enum): - signed = False - width = 0 - for member in obj: - try: - member_shape = Const.cast(member.value).shape() - except TypeError as e: - raise TypeError("Only enumerations whose members have constant-castable " - "values can be used in Amaranth code") - if not signed and member_shape.signed: - signed = True - width = max(width + 1, member_shape.width) - elif signed and not member_shape.signed: - width = max(width, member_shape.width + 1) - else: - width = max(width, member_shape.width) - return Shape(width, signed) - elif isinstance(obj, ShapeCastable): - new_obj = obj.as_shape() + # For compatibility with third party enumerations, handle them as if they were + # defined as subclasses of lib.enum.Enum with no explicitly specified shape. + return Shape._cast_plain_enum(obj) else: raise TypeError("Object {!r} cannot be converted to an Amaranth shape".format(obj)) if new_obj is obj: @@ -866,9 +874,17 @@ class Cat(Value): super().__init__(src_loc_at=src_loc_at) self.parts = [] for index, arg in enumerate(flatten(args)): + if isinstance(arg, Enum) and (not isinstance(type(arg), ShapeCastable) or + not hasattr(arg, "_amaranth_shape_")): + warnings.warn("Argument #{} of Cat() is an enumerated value {!r} without " + "a defined shape used in bit vector context; define the enumeration " + "by inheriting from the class in amaranth.lib.enum and specifying " + "the 'shape=' keyword argument" + .format(index + 1, arg), + SyntaxWarning, stacklevel=2 + src_loc_at) if isinstance(arg, int) and not isinstance(arg, Enum) and arg not in [0, 1]: warnings.warn("Argument #{} of Cat() is a bare integer {} used in bit vector " - "context; consider specifying explicit width using C({}, {}) instead" + "context; specify the width explicitly using C({}, {})" .format(index + 1, arg, arg, bits_for(arg)), SyntaxWarning, stacklevel=2 + src_loc_at) self.parts.append(Value.cast(arg)) diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py new file mode 100644 index 0000000..0542e35 --- /dev/null +++ b/amaranth/lib/enum.py @@ -0,0 +1,108 @@ +import enum as py_enum +import warnings + +from ..hdl.ast import Shape, ShapeCastable, Const +from .._utils import bits_for + + +__all__ = py_enum.__all__ + + +for member in py_enum.__all__: + globals()[member] = getattr(py_enum, member) +del member + + +class EnumMeta(ShapeCastable, py_enum.EnumMeta): + """Subclass of the standard :class:`enum.EnumMeta` that implements the :class:`ShapeCastable` + protocol. + + This metaclass provides the :meth:`as_shape` method, making its instances + :ref:`shape-castable `, and accepts a ``shape=`` keyword argument + to specify a shape explicitly. Other than this, it acts the same as the standard + :class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and + :meth:`as_shape` is never called, it places no restrictions on the enumeration class + or the values of its members. + """ + def __new__(metacls, name, bases, namespace, shape=None, **kwargs): + cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs) + if shape is not None: + # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that + # the values of every member can be cast to the provided shape without truncation. + cls._amaranth_shape_ = shape = Shape.cast(shape) + for member in cls: + try: + Const.cast(member.value) + except TypeError as e: + raise TypeError("Value of enumeration member {!r} must be " + "a constant-castable expression" + .format(member)) from e + width = bits_for(member.value, shape.signed) + if member.value < 0 and not shape.signed: + warnings.warn( + message="Value of enumeration member {!r} is signed, but enumeration " + "shape is {!r}" # the repr will be `unsigned(X)` + .format(member, shape), + category=RuntimeWarning, + stacklevel=2) + elif width > shape.width: + warnings.warn( + message="Value of enumeration member {!r} will be truncated to " + "enumeration shape {!r}" + .format(member, shape), + category=RuntimeWarning, + stacklevel=2) + else: + # Shape is not provided explicitly. Behave the same as a standard enumeration; + # the lack of `_amaranth_shape_` attribute is used to emit a warning when such + # an enumeration is used in a concatenation. + pass + return cls + + def as_shape(cls): + """Cast this enumeration to a shape. + + Returns + ------- + :class:`Shape` + Explicitly provided shape. If not provided, returns the result of shape-casting + this class :ref:`as a standard Python enumeration `. + + Raises + ------ + TypeError + If the enumeration has neither an explicitly provided shape nor any members. + """ + if hasattr(cls, "_amaranth_shape_"): + # Shape was provided explicitly; return it. + return cls._amaranth_shape_ + elif cls.__members__: + # Shape was not provided explicitly, but enumeration has members; treat it + # the same way `Shape.cast` treats standard library enumerations, so that + # `amaranth.lib.enum.Enum` can be a drop-in replacement for `enum.Enum`. + return Shape._cast_plain_enum(cls) + else: + # Shape was not provided explicitly, and enumeration has no members. + # This is a base or mixin class that cannot be instantiated directly. + raise TypeError("Enumeration '{}.{}' does not have a defined shape" + .format(cls.__module__, cls.__qualname__)) + + +class Enum(py_enum.Enum, metaclass=EnumMeta): + """Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as + its metaclass.""" + + +class IntEnum(py_enum.IntEnum, metaclass=EnumMeta): + """Subclass of the standard :class:`enum.IntEnum` that has :class:`EnumMeta` as + its metaclass.""" + + +class Flag(py_enum.Flag, metaclass=EnumMeta): + """Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as + its metaclass.""" + + +class IntFlag(py_enum.IntFlag, metaclass=EnumMeta): + """Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as + its metaclass.""" diff --git a/docs/changes.rst b/docs/changes.rst index aca5daa..c67c834 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -22,6 +22,17 @@ Apply the following changes to code written against Amaranth 0.3 to migrate it t While code that uses the features listed as deprecated below will work in Amaranth 0.4, they will be removed in the next version. +Implemented RFCs +---------------- + +.. _RFC 3: https://amaranth-lang.org/rfcs/0003-enumeration-shapes.html +.. _RFC 4: https://amaranth-lang.org/rfcs/0004-const-castable-exprs.html +.. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html + +* `RFC 3`_: Enumeration shapes +* `RFC 4`_: Constant-castable expressions +* `RFC 5`_: Remove Const.normalize + Language changes ---------------- @@ -30,19 +41,24 @@ Language changes * Added: :class:`ShapeCastable`, similar to :class:`ValueCastable`. * Added: :meth:`Value.as_signed` and :meth:`Value.as_unsigned` can be used on left-hand side of assignment (with no difference in behavior). -* Added: :meth:`Const.cast`, evaluating constant-castable values and returning a :class:`Const`. (`RFC 4`_) +* Added: :meth:`Const.cast`. (`RFC 4`_) * Added: :meth:`Value.matches` and ``with m.Case():`` accept any constant-castable objects. (`RFC 4`_) * Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively. * Changed: :meth:`Value.cast` treats instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) as enumerations rather than integers. -* Changed: ``Value.matches()`` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match ``with m.Case():``. -* Changed: :class:`Cat` accepts instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) without warning. +* Changed: :meth:`Value.matches` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match the behavior of ``with m.Case():``. +* Changed: :class:`Cat` warns if an enumeration without an explicitly specified shape is used. * Deprecated: :meth:`Const.normalize`. (`RFC 5`_) * Removed: (deprecated in 0.1) casting of :class:`Shape` to and from a ``(width, signed)`` tuple. * Removed: (deprecated in 0.3) :class:`ast.UserValue`. * Removed: (deprecated in 0.3) support for ``# nmigen:`` linter instructions at the beginning of file. -.. _RFC 4: https://amaranth-lang.org/rfcs/0004-const-castable-exprs.html -.. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html + +Standard library changes +------------------------ + +.. currentmodule:: amaranth.lib + +* Added: :mod:`amaranth.lib.enum`. Toolchain changes diff --git a/docs/conf.py b/docs/conf.py index da51d3d..c92197f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,9 +27,17 @@ intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} todo_include_todos = True +autodoc_member_order = "bysource" +autodoc_default_options = { + "members": True +} +autodoc_preserve_defaults = True + napoleon_google_docstring = False napoleon_numpy_docstring = True napoleon_use_ivar = True +napoleon_include_init_with_doc = True +napoleon_include_special_with_doc = True napoleon_custom_sections = ["Platform overrides"] html_theme = "sphinx_rtd_theme" diff --git a/docs/lang.rst b/docs/lang.rst index 0da841c..25d193e 100644 --- a/docs/lang.rst +++ b/docs/lang.rst @@ -183,11 +183,7 @@ Specifying a shape with a range is convenient for counters, indexes, and all oth Shapes from enumerations ------------------------ -Casting a shape from an :class:`enum.Enum` subclass ``E``: - - * fails if any of the enumeration members have non-integer values, - * has a width large enough to represent both ``min(m.value for m in E)`` and ``max(m.value for m in E)``, and - * is signed if either ``min(m.value for m in E)`` or ``max(m.value for m in E)`` are negative, unsigned otherwise. +Casting a shape from an :class:`enum.Enum` subclass requires all of the enumeration members to have :ref:`constant-castable ` values. The shape has a width large enough to represent the value of every member, and is signed only if there is a member with a negative value. Specifying a shape with an enumeration is convenient for finite state machines, multiplexers, complex control signals, and all other values whose width is derived from a few distinct choices they must be able to fit: @@ -208,9 +204,27 @@ Specifying a shape with an enumeration is convenient for finite state machines, >>> Shape.cast(Direction) unsigned(2) +The :mod:`amaranth.lib.enum` module extends the standard enumerations such that their shape can be specified explicitly when they are defined: + +.. testsetup:: + + import amaranth.lib.enum + +.. testcode:: + + class Funct4(amaranth.lib.enum.Enum, shape=unsigned(4)): + ADD = 0 + SUB = 1 + MUL = 2 + +.. doctest:: + + >>> Shape.cast(Funct4) + unsigned(4) + .. note:: - The enumeration does not have to subclass :class:`enum.IntEnum`; it only needs to have integers as values of every member. Using enumerations based on :class:`enum.Enum` rather than :class:`enum.IntEnum` prevents unwanted implicit conversion of enum members to integers. + The enumeration does not have to subclass :class:`enum.IntEnum` or have :class:`int` as one of its base classes; it only needs to have integers as values of every member. Using enumerations based on :class:`enum.Enum` rather than :class:`enum.IntEnum` prevents unwanted implicit conversion of enum members to integers. .. _lang-valuecasting: diff --git a/docs/stdlib.rst b/docs/stdlib.rst index 92b0465..38a3cea 100644 --- a/docs/stdlib.rst +++ b/docs/stdlib.rst @@ -8,6 +8,7 @@ Standard library .. toctree:: :maxdepth: 2 + stdlib/enum stdlib/coding stdlib/cdc - stdlib/fifo \ No newline at end of file + stdlib/fifo diff --git a/docs/stdlib/enum.rst b/docs/stdlib/enum.rst new file mode 100644 index 0000000..39c8575 --- /dev/null +++ b/docs/stdlib/enum.rst @@ -0,0 +1,43 @@ +Enumerations +############ + +.. py:module:: amaranth.lib.enum + +The :mod:`amaranth.lib.enum` module is a drop-in replacement for the standard :mod:`enum` module that provides extended :class:`Enum`, :class:`IntEnum`, :class:`Flag`, and :class:`IntFlag` classes with the ability to specify a shape explicitly. + +A shape can be specified for an enumeration with the ``shape=`` keyword argument: + +.. testsetup:: + + from amaranth import * + +.. testcode:: + + from amaranth.lib import enum + + class Funct4(enum.Enum, shape=4): + ADD = 0 + SUB = 1 + MUL = 2 + +.. doctest:: + + >>> Shape.cast(Funct4) + unsigned(4) + +This module is a drop-in replacement for the standard :mod:`enum` module, and re-exports all of its members (not just the ones described below). In an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``. + + +Metaclass +========= + +.. autoclass:: EnumMeta() + + +Base classes +============ + +.. autoclass:: Enum() +.. autoclass:: IntEnum() +.. autoclass:: Flag() +.. autoclass:: IntFlag() diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 65a9b39..9aa99f8 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -798,28 +798,34 @@ class CatTestCase(FHDLTestCase): warnings.filterwarnings(action="error", category=SyntaxWarning) Cat(0, 1, 1, 0) - def test_enum(self): + def test_enum_wrong(self): class Color(Enum): RED = 1 BLUE = 2 - with warnings.catch_warnings(): - warnings.filterwarnings(action="error", category=SyntaxWarning) + with self.assertWarnsRegex(SyntaxWarning, + r"^Argument #1 of Cat\(\) is an enumerated value without " + r"a defined shape used in bit vector context; define the enumeration by " + r"inheriting from the class in amaranth\.lib\.enum and specifying " + r"the 'shape=' keyword argument$"): c = Cat(Color.RED, Color.BLUE) self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))") - def test_intenum(self): + def test_intenum_wrong(self): class Color(int, Enum): RED = 1 BLUE = 2 - with warnings.catch_warnings(): - warnings.filterwarnings(action="error", category=SyntaxWarning) + with self.assertWarnsRegex(SyntaxWarning, + r"^Argument #1 of Cat\(\) is an enumerated value without " + r"a defined shape used in bit vector context; define the enumeration by " + r"inheriting from the class in amaranth\.lib\.enum and specifying " + r"the 'shape=' keyword argument$"): c = Cat(Color.RED, Color.BLUE) self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))") def test_int_wrong(self): with self.assertWarnsRegex(SyntaxWarning, r"^Argument #1 of Cat\(\) is a bare integer 2 used in bit vector context; " - r"consider specifying explicit width using C\(2, 2\) instead$"): + r"specify the width explicitly using C\(2, 2\)$"): Cat(2) diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index 552f120..719f8bc 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -1,11 +1,11 @@ # amaranth: UnusedElaboratable=no from collections import OrderedDict -from enum import Enum from amaranth.hdl.ast import * from amaranth.hdl.cd import * from amaranth.hdl.dsl import * +from amaranth.lib.enum import Enum from .utils import * @@ -447,7 +447,7 @@ class DSLTestCase(FHDLTestCase): """) def test_Switch_const_castable(self): - class Color(Enum): + class Color(Enum, shape=1): RED = 0 BLUE = 1 m = Module() diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py new file mode 100644 index 0000000..0040fee --- /dev/null +++ b/tests/test_lib_enum.py @@ -0,0 +1,91 @@ +from amaranth import * +from amaranth.lib.enum import Enum + +from .utils import * + + +class EnumTestCase(FHDLTestCase): + def test_non_int_members(self): + # Mustn't raise to be a drop-in replacement for Enum. + class EnumA(Enum): + A = "str" + + def test_non_int_members_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Value of enumeration member must be " + r"a constant-castable expression$"): + class EnumA(Enum, shape=unsigned(1)): + A = "str" + + def test_shape_no_members(self): + class EnumA(Enum): + pass + with self.assertRaisesRegex(TypeError, + r"^Enumeration '.+?\.EnumA' does not have a defined shape$"): + Shape.cast(EnumA) + + def test_shape_explicit(self): + class EnumA(Enum, shape=signed(2)): + pass + self.assertEqual(Shape.cast(EnumA), signed(2)) + + def test_shape_explicit_cast(self): + class EnumA(Enum, shape=range(10)): + pass + self.assertEqual(Shape.cast(EnumA), unsigned(4)) + + def test_shape_implicit(self): + class EnumA(Enum): + A = 0 + B = 1 + self.assertEqual(Shape.cast(EnumA), unsigned(1)) + class EnumB(Enum): + A = 0 + B = 5 + self.assertEqual(Shape.cast(EnumB), unsigned(3)) + class EnumC(Enum): + A = 0 + B = -1 + self.assertEqual(Shape.cast(EnumC), signed(2)) + class EnumD(Enum): + A = 3 + B = -5 + self.assertEqual(Shape.cast(EnumD), signed(4)) + + def test_shape_explicit_wrong_signed_mismatch(self): + with self.assertWarnsRegex(RuntimeWarning, + r"^Value of enumeration member is signed, but enumeration " + r"shape is unsigned\(1\)$"): + class EnumA(Enum, shape=unsigned(1)): + A = -1 + + def test_shape_explicit_wrong_too_wide(self): + with self.assertWarnsRegex(RuntimeWarning, + r"^Value of enumeration member will be truncated to enumeration " + r"shape unsigned\(1\)$"): + class EnumA(Enum, shape=unsigned(1)): + A = 2 + with self.assertWarnsRegex(RuntimeWarning, + r"^Value of enumeration member will be truncated to enumeration " + r"shape signed\(1\)$"): + class EnumB(Enum, shape=signed(1)): + A = 1 + with self.assertWarnsRegex(RuntimeWarning, + r"^Value of enumeration member will be truncated to enumeration " + r"shape signed\(1\)$"): + class EnumC(Enum, shape=signed(1)): + A = -2 + + def test_value_shape_from_enum_member(self): + class EnumA(Enum, shape=unsigned(10)): + A = 1 + self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)") + + def test_shape_implicit_wrong_in_concat(self): + class EnumA(Enum): + A = 0 + with self.assertWarnsRegex(SyntaxWarning, + r"^Argument #1 of Cat\(\) is an enumerated value without a defined " + r"shape used in bit vector context; define the enumeration by inheriting from " + r"the class in amaranth\.lib\.enum and specifying the 'shape=' keyword argument$"): + Cat(EnumA.A)