Implement RFC 9: Constant initialization for shape-castable objects.
See amaranth-lang/rfcs#9 and #771.
This commit is contained in:
		
							parent
							
								
									ea5a150155
								
							
						
					
					
						commit
						54d5c4c047
					
				|  | @ -45,6 +45,9 @@ class ShapeCastable: | |||
|         if not hasattr(cls, "as_shape"): | ||||
|             raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " | ||||
|                             f"the `as_shape` method") | ||||
|         if not hasattr(cls, "const"): | ||||
|             raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " | ||||
|                             f"the `const` method") | ||||
| 
 | ||||
| 
 | ||||
| class Shape: | ||||
|  | @ -988,7 +991,7 @@ class Signal(Value, DUID): | |||
|     decoder : function | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, shape=None, *, name=None, reset=0, reset_less=False, | ||||
|     def __init__(self, shape=None, *, name=None, reset=None, reset_less=False, | ||||
|                  attrs=None, decoder=None, src_loc_at=0): | ||||
|         super().__init__(src_loc_at=src_loc_at) | ||||
| 
 | ||||
|  | @ -1005,12 +1008,24 @@ class Signal(Value, DUID): | |||
|         self.signed = shape.signed | ||||
| 
 | ||||
|         orig_reset = reset | ||||
|         try: | ||||
|             reset = Const.cast(reset) | ||||
|         except TypeError: | ||||
|             raise TypeError("Reset value must be a constant-castable expression, not {!r}" | ||||
|                             .format(orig_reset)) | ||||
|         if orig_reset not in (0, -1): # Avoid false positives for all-zeroes and all-ones | ||||
|         if isinstance(orig_shape, ShapeCastable): | ||||
|             try: | ||||
|                 reset = Const.cast(orig_shape.const(reset)) | ||||
|             except Exception: | ||||
|                 raise TypeError("Reset value must be a constant initializer of {!r}" | ||||
|                                 .format(orig_shape)) | ||||
|             if reset.shape() != Shape.cast(orig_shape): | ||||
|                 raise ValueError("Constant returned by {!r}.const() must have the shape that " | ||||
|                                  "it casts to, {!r}, and not {!r}" | ||||
|                                  .format(orig_shape, Shape.cast(orig_shape), | ||||
|                                          reset.shape())) | ||||
|         else: | ||||
|             try: | ||||
|                 reset = Const.cast(reset or 0) | ||||
|             except TypeError: | ||||
|                 raise TypeError("Reset value must be a constant-castable expression, not {!r}" | ||||
|                                 .format(orig_reset)) | ||||
|         if orig_reset not in (None, 0, -1): # Avoid false positives for all-zeroes and all-ones | ||||
|             if reset.shape().signed and not self.signed: | ||||
|                 warnings.warn( | ||||
|                     message="Reset value {!r} is signed, but the signal shape is {!r}" | ||||
|  |  | |||
|  | @ -201,29 +201,44 @@ class Layout(ShapeCastable, metaclass=ABCMeta): | |||
|         """ | ||||
|         return View(self, target) | ||||
| 
 | ||||
|     def _convert_to_int(self, value): | ||||
|         """Convert ``value``, which may be a dict or an array of field values, to an integer using | ||||
|         the representation defined by this layout. | ||||
|     def const(self, init): | ||||
|         """Convert a constant initializer to a constant. | ||||
| 
 | ||||
|         This method is private because Amaranth does not currently have a concept of | ||||
|         a constant initializer; this requires an RFC. It will be renamed or removed | ||||
|         in a future version. | ||||
|         Converts ``init``, which may be a sequence or a mapping of field values, to a constant. | ||||
| 
 | ||||
|         Returns | ||||
|         ------- | ||||
|         :class:`Const` | ||||
|             A constant that has the same value as a view with this layout that was initialized with | ||||
|             an all-zero value and had every field assigned to the corresponding value in the order | ||||
|             in which they appear in ``init``. | ||||
|         """ | ||||
|         if isinstance(value, Mapping): | ||||
|             iterator = value.items() | ||||
|         elif isinstance(value, Sequence): | ||||
|             iterator = enumerate(value) | ||||
|         if init is None: | ||||
|             iterator = iter(()) | ||||
|         elif isinstance(init, Mapping): | ||||
|             iterator = init.items() | ||||
|         elif isinstance(init, Sequence): | ||||
|             iterator = enumerate(init) | ||||
|         else: | ||||
|             raise TypeError("Layout initializer must be a mapping or a sequence, not {!r}" | ||||
|                             .format(value)) | ||||
|             raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}" | ||||
|                             .format(init)) | ||||
| 
 | ||||
