diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index 4dcf4c0..9dc7cff 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -168,12 +168,12 @@ class Value(metaclass=ABCMeta): while True: if isinstance(obj, Value): return obj - elif isinstance(obj, int): - return Const(obj) - elif isinstance(obj, Enum): - return Const(obj.value, Shape.cast(type(obj))) elif isinstance(obj, ValueCastable): new_obj = obj.as_value() + elif isinstance(obj, Enum): + return Const(obj.value, Shape.cast(type(obj))) + elif isinstance(obj, int): + return Const(obj) else: raise TypeError("Object {!r} cannot be converted to an Amaranth value".format(obj)) if new_obj is obj: diff --git a/docs/lang.rst b/docs/lang.rst index 0c2a47e..06c0c46 100644 --- a/docs/lang.rst +++ b/docs/lang.rst @@ -235,6 +235,9 @@ Casting a value from an integer ``i`` is a shorthand for ``Const(i)``: >>> Value.cast(5) (const 3'd5) +.. note:: + + If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration. Values from enumeration members ------------------------------- @@ -247,6 +250,10 @@ Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.va (const 2'd1) +.. note:: + + If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration. + .. _lang-signals: Signals diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 55e5a88..71453de 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -23,6 +23,12 @@ class StringEnum(Enum): BAR = "b" +class TypedEnum(int, Enum): + FOO = 1 + BAR = 2 + BAZ = 3 + + class ShapeTestCase(FHDLTestCase): def test_make(self): s1 = Shape() @@ -199,6 +205,11 @@ class ValueTestCase(FHDLTestCase): self.assertIsInstance(e2, Const) self.assertEqual(e2.shape(), signed(2)) + def test_cast_typedenum(self): + e1 = Value.cast(TypedEnum.FOO) + self.assertIsInstance(e1, Const) + self.assertEqual(e1.shape(), unsigned(2)) + def test_cast_enum_wrong(self): with self.assertRaisesRegex(TypeError, r"^Only enumerations with integer values can be used as value shapes$"): @@ -781,6 +792,15 @@ class CatTestCase(FHDLTestCase): warnings.filterwarnings(action="error", category=SyntaxWarning) Cat(0, 1, 1, 0) + def test_enum(self): + class Color(Enum): + RED = 1 + BLUE = 2 + with warnings.catch_warnings(): + warnings.filterwarnings(action="error", category=SyntaxWarning) + c = Cat(Color.RED, Color.BLUE) + self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))") + def test_int_wrong(self): with self.assertWarnsRegex(SyntaxWarning, r"^Argument #1 of Cat\(\) is a bare integer 2 used in bit vector context; "