diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 9748b28..c1e6f03 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -30,34 +30,42 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): return super().__prepare__(name, bases, **kwargs) def __new__(metacls, name, bases, namespace, shape=None, **kwargs): - cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs) if shape is not None: - # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that - # the values of every member can be cast to the provided shape without truncation. - cls._amaranth_shape_ = shape = Shape.cast(shape) - for member in cls: - try: - member_shape = Const.cast(member.value).shape() - except TypeError as e: - raise TypeError("Value of enumeration member {!r} must be " + shape = Shape.cast(shape) + for member_name, member_value in namespace.items(): + if py_enum._is_sunder(member_name) or py_enum._is_dunder(member_name): + continue + try: + member_shape = Const.cast(member_value).shape() + except TypeError as e: + if shape is not None: + raise TypeError("Value {!r} of enumeration member {!r} must be " "a constant-castable expression" - .format(member)) from e + .format(member_value, member_name)) from e + else: + continue + if shape is not None: if member_shape.signed and not shape.signed: warnings.warn( - message="Value of enumeration member {!r} is signed, but the enumeration " - "shape is {!r}" # the repr will be `unsigned(X)` - .format(member, shape), + message="Value {!r} of enumeration member {!r} is signed, but " + "the enumeration shape is {!r}" # the repr will be `unsigned(X)` + .format(member_value, member_name, shape), category=SyntaxWarning, stacklevel=2) elif (member_shape.width > shape.width or member_shape.width == shape.width and shape.signed and not member_shape.signed): warnings.warn( - message="Value of enumeration member {!r} will be truncated to " + message="Value {!r} of enumeration member {!r} will be truncated to " "the enumeration shape {!r}" - .format(member, shape), + .format(member_value, member_name, shape), category=SyntaxWarning, stacklevel=2) + cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs) + if shape is not None: + # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that + # the values of every member can be cast to the provided shape without truncation. + cls._amaranth_shape_ = shape else: # Shape is not provided explicitly. Behave the same as a standard enumeration; # the lack of `_amaranth_shape_` attribute is used to emit a warning when such diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 60bdc5f..49d0a2b 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -12,8 +12,7 @@ class EnumTestCase(FHDLTestCase): def test_non_int_members_wrong(self): with self.assertRaisesRegex(TypeError, - r"^Value of enumeration member must be " - r"a constant-castable expression$"): + r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"): class EnumA(Enum, shape=unsigned(1)): A = "str" @@ -54,24 +53,24 @@ class EnumTestCase(FHDLTestCase): def test_shape_explicit_wrong_signed_mismatch(self): with self.assertWarnsRegex(SyntaxWarning, - r"^Value of enumeration member is signed, but the enumeration " + r"^Value -1 of enumeration member 'A' is signed, but the enumeration " r"shape is unsigned\(1\)$"): class EnumA(Enum, shape=unsigned(1)): A = -1 def test_shape_explicit_wrong_too_wide(self): with self.assertWarnsRegex(SyntaxWarning, - r"^Value of enumeration member will be truncated to the enumeration " + r"^Value 2 of enumeration member 'A' will be truncated to the enumeration " r"shape unsigned\(1\)$"): class EnumA(Enum, shape=unsigned(1)): A = 2 with self.assertWarnsRegex(SyntaxWarning, - r"^Value of enumeration member will be truncated to the enumeration " + r"^Value 1 of enumeration member 'A' will be truncated to the enumeration " r"shape signed\(1\)$"): class EnumB(Enum, shape=signed(1)): A = 1 with self.assertWarnsRegex(SyntaxWarning, - r"^Value of enumeration member will be truncated to the " + r"^Value -2 of enumeration member 'A' will be truncated to the " r"enumeration shape signed\(1\)$"): class EnumC(Enum, shape=signed(1)): A = -2