lib.enum: check member value shapes before subclassing. NFCI

This commit is a preparation for accepting const-castable expressions
as enum member values.

See #755.
This commit is contained in:
Catherine 2023-05-12 14:20:45 +00:00
parent 5f6b36e91f
commit bf8bbb0f63
2 changed files with 28 additions and 21 deletions

View file

@ -30,34 +30,42 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
return super().__prepare__(name, bases, **kwargs) return super().__prepare__(name, bases, **kwargs)
def __new__(metacls, name, bases, namespace, shape=None, **kwargs): def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs)
if shape is not None: if shape is not None:
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that shape = Shape.cast(shape)
# the values of every member can be cast to the provided shape without truncation. for member_name, member_value in namespace.items():
cls._amaranth_shape_ = shape = Shape.cast(shape) if py_enum._is_sunder(member_name) or py_enum._is_dunder(member_name):
for member in cls: continue
try: try:
member_shape = Const.cast(member.value).shape() member_shape = Const.cast(member_value).shape()
except TypeError as e: except TypeError as e:
raise TypeError("Value of enumeration member {!r} must be " if shape is not None:
raise TypeError("Value {!r} of enumeration member {!r} must be "
"a constant-castable expression" "a constant-castable expression"
.format(member)) from e .format(member_value, member_name)) from e
else:
continue
if shape is not None:
if member_shape.signed and not shape.signed: if member_shape.signed and not shape.signed:
warnings.warn( warnings.warn(
message="Value of enumeration member {!r} is signed, but the enumeration " message="Value {!r} of enumeration member {!r} is signed, but "
"shape is {!r}" # the repr will be `unsigned(X)` "the enumeration shape is {!r}" # the repr will be `unsigned(X)`
.format(member, shape), .format(member_value, member_name, shape),
category=SyntaxWarning, category=SyntaxWarning,
stacklevel=2) stacklevel=2)
elif (member_shape.width > shape.width or elif (member_shape.width > shape.width or
member_shape.width == shape.width and member_shape.width == shape.width and
shape.signed and not member_shape.signed): shape.signed and not member_shape.signed):
warnings.warn( warnings.warn(
message="Value of enumeration member {!r} will be truncated to " message="Value {!r} of enumeration member {!r} will be truncated to "
"the enumeration shape {!r}" "the enumeration shape {!r}"
.format(member, shape), .format(member_value, member_name, shape),
category=SyntaxWarning, category=SyntaxWarning,
stacklevel=2) stacklevel=2)
cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs)
if shape is not None:
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
# the values of every member can be cast to the provided shape without truncation.
cls._amaranth_shape_ = shape
else: else:
# Shape is not provided explicitly. Behave the same as a standard enumeration; # Shape is not provided explicitly. Behave the same as a standard enumeration;
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such # the lack of `_amaranth_shape_` attribute is used to emit a warning when such

View file

@ -12,8 +12,7 @@ class EnumTestCase(FHDLTestCase):
def test_non_int_members_wrong(self): def test_non_int_members_wrong(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Value of enumeration member <EnumA\.A: 'str'> must be " r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"):
r"a constant-castable expression$"):
class EnumA(Enum, shape=unsigned(1)): class EnumA(Enum, shape=unsigned(1)):
A = "str" A = "str"
@ -54,24 +53,24 @@ class EnumTestCase(FHDLTestCase):
def test_shape_explicit_wrong_signed_mismatch(self): def test_shape_explicit_wrong_signed_mismatch(self):
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Value of enumeration member <EnumA\.A: -1> is signed, but the enumeration " r"^Value -1 of enumeration member 'A' is signed, but the enumeration "
r"shape is unsigned\(1\)$"): r"shape is unsigned\(1\)$"):
class EnumA(Enum, shape=unsigned(1)): class EnumA(Enum, shape=unsigned(1)):
A = -1 A = -1
def test_shape_explicit_wrong_too_wide(self): def test_shape_explicit_wrong_too_wide(self):
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Value of enumeration member <EnumA\.A: 2> will be truncated to the enumeration " r"^Value 2 of enumeration member 'A' will be truncated to the enumeration "
r"shape unsigned\(1\)$"): r"shape unsigned\(1\)$"):
class EnumA(Enum, shape=unsigned(1)): class EnumA(Enum, shape=unsigned(1)):
A = 2 A = 2
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Value of enumeration member <EnumB\.A: 1> will be truncated to the enumeration " r"^Value 1 of enumeration member 'A' will be truncated to the enumeration "
r"shape signed\(1\)$"): r"shape signed\(1\)$"):
class EnumB(Enum, shape=signed(1)): class EnumB(Enum, shape=signed(1)):
A = 1 A = 1
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Value of enumeration member <EnumC\.A: -2> will be truncated to the " r"^Value -2 of enumeration member 'A' will be truncated to the "
r"enumeration shape signed\(1\)$"): r"enumeration shape signed\(1\)$"):
class EnumC(Enum, shape=signed(1)): class EnumC(Enum, shape=signed(1)):
A = -2 A = -2