From 7e438180e08f72fd5d960d279bc4f0051bdb5e42 Mon Sep 17 00:00:00 2001 From: Charlotte Date: Sun, 2 Jul 2023 21:39:10 +1000 Subject: [PATCH] lib.enum: allow empty enums. --- amaranth/lib/enum.py | 13 ++++--------- tests/test_lib_enum.py | 9 ++++++--- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 61e5f28..10c3666 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -112,16 +112,11 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): if hasattr(cls, "_amaranth_shape_"): # Shape was provided explicitly; return it. return cls._amaranth_shape_ - elif cls.__members__: - # Shape was not provided explicitly, but enumeration has members; treat it - # the same way `Shape.cast` treats standard library enumerations, so that - # `amaranth.lib.enum.Enum` can be a drop-in replacement for `enum.Enum`. - return Shape._cast_plain_enum(cls) else: - # Shape was not provided explicitly, and enumeration has no members. - # This is a base or mixin class that cannot be instantiated directly. - raise TypeError("Enumeration '{}.{}' does not have a defined shape" - .format(cls.__module__, cls.__qualname__)) + # Shape was not provided explicitly; treat it the same way `Shape.cast` treats + # standard library enumerations, so that `amaranth.lib.enum.Enum` can be a drop-in + # replacement for `enum.Enum`. + return Shape._cast_plain_enum(cls) def __call__(cls, value): # :class:`py_enum.Enum` uses ``__call__()`` for type casting: ``E(x)`` returns diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 2bf2d19..9d5a27e 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -1,3 +1,5 @@ +import enum as py_enum + from amaranth import * from amaranth.lib.enum import Enum @@ -21,9 +23,10 @@ class EnumTestCase(FHDLTestCase): def test_shape_no_members(self): class EnumA(Enum): pass - with self.assertRaisesRegex(TypeError, - r"^Enumeration '.+?\.EnumA' does not have a defined shape$"): - Shape.cast(EnumA) + class PyEnumA(py_enum.Enum): + pass + self.assertEqual(Shape.cast(EnumA), unsigned(0)) + self.assertEqual(Shape.cast(PyEnumA), unsigned(0)) def test_shape_explicit(self): class EnumA(Enum, shape=signed(2)):