diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index c1e6f03..8784d8c 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -1,7 +1,7 @@ import enum as py_enum import warnings -from ..hdl.ast import Shape, ShapeCastable, Const +from ..hdl.ast import Value, Shape, ShapeCastable, Const __all__ = py_enum.__all__ @@ -32,11 +32,18 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): def __new__(metacls, name, bases, namespace, shape=None, **kwargs): if shape is not None: shape = Shape.cast(shape) + # Prepare enumeration members for instantiation. This logic is unfortunately very + # convoluted because it supports two very different code paths that need to share + # the emitted warnings. for member_name, member_value in namespace.items(): if py_enum._is_sunder(member_name) or py_enum._is_dunder(member_name): continue + # If a shape is specified ("Amaranth mode" of amaranth.lib.enum.Enum), then every + # member value must be a constant-castable expression. Otherwise ("Python mode" of + # amaranth.lib.enum.Enum) any value goes, since all enumerations accepted by + # the built-in Enum class must be also accepted by amaranth.lib.enum.Enum. try: - member_shape = Const.cast(member_value).shape() + member_const = Const.cast(member_value) except TypeError as e: if shape is not None: raise TypeError("Value {!r} of enumeration member {!r} must be " @@ -44,7 +51,21 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): .format(member_value, member_name)) from e else: continue + if isinstance(member_value, Value): + # The member value is an Amaranth value that is also constant-castable. + # It cannot be used in an enumeration as-is (since it doesn't return a boolean + # from comparison operators, and this is required by py_enum). + # Replace the member value with the integer value of the constant, per RFC 4. + # Note that we do this even if no shape is provided (and this class is emulating + # a Python enumeration); this is OK because we only need to accept everything that + # the built-in class accepts to be a drop-in replacement, but the built-in class + # does not accept Amaranth values. + # We use dict.__setitem__ since namespace is a py_enum._EnumDict that overrides + # __setitem__ to check if the name has been already used. + dict.__setitem__(namespace, member_name, member_const.value) + # If a shape was specified, check whether the member value is compatible with it. if shape is not None: + member_shape = member_const.shape() if member_shape.signed and not shape.signed: warnings.warn( message="Value {!r} of enumeration member {!r} is signed, but " @@ -61,6 +82,7 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): .format(member_value, member_name, shape), category=SyntaxWarning, stacklevel=2) + # Actually instantiate the enumeration class. cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs) if shape is not None: # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that diff --git a/docs/lang.rst b/docs/lang.rst index 74915ba..a9ffd2a 100644 --- a/docs/lang.rst +++ b/docs/lang.rst @@ -282,30 +282,25 @@ Constant-castable objects are accepted anywhere a constant integer is accepted. .. doctest:: - >>> Const.cast(Cat(Direction.TOP, Direction.LEFT)) - (const 4'd4) + >>> Const.cast(Cat(C(10, 4), C(1, 2))) + (const 6'd26) -.. TODO: uncomment when this actually works +They may be used in enumeration members, provided the enumeration inherits from :class:`amaranth.lib.enum.Enum`: -.. comment:: +.. testcode:: - They may be used in enumeration members: + class Funct(amaranth.lib.enum.Enum, shape=4): + ADD = 0 + ... - .. testcode:: - - class Funct(enum.Enum): - ADD = 0 - ... - - class Op(enum.Enum): - REG = 0 - IMM = 1 - - class Instr(enum.Enum): - ADD = Cat(Funct.ADD, Op.REG) - ADDI = Cat(Funct.ADD, Op.IMM) - ... + class Op(amaranth.lib.enum.Enum, shape=1): + REG = 0 + IMM = 1 + class Instr(amaranth.lib.enum.Enum, shape=5): + ADD = Cat(Funct.ADD, Op.REG) + ADDI = Cat(Funct.ADD, Op.IMM) + ... .. note:: diff --git a/docs/stdlib/enum.rst b/docs/stdlib/enum.rst index 39c8575..e1f9da0 100644 --- a/docs/stdlib/enum.rst +++ b/docs/stdlib/enum.rst @@ -15,16 +15,36 @@ A shape can be specified for an enumeration with the ``shape=`` keyword argument from amaranth.lib import enum - class Funct4(enum.Enum, shape=4): + class Funct(enum.Enum, shape=4): ADD = 0 SUB = 1 MUL = 2 .. doctest:: - >>> Shape.cast(Funct4) + >>> Shape.cast(Funct) unsigned(4) +Any :ref:`constant-castable ` expression can be used as the value of a member: + +.. testcode:: + + class Op(enum.Enum, shape=1): + REG = 0 + IMM = 1 + + class Instr(enum.Enum, shape=5): + ADD = Cat(Funct.ADD, Op.REG) + ADDI = Cat(Funct.ADD, Op.IMM) + SUB = Cat(Funct.SUB, Op.REG) + SUBI = Cat(Funct.SUB, Op.IMM) + ... + +.. doctest:: + + >>> Instr.SUBI + + This module is a drop-in replacement for the standard :mod:`enum` module, and re-exports all of its members (not just the ones described below). In an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``. diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 49d0a2b..5f86945 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -10,12 +10,20 @@ class EnumTestCase(FHDLTestCase): class EnumA(Enum): A = "str" - def test_non_int_members_wrong(self): + def test_non_const_non_int_members_wrong(self): with self.assertRaisesRegex(TypeError, r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"): class EnumA(Enum, shape=unsigned(1)): A = "str" + def test_const_non_int_members(self): + class EnumA(Enum): + A = C(0) + B = C(1) + self.assertIs(EnumA.A.value, 0) + self.assertIs(EnumA.B.value, 1) + self.assertEqual(Shape.cast(EnumA), unsigned(1)) + def test_shape_no_members(self): class EnumA(Enum): pass