diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index a97b90f..f5a249c 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -93,13 +93,21 @@ class Shape: bits_for(obj.stop - obj.step, signed)) return Shape(width, signed) elif isinstance(obj, type) and issubclass(obj, Enum): - min_value = min(member.value for member in obj) - max_value = max(member.value for member in obj) - if not isinstance(min_value, int) or not isinstance(max_value, int): - raise TypeError("Only enumerations with integer values can be used " - "as value shapes") - signed = min_value < 0 or max_value < 0 - width = max(bits_for(min_value, signed), bits_for(max_value, signed)) + signed = False + width = 0 + for member in obj: + try: + member_shape = Const.cast(member.value).shape() + except TypeError as e: + raise TypeError("Only enumerations whose members have constant-castable " + "values can be used in Amaranth code") + if not signed and member_shape.signed: + signed = True + width = max(width + 1, member_shape.width) + elif signed and not member_shape.signed: + width = max(width, member_shape.width + 1) + else: + width = max(width, member_shape.width) return Shape(width, signed) elif isinstance(obj, ShapeCastable): new_obj = obj.as_shape() @@ -402,11 +410,8 @@ class Value(metaclass=ABCMeta): ``1`` if any pattern matches the value, ``0`` otherwise. """ matches = [] + # This code should accept exactly the same patterns as `with m.Case(...):`. for pattern in patterns: - if not isinstance(pattern, (int, str, Enum)): - raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, " - "not {!r}" - .format(pattern)) if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " "bits, and may include whitespace" @@ -416,23 +421,26 @@ class Value(metaclass=ABCMeta): raise SyntaxError("Match pattern '{}' must have the same width as match value " "(which is {})" .format(pattern, len(self))) - if isinstance(pattern, int) and bits_for(pattern) > len(self): - warnings.warn("Match pattern '{:b}' is wider than match value " - "(which has width {}); comparison will never be true" - .format(pattern, len(self)), - SyntaxWarning, stacklevel=3) - continue if isinstance(pattern, str): pattern = "".join(pattern.split()) # remove whitespace mask = int(pattern.replace("0", "1").replace("-", "0"), 2) pattern = int(pattern.replace("-", "0"), 2) matches.append((self & mask) == pattern) - elif isinstance(pattern, int): - matches.append(self == pattern) - elif isinstance(pattern, Enum): - matches.append(self == pattern.value) else: - assert False + try: + orig_pattern, pattern = pattern, Const.cast(pattern) + except TypeError as e: + raise SyntaxError("Match pattern must be a string or a constant-castable " + "expression, not {!r}" + .format(pattern)) from e + pattern_len = bits_for(pattern.value) + if pattern_len > len(self): + warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value " + "(which has width {}); comparison will never be true" + .format(orig_pattern, pattern_len, pattern.value, len(self)), + SyntaxWarning, stacklevel=2) + continue + matches.append(self == pattern) if not matches: return Const(0) elif len(matches) == 1: @@ -560,9 +568,6 @@ class Value(metaclass=ABCMeta): def _rhs_signals(self): pass # :nocov: - def _as_const(self): - raise TypeError("Value {!r} cannot be evaluated as constant".format(self)) - __hash__ = None @@ -595,6 +600,28 @@ class Const(Value): value |= ~mask return value + @staticmethod + def cast(obj): + """Converts ``obj`` to an Amaranth constant. + + First, ``obj`` is converted to a value using :meth:`Value.cast`. If it is a constant, it + is returned. If it is a constant-castable expression, it is evaluated and returned. + Otherwise, :exn:`TypeError` is raised. + """ + obj = Value.cast(obj) + if type(obj) is Const: + return obj + elif type(obj) is Cat: + value = 0 + width = 0 + for part in obj.parts: + const = Const.cast(part) + value |= const.value << width + width += len(const) + return Const(value, width) + else: + raise TypeError("Value {!r} cannot be converted to an Amaranth constant".format(obj)) + def __init__(self, value, shape=None, *, src_loc_at=0): # We deliberately do not call Value.__init__ here. self.value = int(value) @@ -617,9 +644,6 @@ class Const(Value): def _rhs_signals(self): return SignalSet() - def _as_const(self): - return self.value - def __repr__(self): return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value) @@ -858,13 +882,6 @@ class Cat(Value): def _rhs_signals(self): return union((part._rhs_signals() for part in self.parts), start=SignalSet()) - def _as_const(self): - value = 0 - for part in reversed(self.parts): - value <<= len(part) - value |= part._as_const() - return value - def __repr__(self): return "(cat {})".format(" ".join(map(repr, self.parts))) diff --git a/amaranth/hdl/dsl.py b/amaranth/hdl/dsl.py index 7ad0e3e..feabef1 100644 --- a/amaranth/hdl/dsl.py +++ b/amaranth/hdl/dsl.py @@ -305,11 +305,8 @@ class Module(_ModuleBuilderRoot, Elaboratable): src_loc = tracer.get_src_loc(src_loc_at=1) switch_data = self._get_ctrl("Switch") new_patterns = () + # This code should accept exactly the same patterns as `v.matches(...)`. for pattern in patterns: - if not isinstance(pattern, (int, str, Enum)): - raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, " - "not {!r}" - .format(pattern)) if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern): raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) " "bits, and may include whitespace" @@ -319,20 +316,24 @@ class Module(_ModuleBuilderRoot, Elaboratable): raise SyntaxError("Case pattern '{}' must have the same width as switch value " "(which is {})" .format(pattern, len(switch_data["test"]))) - if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]): - warnings.warn("Case pattern '{:b}' is wider than switch value " - "(which has width {}); comparison will never be true" - .format(pattern, len(switch_data["test"])), - SyntaxWarning, stacklevel=3) - continue - if isinstance(pattern, Enum) and bits_for(pattern.value) > len(switch_data["test"]): - warnings.warn("Case pattern '{:b}' ({}.{}) is wider than switch value " - "(which has width {}); comparison will never be true" - .format(pattern.value, pattern.__class__.__name__, pattern.name, - len(switch_data["test"])), - SyntaxWarning, stacklevel=3) - continue - new_patterns = (*new_patterns, pattern) + if isinstance(pattern, str): + new_patterns = (*new_patterns, pattern) + else: + try: + orig_pattern, pattern = pattern, Const.cast(pattern) + except TypeError as e: + raise SyntaxError("Case pattern must be a string or a constant-castable " + "expression, not {!r}" + .format(pattern)) from e + pattern_len = bits_for(pattern.value) + if pattern_len > len(switch_data["test"]): + warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value " + "(which has width {}); comparison will never be true" + .format(orig_pattern, pattern_len, pattern.value, + len(switch_data["test"])), + SyntaxWarning, stacklevel=3) + continue + new_patterns = (*new_patterns, pattern.value) try: _outer_case, self._statements = self._statements, [] self._ctrl_context = None diff --git a/docs/lang.rst b/docs/lang.rst index 1e0459b..0da841c 100644 --- a/docs/lang.rst +++ b/docs/lang.rst @@ -34,6 +34,43 @@ All of the examples below assume that a glob import is used. from amaranth import * +.. _lang-shapes: + +Shapes +====== + +A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly: + +.. doctest:: + + >>> Shape(width=5, signed=False) + unsigned(5) + >>> Shape(width=12, signed=True) + signed(12) + +However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient: + +.. doctest:: + + >>> unsigned(5) == Shape(width=5, signed=False) + True + >>> signed(12) == Shape(width=12, signed=True) + True + + +Shapes of values +---------------- + +All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``. + +.. doctest:: + + >>> Const(5).shape() + unsigned(3) + >>> len(Const(5)) + 3 + + .. _lang-values: Values @@ -79,43 +116,6 @@ The shape of the constant can be specified explicitly, in which case the number' 0 -.. _lang-shapes: - -Shapes -====== - -A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly: - -.. doctest:: - - >>> Shape(width=5, signed=False) - unsigned(5) - >>> Shape(width=12, signed=True) - signed(12) - -However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient: - -.. doctest:: - - >>> unsigned(5) == Shape(width=5, signed=False) - True - >>> signed(12) == Shape(width=12, signed=True) - True - - -Shapes of values ----------------- - -All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``. - -.. doctest:: - - >>> Const(5).shape() - unsigned(3) - >>> len(Const(5)) - 3 - - .. _lang-shapecasting: Shape casting @@ -218,7 +218,7 @@ Specifying a shape with an enumeration is convenient for finite state machines, Value casting ============= -Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly. +Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting to values allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly. .. TODO: link to ValueCastable @@ -228,7 +228,7 @@ Casting to a value can be done explicitly with ``Value.cast``, but is usually im Values from integers -------------------- -Casting a value from an integer ``i`` is a shorthand for ``Const(i)``: +Casting a value from an integer ``i`` is equivalent to ``Const(i)``: .. doctest:: @@ -242,7 +242,7 @@ Casting a value from an integer ``i`` is a shorthand for ``Const(i)``: Values from enumeration members ------------------------------- -Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.value, type(m))``: +Casting a value from an enumeration member ``m`` is equivalent to ``Const(m.value, type(m))``: .. doctest:: @@ -254,6 +254,55 @@ Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.va If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration. + +.. _lang-constcasting: + +Constant casting +================ + +A subset of :ref:`values ` are *constant-castable*. If a value is constant-castable and all of its operands are also constant-castable, it can be converted to a :class:`Const`, the numeric value of which can then be read by Python code. This provides a way to perform computation on Amaranth values while constructing the design. + +.. TODO: link to m.Case and v.matches() below + +Constant-castable objects are accepted anywhere a constant integer is accepted. Casting to a constant can also be done explicitly with :meth:`Const.cast`: + +.. doctest:: + + >>> Const.cast(Cat(Direction.TOP, Direction.LEFT)) + (const 4'd4) + +.. TODO: uncomment when this actually works + +.. comment:: + + They may be used in enumeration members: + + .. testcode:: + + class Funct(enum.Enum): + ADD = 0 + ... + + class Op(enum.Enum): + REG = 0 + IMM = 1 + + class Instr(enum.Enum): + ADD = Cat(Funct.ADD, Op.REG) + ADDI = Cat(Funct.ADD, Op.IMM) + ... + + +.. note:: + + At the moment, only the following expressions are constant-castable: + + * :class:`Const` + * :class:`Cat` + + This list will be expanded in the future. + + .. _lang-signals: Signals diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 6bc7f49..0df07ae 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -133,7 +133,8 @@ class ShapeTestCase(FHDLTestCase): def test_cast_enum_bad(self): with self.assertRaisesRegex(TypeError, - r"^Only enumerations with integer values can be used as value shapes$"): + r"^Only enumerations whose members have constant-castable values can be used " + r"in Amaranth code$"): Shape.cast(StringEnum) def test_cast_bad(self): @@ -203,7 +204,8 @@ class ValueTestCase(FHDLTestCase): def test_cast_enum_wrong(self): with self.assertRaisesRegex(TypeError, - r"^Only enumerations with integer values can be used as value shapes$"): + r"^Only enumerations whose members have constant-castable values can be used " + r"in Amaranth code$"): Value.cast(StringEnum.FOO) def test_bool(self): @@ -614,7 +616,13 @@ class OperatorTestCase(FHDLTestCase): def test_matches_enum(self): s = Signal(SignedEnum) self.assertRepr(s.matches(SignedEnum.FOO), """ - (== (sig s) (const 1'sd-1)) + (== (sig s) (const 2'sd-1)) + """) + + def test_matches_const_castable(self): + s = Signal(4) + self.assertRepr(s.matches(Cat(C(0b10, 2), C(0b11, 2))), """ + (== (sig s) (const 4'd14)) """) def test_matches_width_wrong(self): @@ -623,21 +631,25 @@ class OperatorTestCase(FHDLTestCase): r"^Match pattern '--' must have the same width as match value \(which is 4\)$"): s.matches("--") with self.assertWarnsRegex(SyntaxWarning, - (r"^Match pattern '10110' is wider than match value \(which has width 4\); " - r"comparison will never be true$")): + r"^Match pattern '22' \(5'10110\) is wider than match value \(which has " + r"width 4\); comparison will never be true$"): s.matches(0b10110) + with self.assertWarnsRegex(SyntaxWarning, + r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider " + r"than match value \(which has width 4\); comparison will never be true$"): + s.matches(Cat(0, C(0b1011, 4))) def test_matches_bits_wrong(self): s = Signal(4) with self.assertRaisesRegex(SyntaxError, - (r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " - r"and may include whitespace$")): + r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, " + r"and may include whitespace$"): s.matches("abc") def test_matches_pattern_wrong(self): s = Signal(4) with self.assertRaisesRegex(SyntaxError, - r"^Match pattern must be an integer, a string, or an enumeration, not 1\.0$"): + r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"): s.matches(1.0) def test_hash(self): diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index c9f116d..552f120 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -446,6 +446,23 @@ class DSLTestCase(FHDLTestCase): ) """) + def test_Switch_const_castable(self): + class Color(Enum): + RED = 0 + BLUE = 1 + m = Module() + se = Signal(2) + with m.Switch(se): + with m.Case(Cat(Color.RED, Color.BLUE)): + m.d.comb += self.c1.eq(1) + self.assertRepr(m._statements, """ + ( + (switch (sig se) + (case 10 (eq (sig c1) (const 1'd1))) + ) + ) + """) + def test_Case_width_wrong(self): class Color(Enum): RED = 0b10101010 @@ -456,13 +473,13 @@ class DSLTestCase(FHDLTestCase): with m.Case("--"): pass with self.assertWarnsRegex(SyntaxWarning, - (r"^Case pattern '10110' is wider than switch value \(which has width 4\); " - r"comparison will never be true$")): + r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has " + r"width 4\); comparison will never be true$"): with m.Case(0b10110): pass with self.assertWarnsRegex(SyntaxWarning, - (r"^Case pattern '10101010' \(Color\.RED\) is wider than switch value " - r"\(which has width 4\); comparison will never be true$")): + r"^Case pattern '' \(8'10101010\) is wider than switch value " + r"\(which has width 4\); comparison will never be true$"): with m.Case(Color.RED): pass self.assertRepr(m._statements, """ @@ -484,7 +501,8 @@ class DSLTestCase(FHDLTestCase): m = Module() with m.Switch(self.w1): with self.assertRaisesRegex(SyntaxError, - r"^Case pattern must be an integer, a string, or an enumeration, not 1\.0$"): + r"^Case pattern must be a string or a constant-castable expression, " + r"not 1\.0$"): with m.Case(1.0): pass