|         int_value = 0 | ||||
|         for key, key_value in iterator: | ||||
|             field = self[key] | ||||
|             if isinstance(field.shape, Layout): | ||||
|                 key_value = field.shape._convert_to_int(key_value) | ||||
|             int_value |= Const(key_value, Shape.cast(field.shape)).value << field.offset | ||||
|         return int_value | ||||
|             cast_field_shape = Shape.cast(field.shape) | ||||
|             if isinstance(field.shape, ShapeCastable): | ||||
|                 key_value = Const.cast(field.shape.const(key_value)) | ||||
|                 if key_value.shape() != cast_field_shape: | ||||
|                     raise ValueError("Constant returned by {!r}.const() must have the shape that " | ||||
|                                      "it casts to, {!r}, and not {!r}" | ||||
|                                      .format(field.shape, cast_field_shape, | ||||
|                                              key_value.shape())) | ||||
|             else: | ||||
|                 key_value = Const(key_value, cast_field_shape) | ||||
|             int_value &= ~(((1 << cast_field_shape.width) - 1) << field.offset) | ||||
|             int_value |= key_value.value << field.offset | ||||
|         return Const(int_value, self.as_shape()) | ||||
| 
 | ||||
| 
 | ||||
| class StructLayout(Layout): | ||||
|  | @ -617,13 +632,9 @@ class View(ValueCastable): | |||
|                                  "the {} bit(s) wide view layout" | ||||
|                                  .format(len(cast_target), cast_layout.size)) | ||||
|         else: | ||||
|             if reset is None: | ||||
|                 reset = 0 | ||||
|             else: | ||||
|                 reset = cast_layout._convert_to_int(reset) | ||||
|             if reset_less is None: | ||||
|                 reset_less = False | ||||
|             cast_target = Signal(cast_layout, name=name, reset=reset, reset_less=reset_less, | ||||
|             cast_target = Signal(layout, name=name, reset=reset, reset_less=reset_less, | ||||
|                 attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1) | ||||
|         self.__orig_layout = layout | ||||
|         self.__layout = cast_layout | ||||
|  | @ -774,6 +785,9 @@ class _AggregateMeta(ShapeCastable, type): | |||
|                             .format(cls.__module__, cls.__qualname__)) | ||||
|         return cls.__layout | ||||
| 
 | ||||
|     def const(cls, init): | ||||
|         return cls.as_shape().const(init) | ||||
| 
 | ||||
| 
 | ||||
| class _Aggregate(View, metaclass=_AggregateMeta): | ||||
|     def __init__(self, target=None, *, name=None, reset=None, reset_less=None, | ||||
|  |  | |||
|  | @ -137,6 +137,17 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): | |||
|             return value | ||||
|         return super().__call__(value) | ||||
| 
 | ||||
|     def const(cls, init): | ||||
|         # Same considerations apply as above. | ||||
|         if init is None: | ||||
|             # Signal with unspecified reset value passes ``None`` to :meth:`const`. | ||||
|             # Before RFC 9 was implemented, the unspecified reset value was 0, so this keeps | ||||
|             # the old behavior intact. | ||||
|             member = cls(0) | ||||
|         else: | ||||
|             member = cls(init) | ||||
|         return Const(member.value, cls.as_shape()) | ||||
| 
 | ||||
