diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 02e9dfe..d8b7416 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -1,11 +1,12 @@ import enum as py_enum import warnings +import operator -from ..hdl.ast import Value, Shape, ShapeCastable, Const +from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const from ..hdl._repr import * -__all__ = py_enum.__all__ +__all__ = py_enum.__all__ + ["EnumView", "FlagView"] for _member in py_enum.__all__: @@ -23,14 +24,18 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): :class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and :meth:`as_shape` is never called, it places no restrictions on the enumeration class or the values of its members. + + When a :ref:`value-castable ` is cast to an enum type that is an instance + of this metaclass, it can be automatically wrapped in a view class. A custom view class + can be specified by passing the ``view_class=`` keyword argument when creating the enum class. """ # TODO: remove this shim once py3.8 support is dropped @classmethod - def __prepare__(metacls, name, bases, shape=None, **kwargs): + def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs): return super().__prepare__(name, bases, **kwargs) - def __new__(metacls, name, bases, namespace, shape=None, **kwargs): + def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs): if shape is not None: shape = Shape.cast(shape) # Prepare enumeration members for instantiation. This logic is unfortunately very @@ -89,6 +94,8 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that # the values of every member can be cast to the provided shape without truncation. cls._amaranth_shape_ = shape + if view_class is not None: + cls._amaranth_view_class_ = view_class else: # Shape is not provided explicitly. Behave the same as a standard enumeration; # the lack of `_amaranth_shape_` attribute is used to emit a warning when such @@ -127,17 +134,32 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): return Shape._cast_plain_enum(cls) def __call__(cls, value, *args, **kwargs): - # :class:`py_enum.Enum` uses ``__call__()`` for type casting: ``E(x)`` returns - # the enumeration member whose value equals ``x``. In this case, ``x`` must be a concrete - # value. - # Amaranth extends this to indefinite values, but conceptually the operation is the same: - # :class:`View` calls :meth:`Enum.__call__` to go from a :class:`Value` to something - # representing this enumeration with that value. - # At the moment however, for historical reasons, this is just the value itself. This works - # and is backwards-compatible but is limiting in that it does not allow us to e.g. catch - # comparisons with enum members of the wrong type. - if isinstance(value, Value): - return value + """Cast the value to this enum type. + + When given an integer constant, it returns the corresponding enum value, like a standard + Python enumeration. + + When given a :ref:`value-castable `, it is cast to a value, then wrapped + in the ``view_class`` specified for this enum type (:class:`EnumView` for :class:`Enum`, + :class:`FlagView` for :class:`Flag`, or a custom user-defined class). If the type has no + ``view_class`` (like :class:`IntEnum` or :class:`IntFlag`), a plain + :class:`Value` is returned. + + Returns + ------- + instance of itself + For integer values, or instances of itself. + :class:`EnumView` or its subclass + For value-castables, as defined by the ``view_class`` keyword argument. + :class:`Value` + For value-castables, when a view class is not specified for this enum. + """ + if isinstance(value, (Value, ValueCastable)): + value = Value.cast(value) + if cls._amaranth_view_class_ is None: + return value + else: + return cls._amaranth_view_class_(cls, value) return super().__call__(value, *args, **kwargs) def const(cls, init): @@ -149,7 +171,7 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): member = cls(0) else: member = cls(init) - return Const(member.value, cls.as_shape()) + return cls(Const(member.value, cls.as_shape())) def _value_repr(cls, value): yield Repr(FormatEnum(cls), value) @@ -157,7 +179,7 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): class Enum(py_enum.Enum): """Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as - its metaclass.""" + its metaclass and :class:`EnumView` as its view class.""" class IntEnum(py_enum.IntEnum): @@ -167,16 +189,197 @@ class IntEnum(py_enum.IntEnum): class Flag(py_enum.Flag): """Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as - its metaclass.""" + its metaclass and :class:`FlagView` as its view class.""" class IntFlag(py_enum.IntFlag): """Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as its metaclass.""" + # Fix up the metaclass after the fact: the metaclass __new__ requires these classes # to already be present, and also would not install itself on them due to lack of shape. Enum.__class__ = EnumMeta IntEnum.__class__ = EnumMeta Flag.__class__ = EnumMeta IntFlag.__class__ = EnumMeta + + +class EnumView(ValueCastable): + """The view class used for :class:`Enum`. + + Wraps a :class:`Value` and only allows type-safe operations. The only operators allowed are + equality comparisons (``==`` and ``!=``) with another :class:`EnumView` of the same enum type. + """ + + def __init__(self, enum, target): + """Constructs a view with the given enum type and target + (a :ref:`value-castable `). + """ + if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"): + raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}") + try: + cast_target = Value.cast(target) + except TypeError as e: + raise TypeError("EnumView target must be a value-castable object, not {!r}" + .format(target)) from e + if cast_target.shape() != enum.as_shape(): + raise TypeError("EnumView target must have the same shape as the enum") + self.enum = enum + self.target = cast_target + + def shape(self): + """Returns the underlying enum type.""" + return self.enum + + @ValueCastable.lowermethod + def as_value(self): + """Returns the underlying value.""" + return self.target + + def eq(self, other): + """Assign to the underlying value. + + Returns + ------- + :class:`Assign` + ``self.as_value().eq(other)`` + """ + return self.as_value().eq(other) + + def __add__(self, other): + raise TypeError("cannot perform arithmetic operations on non-IntEnum enum") + + __radd__ = __add__ + __sub__ = __add__ + __rsub__ = __add__ + __mul__ = __add__ + __rmul__ = __add__ + __floordiv__ = __add__ + __rfloordiv__ = __add__ + __mod__ = __add__ + __rmod__ = __add__ + __lshift__ = __add__ + __rlshift__ = __add__ + __rshift__ = __add__ + __rrshift__ = __add__ + __lt__ = __add__ + __le__ = __add__ + __gt__ = __add__ + __ge__ = __add__ + + def __and__(self, other): + raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum") + + __rand__ = __and__ + __or__ = __and__ + __ror__ = __and__ + __xor__ = __and__ + __rxor__ = __and__ + + def __eq__(self, other): + """Compares the underlying value for equality. + + The other operand has to be either another :class:`EnumView` with the same enum type, or + a plain value of the underlying enum. + + Returns + ------- + :class:`Value` + The result of the equality comparison, as a single-bit value. + """ + if isinstance(other, self.enum): + other = self.enum(Value.cast(other)) + if not isinstance(other, EnumView) or other.enum is not self.enum: + raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type") + return self.target == other.target + + def __ne__(self, other): + if isinstance(other, self.enum): + other = self.enum(Value.cast(other)) + if not isinstance(other, EnumView) or other.enum is not self.enum: + raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type") + return self.target != other.target + + def __repr__(self): + return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})" + + +class FlagView(EnumView): + """The view class used for :class:`Flag`. + + In addition to the operations allowed by :class:`EnumView`, it allows bitwise operations among + values of the same enum type.""" + + def __invert__(self): + """Inverts all flags in this value and returns another :ref:`FlagView`. + + Note that this is not equivalent to applying bitwise negation to the underlying value: + just like the Python :class:`enum.Flag` class, only bits corresponding to flags actually + defined in the enumeration are included in the result. + + Returns + ------- + :class:`FlagView` + """ + if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP): + return self.enum._amaranth_view_class_(self.enum, ~self.target) + else: + singles_mask = 0 + for flag in self.enum: + if (flag.value & (flag.value - 1)) == 0: + singles_mask |= flag.value + return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask) + + def __bitop(self, other, op): + if isinstance(other, self.enum): + other = self.enum(Value.cast(other)) + if not isinstance(other, FlagView) or other.enum is not self.enum: + raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type") + return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target)) + + def __and__(self, other): + """Performs a bitwise AND and returns another :class:`FlagView`. + + The other operand has to be either another :class:`FlagView` of the same enum type, or + a plain value of the underlying enum type. + + Returns + ------- + :class:`FlagView` + """ + return self.__bitop(other, operator.__and__) + + def __or__(self, other): + """Performs a bitwise OR and returns another :class:`FlagView`. + + The other operand has to be either another :class:`FlagView` of the same enum type, or + a plain value of the underlying enum type. + + Returns + ------- + :class:`FlagView` + """ + return self.__bitop(other, operator.__or__) + + def __xor__(self, other): + """Performs a bitwise XOR and returns another :class:`FlagView`. + + The other operand has to be either another :class:`FlagView` of the same enum type, or + a plain value of the underlying enum type. + + Returns + ------- + :class:`FlagView` + """ + return self.__bitop(other, operator.__xor__) + + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ + + +Enum._amaranth_view_class_ = EnumView +IntEnum._amaranth_view_class_ = None +Flag._amaranth_view_class_ = FlagView +IntFlag._amaranth_view_class_ = None diff --git a/docs/changes.rst b/docs/changes.rst index 019fbf0..7e1c698 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -60,6 +60,7 @@ Implemented RFCs .. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html .. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html .. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html +.. _RFC 31: https://amaranth-lang.org/rfcs/0031-enumeration-type-safety.html * `RFC 1`_: Aggregate data structure library @@ -77,6 +78,7 @@ Implemented RFCs * `RFC 20`_: Deprecate non-FWFT FIFOs * `RFC 22`_: Define ``ValueCastable.shape()`` * `RFC 28`_: Allow overriding ``Value`` operators +* `RFC 31`_: Enumeration type safety Language changes diff --git a/docs/stdlib/enum.rst b/docs/stdlib/enum.rst index 6f6897b..0fa87d9 100644 --- a/docs/stdlib/enum.rst +++ b/docs/stdlib/enum.rst @@ -24,6 +24,8 @@ A shape can be specified for an enumeration with the ``shape=`` keyword argument >>> Shape.cast(Funct) unsigned(4) + >>> Value.cast(Funct.ADD) + (const 4'd0) Any :ref:`constant-castable ` expression can be used as the value of a member: @@ -57,6 +59,57 @@ The ``shape=`` argument is optional. If not specified, classes from this module In this way, this module is a drop-in replacement for the standard :mod:`enum` module, and in an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``. +Signals with :class:`Enum` or :class:`Flag` based shape are automatically wrapped in the :class:`EnumView` or :class:`FlagView` value-castable wrappers, which ensure type safety. Any :ref:`value-castable ` can also be explicitly wrapped in a view class by casting it to the enum type: + +.. doctest:: + + >>> a = Signal(Funct) + >>> b = Signal(Op) + >>> type(a) + + >>> a == b + Traceback (most recent call last): + File "", line 1, in + TypeError: an EnumView can only be compared to value or other EnumView of the same enum type + >>> c = Signal(4) + >>> type(Funct(c)) + + +Like the standard Python :class:`enum.IntEnum` and :class:`enum.IntFlag` classes, the Amaranth :class:`IntEnum` and :class:`IntFlag` classes are loosely typed and will not be subject to wrapping in view classes: + +.. testcode:: + + class TransparentEnum(enum.IntEnum, shape=unsigned(4)): + FOO = 0 + BAR = 1 + +.. doctest:: + + >>> a = Signal(TransparentEnum) + >>> type(a) + + +It is also possible to define a custom view class for a given enum: + +.. testcode:: + + class InstrView(enum.EnumView): + def has_immediate(self): + return (self == Instr.ADDI) | (self == Instr.SUBI) + + class Instr(enum.Enum, shape=5, view_class=InstrView): + ADD = Cat(Funct.ADD, Op.REG) + ADDI = Cat(Funct.ADD, Op.IMM) + SUB = Cat(Funct.SUB, Op.REG) + SUBI = Cat(Funct.SUB, Op.IMM) + +.. doctest:: + + >>> a = Signal(Instr) + >>> type(a) + + >>> a.has_immediate() + (| (== (sig a) (const 5'd16)) (== (sig a) (const 5'd17))) Metaclass ========= @@ -71,3 +124,9 @@ Base classes .. autoclass:: IntEnum() .. autoclass:: Flag() .. autoclass:: IntFlag() + +View classes +============ + +.. autoclass:: EnumView() +.. autoclass:: FlagView() \ No newline at end of file diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 99425ff..83dd038 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -1,7 +1,9 @@ import enum as py_enum +import operator +import sys from amaranth import * -from amaranth.lib.enum import Enum, EnumMeta +from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView from .utils import * @@ -103,9 +105,9 @@ class EnumTestCase(FHDLTestCase): class EnumA(Enum, shape=8): Z = 0 A = 10 - self.assertRepr(EnumA.const(None), "(const 8'd0)") - self.assertRepr(EnumA.const(10), "(const 8'd10)") - self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)") + self.assertRepr(EnumA.const(None), "EnumView(EnumA, (const 8'd0))") + self.assertRepr(EnumA.const(10), "EnumView(EnumA, (const 8'd10))") + self.assertRepr(EnumA.const(EnumA.A), "EnumView(EnumA, (const 8'd10))") def test_shape_implicit_wrong_in_concat(self): class EnumA(Enum): @@ -118,3 +120,171 @@ class EnumTestCase(FHDLTestCase): def test_functional(self): Enum("FOO", ["BAR", "BAZ"]) + + def test_int_enum(self): + class EnumA(IntEnum, shape=signed(4)): + A = 0 + B = -3 + a = Signal(EnumA) + self.assertRepr(a, "(sig a)") + + def test_enum_view(self): + class EnumA(Enum, shape=signed(4)): + A = 0 + B = -3 + class EnumB(Enum, shape=signed(4)): + C = 0 + D = 5 + a = Signal(EnumA) + b = Signal(EnumB) + c = Signal(EnumA) + d = Signal(4) + self.assertIsInstance(a, EnumView) + self.assertIs(a.shape(), EnumA) + self.assertRepr(a, "EnumView(EnumA, (sig a))") + self.assertRepr(a.as_value(), "(sig a)") + self.assertRepr(a.eq(c), "(eq (sig a) (sig c))") + for op in [ + operator.__add__, + operator.__sub__, + operator.__mul__, + operator.__floordiv__, + operator.__mod__, + operator.__lshift__, + operator.__rshift__, + operator.__and__, + operator.__or__, + operator.__xor__, + operator.__lt__, + operator.__le__, + operator.__gt__, + operator.__ge__, + ]: + with self.assertRaises(TypeError): + op(a, a) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, EnumA.A) + for op in [ + operator.__eq__, + operator.__ne__, + ]: + with self.assertRaises(TypeError): + op(a, b) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, EnumB.C) + self.assertRepr(a == c, "(== (sig a) (sig c))") + self.assertRepr(a != c, "(!= (sig a) (sig c))") + self.assertRepr(a == EnumA.B, "(== (sig a) (const 4'sd-3))") + self.assertRepr(EnumA.B == a, "(== (sig a) (const 4'sd-3))") + self.assertRepr(a != EnumA.B, "(!= (sig a) (const 4'sd-3))") + + def test_flag_view(self): + class FlagA(Flag, shape=unsigned(4)): + A = 1 + B = 4 + class FlagB(Flag, shape=unsigned(4)): + C = 1 + D = 2 + a = Signal(FlagA) + b = Signal(FlagB) + c = Signal(FlagA) + d = Signal(4) + self.assertIsInstance(a, FlagView) + self.assertRepr(a, "FlagView(FlagA, (sig a))") + for op in [ + operator.__add__, + operator.__sub__, + operator.__mul__, + operator.__floordiv__, + operator.__mod__, + operator.__lshift__, + operator.__rshift__, + operator.__lt__, + operator.__le__, + operator.__gt__, + operator.__ge__, + ]: + with self.assertRaises(TypeError): + op(a, a) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, FlagA.A) + for op in [ + operator.__eq__, + operator.__ne__, + operator.__and__, + operator.__or__, + operator.__xor__, + ]: + with self.assertRaises(TypeError): + op(a, b) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, FlagB.C) + self.assertRepr(a == c, "(== (sig a) (sig c))") + self.assertRepr(a != c, "(!= (sig a) (sig c))") + self.assertRepr(a == FlagA.B, "(== (sig a) (const 4'd4))") + self.assertRepr(FlagA.B == a, "(== (sig a) (const 4'd4))") + self.assertRepr(a != FlagA.B, "(!= (sig a) (const 4'd4))") + self.assertRepr(a | c, "FlagView(FlagA, (| (sig a) (sig c)))") + self.assertRepr(a & c, "FlagView(FlagA, (& (sig a) (sig c)))") + self.assertRepr(a ^ c, "FlagView(FlagA, (^ (sig a) (sig c)))") + self.assertRepr(~a, "FlagView(FlagA, (& (~ (sig a)) (const 3'd5)))") + self.assertRepr(a | FlagA.B, "FlagView(FlagA, (| (sig a) (const 4'd4)))") + if sys.version_info >= (3, 11): + class FlagC(Flag, shape=unsigned(4), boundary=py_enum.KEEP): + A = 1 + B = 4 + e = Signal(FlagC) + self.assertRepr(~e, "FlagView(FlagC, (~ (sig e)))") + + def test_enum_view_wrong(self): + class EnumA(Enum, shape=signed(4)): + A = 0 + B = -3 + + a = Signal(2) + with self.assertRaisesRegex(TypeError, + r'^EnumView target must have the same shape as the enum$'): + EnumA(a) + with self.assertRaisesRegex(TypeError, + r'^EnumView target must be a value-castable object, not .*$'): + EnumView(EnumA, "a") + + class EnumB(Enum): + C = 0 + D = 1 + with self.assertRaisesRegex(TypeError, + r'^EnumView type must be an enum with shape, not .*$'): + EnumView(EnumB, 3) + + def test_enum_view_custom(self): + class CustomView(EnumView): + pass + class EnumA(Enum, view_class=CustomView, shape=unsigned(2)): + A = 0 + B = 1 + a = Signal(EnumA) + assert isinstance(a, CustomView)