diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 6dc70e7..7800375 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -83,8 +83,8 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): category=SyntaxWarning, stacklevel=2) # Actually instantiate the enumeration class. - cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs) if shape is not None: + cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs) # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that # the values of every member can be cast to the provided shape without truncation. cls._amaranth_shape_ = shape @@ -92,7 +92,14 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): # Shape is not provided explicitly. Behave the same as a standard enumeration; # the lack of `_amaranth_shape_` attribute is used to emit a warning when such # an enumeration is used in a concatenation. - pass + bases = tuple( + py_enum.Enum if base is Enum else + py_enum.IntEnum if base is IntEnum else + py_enum.Flag if base is Flag else + py_enum.IntFlag if base is IntFlag else base + for base in bases + ) + cls = py_enum.EnumMeta.__new__(py_enum.EnumMeta, name, bases, namespace, **kwargs) return cls def as_shape(cls): @@ -144,21 +151,28 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): return Const(member.value, cls.as_shape()) -class Enum(py_enum.Enum, metaclass=EnumMeta): +class Enum(py_enum.Enum): """Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as its metaclass.""" -class IntEnum(py_enum.IntEnum, metaclass=EnumMeta): +class IntEnum(py_enum.IntEnum): """Subclass of the standard :class:`enum.IntEnum` that has :class:`EnumMeta` as its metaclass.""" -class Flag(py_enum.Flag, metaclass=EnumMeta): +class Flag(py_enum.Flag): """Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as its metaclass.""" -class IntFlag(py_enum.IntFlag, metaclass=EnumMeta): +class IntFlag(py_enum.IntFlag): """Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as its metaclass.""" + +# Fix up the metaclass after the fact: the metaclass __new__ requires these classes +# to already be present, and also would not install itself on them due to lack of shape. +Enum.__class__ = EnumMeta +IntEnum.__class__ = EnumMeta +Flag.__class__ = EnumMeta +IntFlag.__class__ = EnumMeta diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index c1ad320..2252a38 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1077,7 +1077,7 @@ class SignalTestCase(FHDLTestCase): Signal(CastableFromHex(), reset="01") def test_reset_shape_castable_enum_wrong(self): - class EnumA(AmaranthEnum): + class EnumA(AmaranthEnum, shape=1): X = 1 with self.assertRaisesRegex(TypeError, r"^Reset value must be a constant initializer of $"): diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 5c86d60..99425ff 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -1,7 +1,7 @@ import enum as py_enum from amaranth import * -from amaranth.lib.enum import Enum +from amaranth.lib.enum import Enum, EnumMeta from .utils import * @@ -91,14 +91,13 @@ class EnumTestCase(FHDLTestCase): A = 1 self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)") - def test_const_no_shape(self): + def test_no_shape(self): class EnumA(Enum): Z = 0 A = 10 B = 20 - self.assertRepr(EnumA.const(None), "(const 5'd0)") - self.assertRepr(EnumA.const(10), "(const 5'd10)") - self.assertRepr(EnumA.const(EnumA.A), "(const 5'd10)") + self.assertNotIsInstance(EnumA, EnumMeta) + self.assertIsInstance(EnumA, py_enum.EnumMeta) def test_const_shape(self): class EnumA(Enum, shape=8):