| 
 | ||||
| class Enum(py_enum.Enum, metaclass=EnumMeta): | ||||
|     """Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ import warnings | |||
| from enum import Enum | ||||
| 
 | ||||
| from amaranth.hdl.ast import * | ||||
| from amaranth.lib.enum import Enum as AmaranthEnum | ||||
| 
 | ||||
| from .utils import * | ||||
| from amaranth._utils import _ignore_deprecated | ||||
|  | @ -144,6 +145,9 @@ class MockShapeCastable(ShapeCastable): | |||
|     def as_shape(self): | ||||
|         return self.dest | ||||
| 
 | ||||
|     def const(self, obj): | ||||
|         return Const(obj, self.dest) | ||||
| 
 | ||||
| 
 | ||||
| class ShapeCastableTestCase(FHDLTestCase): | ||||
|     def test_no_override(self): | ||||
|  | @ -995,6 +999,29 @@ class SignalTestCase(FHDLTestCase): | |||
|                 r"not <StringEnum\.FOO: 'a'>$"): | ||||
|             Signal(1, reset=StringEnum.FOO) | ||||
| 
 | ||||
|     def test_reset_shape_castable_const(self): | ||||
|         class CastableFromHex(ShapeCastable): | ||||
|             def as_shape(self): | ||||
|                 return unsigned(8) | ||||
| 
 | ||||
|             def const(self, init): | ||||
|                 return int(init, 16) | ||||
| 
 | ||||
|         s1 = Signal(CastableFromHex(), reset="aa") | ||||
|         self.assertEqual(s1.reset, 0xaa) | ||||
| 
 | ||||
|         with self.assertRaisesRegex(ValueError, | ||||
|                 r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape " | ||||
|                 r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"): | ||||
|             Signal(CastableFromHex(), reset="01") | ||||
| 
 | ||||
|     def test_reset_shape_castable_enum_wrong(self): | ||||
|         class EnumA(AmaranthEnum): | ||||
|             X = 1 | ||||
|         with self.assertRaisesRegex(TypeError, | ||||
|                 r"^Reset value must be a constant initializer of <enum 'EnumA'>$"): | ||||
|             Signal(EnumA) # implied reset=0 | ||||
| 
 | ||||
|     def test_reset_signed_mismatch(self): | ||||
|         with self.assertWarnsRegex(SyntaxWarning, | ||||
|                 r"^Reset value -2 is signed, but the signal shape is unsigned\(2\)$"): | ||||
|  |  | |||
|  | @ -436,7 +436,7 @@ class DSLTestCase(FHDLTestCase): | |||
|             RED  = 1 | ||||
|             BLUE = 2 | ||||
|         m = Module() | ||||
|         se = Signal(Color) | ||||
|         se = Signal(Color, reset=Color.RED) | ||||
|         with m.Switch(se): | ||||
|             with m.Case(Color.RED): | ||||
|                 m.d.comb += self.c1.eq(1) | ||||
|  |  | |||
|  | @ -16,6 +16,9 @@ class MockShapeCastable(ShapeCastable): | |||
|     def as_shape(self): | ||||
|         return self.shape | ||||
| 
 | ||||
|     def const(self, init): | ||||
|         return Const(init, self.shape) | ||||
| 
 | ||||
| 
 | ||||
| class FieldTestCase(TestCase): | ||||
|     def test_construct(self): | ||||
|  | @ -332,7 +335,7 @@ class FlexibleLayoutTestCase(TestCase): | |||
|             il[object()] | ||||
| 
 | ||||
| 
 | ||||
| class LayoutTestCase(TestCase): | ||||
| class LayoutTestCase(FHDLTestCase): | ||||
|     def test_cast(self): | ||||
|         sl = StructLayout({}) | ||||
|         self.assertIs(Layout.cast(sl), sl) | ||||
|  | @ -371,6 +374,53 @@ class LayoutTestCase(TestCase): | |||
|         self.assertIs(Layout.of(v), sl) | ||||
|         self.assertIs(v.as_value(), s) | ||||
| 
 | ||||
|     def test_const(self): | ||||
|         sl = StructLayout({ | ||||
|             "a": unsigned(1), | ||||
|             "b": unsigned(2) | ||||
|         }) | ||||
|         self.assertRepr(sl.const(None), "(const 3'd0)") | ||||
|         self.assertRepr(sl.const({"a": 0b1, "b": 0b10}), "(const 3'd5)") | ||||
| 
 | ||||
|         ul = UnionLayout({ | ||||
|             "a": unsigned(1), | ||||
|             "b": unsigned(2) | ||||
|         }) | ||||
|         self.assertRepr(ul.const({"a": 0b11}), "(const 2'd1)") | ||||
|         self.assertRepr(ul.const({"b": 0b10}), "(const 2'd2)") | ||||
|         self.assertRepr(ul.const({"a": 0b1, "b": 0b10}), "(const 2'd2)") | ||||
| 
 | ||||
|     def test_const_wrong(self): | ||||
|         sl = StructLayout({"f": unsigned(1)}) | ||||
|         with self.assertRaisesRegex(TypeError, | ||||
|                 r"^Layout constant initializer must be a mapping or a sequence, not " | ||||
|                 r"<.+?object.+?>$"): | ||||
|             sl.const(object()) | ||||
| 
 | ||||
|     def test_const_field_shape_castable(self): | ||||
|         class CastableFromHex(ShapeCastable): | ||||
|             def as_shape(self): | ||||
|                 return unsigned(8) | ||||
| 
 | ||||
|             def const(self, init): | ||||
|                 return int(init, 16) | ||||
| 
 | ||||
|         sl = StructLayout({"f": CastableFromHex()}) | ||||
|         self.assertRepr(sl.const({"f": "aa"}), "(const 8'd170)") | ||||
| 
 | ||||
|         with self.assertRaisesRegex(ValueError, | ||||
|                 r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape " | ||||
|                 r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"): | ||||
|             sl.const({"f": "01"}) | ||||
| 
 | ||||
|     def test_signal_reset(self): | ||||
|         sl = StructLayout({ | ||||
|             "a": unsigned(1), | ||||
|             "b": unsigned(2) | ||||
|         }) | ||||
|         self.assertEqual(Signal(sl).reset, 0) | ||||
|         self.assertEqual(Signal(sl, reset={"a": 0b1, "b": 0b10}).reset, 5) | ||||
| 
 | ||||
| 
 | ||||
| class ViewTestCase(FHDLTestCase): | ||||
|     def test_construct(self): | ||||
|  | @ -434,7 +484,7 @@ class ViewTestCase(FHDLTestCase): | |||
| 
 | ||||
|     def test_signal_reset_wrong(self): | ||||
|         with self.assertRaisesRegex(TypeError, | ||||
|                 r"^Layout initializer must be a mapping or a sequence, not 1$"): | ||||
|                 r"^Reset value must be a constant initializer of StructLayout\({}\)$"): | ||||
|             View(StructLayout({}), reset=0b1) | ||||
| 
 | ||||
|     def test_target_signal_wrong(self): | ||||
|  | @ -483,6 +533,9 @@ class ViewTestCase(FHDLTestCase): | |||
|             def __call__(self, value): | ||||
|                 return value[::-1] | ||||
| 
 | ||||
|             def const(self, init): | ||||
|                 return Const(init, 2) | ||||
| 
 | ||||
|         v = View(StructLayout({ | ||||
|             "f": Reverser() | ||||
|         })) | ||||
|  | @ -497,13 +550,15 @@ class ViewTestCase(FHDLTestCase): | |||
|             def __call__(self, value): | ||||
|                 pass | ||||
| 
 | ||||
|             def const(self, init): | ||||
|                 return Const(init, 2) | ||||
| 
 | ||||
|         v = View(StructLayout({ | ||||
|             "f": WrongCastable() | ||||
|         })) | ||||
|         with self.assertRaisesRegex(TypeError, | ||||
|                 r"^<tests\.test_lib_data\.ViewTestCase\.test_getitem_custom_call_wrong\.<locals>" | ||||
|                 r"\.WrongCastable object at 0x.+?>\.__call__\(\) must return a value or " | ||||
|                 r"a value-castable object, not None$"): | ||||
|                 r"^<.+?\.WrongCastable.+?>\.__call__\(\) must return a value or a value-castable " | ||||
|                 r"object, not None$"): | ||||
|             v.f | ||||
| 
 | ||||
|     def test_index_wrong_missing(self): | ||||
|  |  | |||
|  | @ -5,18 +5,12 @@ from .utils import * | |||
| 
 | ||||
| 
 | ||||
| class EnumTestCase(FHDLTestCase): | ||||
|     def test_non_int_members(self): | ||||
|     def test_members_non_int(self): | ||||
|         # Mustn't raise to be a drop-in replacement for Enum. | ||||
|         class EnumA(Enum): | ||||
|             A = "str" | ||||
| 
 | ||||
|     def test_non_const_non_int_members_wrong(self): | ||||
|         with self.assertRaisesRegex(TypeError, | ||||
|                 r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"): | ||||
|             class EnumA(Enum, shape=unsigned(1)): | ||||
|                 A = "str" | ||||
| 
 | ||||
|     def test_const_non_int_members(self): | ||||
|     def test_members_const_non_int(self): | ||||
|         class EnumA(Enum): | ||||
|             A = C(0) | ||||
|             B = C(1) | ||||
|  | @ -59,6 +53,12 @@ class EnumTestCase(FHDLTestCase): | |||
|             B = -5 | ||||
|         self.assertEqual(Shape.cast(EnumD), signed(4)) | ||||
| 
 | ||||
|     def test_shape_members_non_const_non_int_wrong(self): | ||||
|         with self.assertRaisesRegex(TypeError, | ||||
|                 r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"): | ||||
|             class EnumA(Enum, shape=unsigned(1)): | ||||
|                 A = "str" | ||||
| 
 | ||||
|     def test_shape_explicit_wrong_signed_mismatch(self): | ||||
|         with self.assertWarnsRegex(SyntaxWarning, | ||||
|                 r"^Value -1 of enumeration member 'A' is signed, but the enumeration " | ||||
|  | @ -88,6 +88,23 @@ class EnumTestCase(FHDLTestCase): | |||
|             A = 1 | ||||
|         self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)") | ||||
| 
 | ||||
|     def test_const_no_shape(self): | ||||
|         class EnumA(Enum): | ||||
|             Z = 0 | ||||
|             A = 10 | ||||
|             B = 20 | ||||
|         self.assertRepr(EnumA.const(None), "(const 5'd0)") | ||||
|         self.assertRepr(EnumA.const(10), "(const 5'd10)") | ||||
|         self.assertRepr(EnumA.const(EnumA.A), "(const 5'd10)") | ||||
| 
 | ||||
|     def test_const_shape(self): | ||||
|         class EnumA(Enum, shape=8): | ||||
|             Z = 0 | ||||
|             A = 10 | ||||
|         self.assertRepr(EnumA.const(None), "(const 8'd0)") | ||||
|         self.assertRepr(EnumA.const(10), "(const 8'd10)") | ||||
|         self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)") | ||||
| 
 | ||||
|     def test_shape_implicit_wrong_in_concat(self): | ||||
|         class EnumA(Enum): | ||||
|             A = 0 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue
	
	 Catherine
						Catherine