Implement RFC 31: Enumeration type safety.

This commit is contained in:
Wanda 2023-11-22 07:35:55 +01:00 committed by Catherine
parent b0b193f1ad
commit ef5cfa72bc
4 changed files with 456 additions and 22 deletions

View file

@ -1,11 +1,12 @@
import enum as py_enum
import warnings
import operator
from ..hdl.ast import Value, Shape, ShapeCastable, Const
from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const
from ..hdl._repr import *
__all__ = py_enum.__all__
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]
for _member in py_enum.__all__:
@ -23,14 +24,18 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
:class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and
:meth:`as_shape` is never called, it places no restrictions on the enumeration class
or the values of its members.
When a :ref:`value-castable <lang-valuecasting>` is cast to an enum type that is an instance
of this metaclass, it can be automatically wrapped in a view class. A custom view class
can be specified by passing the ``view_class=`` keyword argument when creating the enum class.
"""
# TODO: remove this shim once py3.8 support is dropped
@classmethod
def __prepare__(metacls, name, bases, shape=None, **kwargs):
def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs):
return super().__prepare__(name, bases, **kwargs)
def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs):
if shape is not None:
shape = Shape.cast(shape)
# Prepare enumeration members for instantiation. This logic is unfortunately very
@ -89,6 +94,8 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
# the values of every member can be cast to the provided shape without truncation.
cls._amaranth_shape_ = shape
if view_class is not None:
cls._amaranth_view_class_ = view_class
else:
# Shape is not provided explicitly. Behave the same as a standard enumeration;
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
@ -127,17 +134,32 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
return Shape._cast_plain_enum(cls)
def __call__(cls, value, *args, **kwargs):
# :class:`py_enum.Enum` uses ``__call__()`` for type casting: ``E(x)`` returns
# the enumeration member whose value equals ``x``. In this case, ``x`` must be a concrete
# value.
# Amaranth extends this to indefinite values, but conceptually the operation is the same:
# :class:`View` calls :meth:`Enum.__call__` to go from a :class:`Value` to something
# representing this enumeration with that value.
# At the moment however, for historical reasons, this is just the value itself. This works
# and is backwards-compatible but is limiting in that it does not allow us to e.g. catch
# comparisons with enum members of the wrong type.
if isinstance(value, Value):
return value
"""Cast the value to this enum type.
When given an integer constant, it returns the corresponding enum value, like a standard
Python enumeration.
When given a :ref:`value-castable <lang-valuecasting>`, it is cast to a value, then wrapped
in the ``view_class`` specified for this enum type (:class:`EnumView` for :class:`Enum`,
:class:`FlagView` for :class:`Flag`, or a custom user-defined class). If the type has no
``view_class`` (like :class:`IntEnum` or :class:`IntFlag`), a plain
:class:`Value` is returned.
Returns
-------
instance of itself
For integer values, or instances of itself.
:class:`EnumView` or its subclass
For value-castables, as defined by the ``view_class`` keyword argument.
:class:`Value`
For value-castables, when a view class is not specified for this enum.
"""
if isinstance(value, (Value, ValueCastable)):
value = Value.cast(value)
if cls._amaranth_view_class_ is None:
return value
else:
return cls._amaranth_view_class_(cls, value)
return super().__call__(value, *args, **kwargs)
def const(cls, init):
@ -149,7 +171,7 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
member = cls(0)
else:
member = cls(init)
return Const(member.value, cls.as_shape())
return cls(Const(member.value, cls.as_shape()))
def _value_repr(cls, value):
yield Repr(FormatEnum(cls), value)
@ -157,7 +179,7 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
class Enum(py_enum.Enum):
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
its metaclass."""
its metaclass and :class:`EnumView` as its view class."""
class IntEnum(py_enum.IntEnum):
@ -167,16 +189,197 @@ class IntEnum(py_enum.IntEnum):
class Flag(py_enum.Flag):
"""Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as
its metaclass."""
its metaclass and :class:`FlagView` as its view class."""
class IntFlag(py_enum.IntFlag):
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
its metaclass."""
# Fix up the metaclass after the fact: the metaclass __new__ requires these classes
# to already be present, and also would not install itself on them due to lack of shape.
Enum.__class__ = EnumMeta
IntEnum.__class__ = EnumMeta
Flag.__class__ = EnumMeta
IntFlag.__class__ = EnumMeta
class EnumView(ValueCastable):
"""The view class used for :class:`Enum`.
Wraps a :class:`Value` and only allows type-safe operations. The only operators allowed are
equality comparisons (``==`` and ``!=``) with another :class:`EnumView` of the same enum type.
"""
def __init__(self, enum, target):
"""Constructs a view with the given enum type and target
(a :ref:`value-castable <lang-valuecasting>`).
"""
if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"):
raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}")
try:
cast_target = Value.cast(target)
except TypeError as e:
raise TypeError("EnumView target must be a value-castable object, not {!r}"
.format(target)) from e
if cast_target.shape() != enum.as_shape():
raise TypeError("EnumView target must have the same shape as the enum")
self.enum = enum
self.target = cast_target
def shape(self):
"""Returns the underlying enum type."""
return self.enum
@ValueCastable.lowermethod
def as_value(self):
"""Returns the underlying value."""
return self.target
def eq(self, other):
"""Assign to the underlying value.
Returns
-------
:class:`Assign`
``self.as_value().eq(other)``
"""
return self.as_value().eq(other)
def __add__(self, other):
raise TypeError("cannot perform arithmetic operations on non-IntEnum enum")
__radd__ = __add__
__sub__ = __add__
__rsub__ = __add__
__mul__ = __add__
__rmul__ = __add__
__floordiv__ = __add__
__rfloordiv__ = __add__
__mod__ = __add__
__rmod__ = __add__
__lshift__ = __add__
__rlshift__ = __add__
__rshift__ = __add__
__rrshift__ = __add__
__lt__ = __add__
__le__ = __add__
__gt__ = __add__
__ge__ = __add__
def __and__(self, other):
raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum")
__rand__ = __and__
__or__ = __and__
__ror__ = __and__
__xor__ = __and__
__rxor__ = __and__
def __eq__(self, other):
"""Compares the underlying value for equality.
The other operand has to be either another :class:`EnumView` with the same enum type, or
a plain value of the underlying enum.
Returns
-------
:class:`Value`
The result of the equality comparison, as a single-bit value.
"""
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, EnumView) or other.enum is not self.enum:
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
return self.target == other.target
def __ne__(self, other):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, EnumView) or other.enum is not self.enum:
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
return self.target != other.target
def __repr__(self):
return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})"
class FlagView(EnumView):
"""The view class used for :class:`Flag`.
In addition to the operations allowed by :class:`EnumView`, it allows bitwise operations among
values of the same enum type."""
def __invert__(self):
"""Inverts all flags in this value and returns another :ref:`FlagView`.
Note that this is not equivalent to applying bitwise negation to the underlying value:
just like the Python :class:`enum.Flag` class, only bits corresponding to flags actually
defined in the enumeration are included in the result.
Returns
-------
:class:`FlagView`
"""
if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP):
return self.enum._amaranth_view_class_(self.enum, ~self.target)
else:
singles_mask = 0
for flag in self.enum:
if (flag.value & (flag.value - 1)) == 0:
singles_mask |= flag.value
return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask)
def __bitop(self, other, op):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, FlagView) or other.enum is not self.enum:
raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type")
return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target))
def __and__(self, other):
"""Performs a bitwise AND and returns another :class:`FlagView`.
The other operand has to be either another :class:`FlagView` of the same enum type, or
a plain value of the underlying enum type.
Returns
-------
:class:`FlagView`
"""
return self.__bitop(other, operator.__and__)
def __or__(self, other):
"""Performs a bitwise OR and returns another :class:`FlagView`.
The other operand has to be either another :class:`FlagView` of the same enum type, or
a plain value of the underlying enum type.
Returns
-------
:class:`FlagView`
"""
return self.__bitop(other, operator.__or__)
def __xor__(self, other):
"""Performs a bitwise XOR and returns another :class:`FlagView`.
The other operand has to be either another :class:`FlagView` of the same enum type, or
a plain value of the underlying enum type.
Returns
-------
:class:`FlagView`
"""
return self.__bitop(other, operator.__xor__)
__rand__ = __and__
__ror__ = __or__
__rxor__ = __xor__
Enum._amaranth_view_class_ = EnumView
IntEnum._amaranth_view_class_ = None
Flag._amaranth_view_class_ = FlagView
IntFlag._amaranth_view_class_ = None