Implement RFC 9: Constant initialization for shape-castable objects.
See amaranth-lang/rfcs#9 and #771.
This commit is contained in:
parent
ea5a150155
commit
54d5c4c047
|
@ -45,6 +45,9 @@ class ShapeCastable:
|
||||||
if not hasattr(cls, "as_shape"):
|
if not hasattr(cls, "as_shape"):
|
||||||
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
|
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
|
||||||
f"the `as_shape` method")
|
f"the `as_shape` method")
|
||||||
|
if not hasattr(cls, "const"):
|
||||||
|
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
|
||||||
|
f"the `const` method")
|
||||||
|
|
||||||
|
|
||||||
class Shape:
|
class Shape:
|
||||||
|
@ -988,7 +991,7 @@ class Signal(Value, DUID):
|
||||||
decoder : function
|
decoder : function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shape=None, *, name=None, reset=0, reset_less=False,
|
def __init__(self, shape=None, *, name=None, reset=None, reset_less=False,
|
||||||
attrs=None, decoder=None, src_loc_at=0):
|
attrs=None, decoder=None, src_loc_at=0):
|
||||||
super().__init__(src_loc_at=src_loc_at)
|
super().__init__(src_loc_at=src_loc_at)
|
||||||
|
|
||||||
|
@ -1005,12 +1008,24 @@ class Signal(Value, DUID):
|
||||||
self.signed = shape.signed
|
self.signed = shape.signed
|
||||||
|
|
||||||
orig_reset = reset
|
orig_reset = reset
|
||||||
try:
|
if isinstance(orig_shape, ShapeCastable):
|
||||||
reset = Const.cast(reset)
|
try:
|
||||||
except TypeError:
|
reset = Const.cast(orig_shape.const(reset))
|
||||||
raise TypeError("Reset value must be a constant-castable expression, not {!r}"
|
except Exception:
|
||||||
.format(orig_reset))
|
raise TypeError("Reset value must be a constant initializer of {!r}"
|
||||||
if orig_reset not in (0, -1): # Avoid false positives for all-zeroes and all-ones
|
.format(orig_shape))
|
||||||
|
if reset.shape() != Shape.cast(orig_shape):
|
||||||
|
raise ValueError("Constant returned by {!r}.const() must have the shape that "
|
||||||
|
"it casts to, {!r}, and not {!r}"
|
||||||
|
.format(orig_shape, Shape.cast(orig_shape),
|
||||||
|
reset.shape()))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
reset = Const.cast(reset or 0)
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError("Reset value must be a constant-castable expression, not {!r}"
|
||||||
|
.format(orig_reset))
|
||||||
|
if orig_reset not in (None, 0, -1): # Avoid false positives for all-zeroes and all-ones
|
||||||
if reset.shape().signed and not self.signed:
|
if reset.shape().signed and not self.signed:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
message="Reset value {!r} is signed, but the signal shape is {!r}"
|
message="Reset value {!r} is signed, but the signal shape is {!r}"
|
||||||
|
|
|
@ -201,29 +201,44 @@ class Layout(ShapeCastable, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
return View(self, target)
|
return View(self, target)
|
||||||
|
|
||||||
def _convert_to_int(self, value):
|
def const(self, init):
|
||||||
"""Convert ``value``, which may be a dict or an array of field values, to an integer using
|
"""Convert a constant initializer to a constant.
|
||||||
the representation defined by this layout.
|
|
||||||
|
|
||||||
This method is private because Amaranth does not currently have a concept of
|
Converts ``init``, which may be a sequence or a mapping of field values, to a constant.
|
||||||
a constant initializer; this requires an RFC. It will be renamed or removed
|
|
||||||
in a future version.
|
Returns
|
||||||
|
-------
|
||||||
|
:class:`Const`
|
||||||
|
A constant that has the same value as a view with this layout that was initialized with
|
||||||
|
an all-zero value and had every field assigned to the corresponding value in the order
|
||||||
|
in which they appear in ``init``.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, Mapping):
|
if init is None:
|
||||||
iterator = value.items()
|
iterator = iter(())
|
||||||
elif isinstance(value, Sequence):
|
elif isinstance(init, Mapping):
|
||||||
iterator = enumerate(value)
|
iterator = init.items()
|
||||||
|
elif isinstance(init, Sequence):
|
||||||
|
iterator = enumerate(init)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Layout initializer must be a mapping or a sequence, not {!r}"
|
raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}"
|
||||||
.format(value))
|
.format(init))
|
||||||
|
|
||||||
int_value = 0
|
int_value = 0
|
||||||
for key, key_value in iterator:
|
for key, key_value in iterator:
|
||||||
field = self[key]
|
field = self[key]
|
||||||
if isinstance(field.shape, Layout):
|
cast_field_shape = Shape.cast(field.shape)
|
||||||
key_value = field.shape._convert_to_int(key_value)
|
if isinstance(field.shape, ShapeCastable):
|
||||||
int_value |= Const(key_value, Shape.cast(field.shape)).value << field.offset
|
key_value = Const.cast(field.shape.const(key_value))
|
||||||
return int_value
|
if key_value.shape() != cast_field_shape:
|
||||||
|
raise ValueError("Constant returned by {!r}.const() must have the shape that "
|
||||||
|
"it casts to, {!r}, and not {!r}"
|
||||||
|
.format(field.shape, cast_field_shape,
|
||||||
|
key_value.shape()))
|
||||||
|
else:
|
||||||
|
key_value = Const(key_value, cast_field_shape)
|
||||||
|
int_value &= ~(((1 << cast_field_shape.width) - 1) << field.offset)
|
||||||
|
int_value |= key_value.value << field.offset
|
||||||
|
return Const(int_value, self.as_shape())
|
||||||
|
|
||||||
|
|
||||||
class StructLayout(Layout):
|
class StructLayout(Layout):
|
||||||
|
@ -617,13 +632,9 @@ class View(ValueCastable):
|
||||||
"the {} bit(s) wide view layout"
|
"the {} bit(s) wide view layout"
|
||||||
.format(len(cast_target), cast_layout.size))
|
.format(len(cast_target), cast_layout.size))
|
||||||
else:
|
else:
|
||||||
if reset is None:
|
|
||||||
reset = 0
|
|
||||||
else:
|
|
||||||
reset = cast_layout._convert_to_int(reset)
|
|
||||||
if reset_less is None:
|
if reset_less is None:
|
||||||
reset_less = False
|
reset_less = False
|
||||||
cast_target = Signal(cast_layout, name=name, reset=reset, reset_less=reset_less,
|
cast_target = Signal(layout, name=name, reset=reset, reset_less=reset_less,
|
||||||
attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1)
|
attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1)
|
||||||
self.__orig_layout = layout
|
self.__orig_layout = layout
|
||||||
self.__layout = cast_layout
|
self.__layout = cast_layout
|
||||||
|
@ -774,6 +785,9 @@ class _AggregateMeta(ShapeCastable, type):
|
||||||
.format(cls.__module__, cls.__qualname__))
|
.format(cls.__module__, cls.__qualname__))
|
||||||
return cls.__layout
|
return cls.__layout
|
||||||
|
|
||||||
|
def const(cls, init):
|
||||||
|
return cls.as_shape().const(init)
|
||||||
|
|
||||||
|
|
||||||
class _Aggregate(View, metaclass=_AggregateMeta):
|
class _Aggregate(View, metaclass=_AggregateMeta):
|
||||||
def __init__(self, target=None, *, name=None, reset=None, reset_less=None,
|
def __init__(self, target=None, *, name=None, reset=None, reset_less=None,
|
||||||
|
|
|
@ -137,6 +137,17 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
|
||||||
return value
|
return value
|
||||||
return super().__call__(value)
|
return super().__call__(value)
|
||||||
|
|
||||||
|
def const(cls, init):
|
||||||
|
# Same considerations apply as above.
|
||||||
|
if init is None:
|
||||||
|
# Signal with unspecified reset value passes ``None`` to :meth:`const`.
|
||||||
|
# Before RFC 9 was implemented, the unspecified reset value was 0, so this keeps
|
||||||
|
# the old behavior intact.
|
||||||
|
member = cls(0)
|
||||||
|
else:
|
||||||
|
member = cls(init)
|
||||||
|
return Const(member.value, cls.as_shape())
|
||||||
|
|
||||||
|
|
||||||
class Enum(py_enum.Enum, metaclass=EnumMeta):
|
class Enum(py_enum.Enum, metaclass=EnumMeta):
|
||||||
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
|
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
|
||||||
|
|
|
@ -2,6 +2,7 @@ import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from amaranth.hdl.ast import *
|
from amaranth.hdl.ast import *
|
||||||
|
from amaranth.lib.enum import Enum as AmaranthEnum
|
||||||
|
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from amaranth._utils import _ignore_deprecated
|
from amaranth._utils import _ignore_deprecated
|
||||||
|
@ -144,6 +145,9 @@ class MockShapeCastable(ShapeCastable):
|
||||||
def as_shape(self):
|
def as_shape(self):
|
||||||
return self.dest
|
return self.dest
|
||||||
|
|
||||||
|
def const(self, obj):
|
||||||
|
return Const(obj, self.dest)
|
||||||
|
|
||||||
|
|
||||||
class ShapeCastableTestCase(FHDLTestCase):
|
class ShapeCastableTestCase(FHDLTestCase):
|
||||||
def test_no_override(self):
|
def test_no_override(self):
|
||||||
|
@ -995,6 +999,29 @@ class SignalTestCase(FHDLTestCase):
|
||||||
r"not <StringEnum\.FOO: 'a'>$"):
|
r"not <StringEnum\.FOO: 'a'>$"):
|
||||||
Signal(1, reset=StringEnum.FOO)
|
Signal(1, reset=StringEnum.FOO)
|
||||||
|
|
||||||
|
def test_reset_shape_castable_const(self):
|
||||||
|
class CastableFromHex(ShapeCastable):
|
||||||
|
def as_shape(self):
|
||||||
|
return unsigned(8)
|
||||||
|
|
||||||
|
def const(self, init):
|
||||||
|
return int(init, 16)
|
||||||
|
|
||||||
|
s1 = Signal(CastableFromHex(), reset="aa")
|
||||||
|
self.assertEqual(s1.reset, 0xaa)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape "
|
||||||
|
r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"):
|
||||||
|
Signal(CastableFromHex(), reset="01")
|
||||||
|
|
||||||
|
def test_reset_shape_castable_enum_wrong(self):
|
||||||
|
class EnumA(AmaranthEnum):
|
||||||
|
X = 1
|
||||||
|
with self.assertRaisesRegex(TypeError,
|
||||||
|
r"^Reset value must be a constant initializer of <enum 'EnumA'>$"):
|
||||||
|
Signal(EnumA) # implied reset=0
|
||||||
|
|
||||||
def test_reset_signed_mismatch(self):
|
def test_reset_signed_mismatch(self):
|
||||||
with self.assertWarnsRegex(SyntaxWarning,
|
with self.assertWarnsRegex(SyntaxWarning,
|
||||||
r"^Reset value -2 is signed, but the signal shape is unsigned\(2\)$"):
|
r"^Reset value -2 is signed, but the signal shape is unsigned\(2\)$"):
|
||||||
|
|
|
@ -436,7 +436,7 @@ class DSLTestCase(FHDLTestCase):
|
||||||
RED = 1
|
RED = 1
|
||||||
BLUE = 2
|
BLUE = 2
|
||||||
m = Module()
|
m = Module()
|
||||||
se = Signal(Color)
|
se = Signal(Color, reset=Color.RED)
|
||||||
with m.Switch(se):
|
with m.Switch(se):
|
||||||
with m.Case(Color.RED):
|
with m.Case(Color.RED):
|
||||||
m.d.comb += self.c1.eq(1)
|
m.d.comb += self.c1.eq(1)
|
||||||
|
|
|
@ -16,6 +16,9 @@ class MockShapeCastable(ShapeCastable):
|
||||||
def as_shape(self):
|
def as_shape(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
|
|
||||||
|
def const(self, init):
|
||||||
|
return Const(init, self.shape)
|
||||||
|
|
||||||
|
|
||||||
class FieldTestCase(TestCase):
|
class FieldTestCase(TestCase):
|
||||||
def test_construct(self):
|
def test_construct(self):
|
||||||
|
@ -332,7 +335,7 @@ class FlexibleLayoutTestCase(TestCase):
|
||||||
il[object()]
|
il[object()]
|
||||||
|
|
||||||
|
|
||||||
class LayoutTestCase(TestCase):
|
class LayoutTestCase(FHDLTestCase):
|
||||||
def test_cast(self):
|
def test_cast(self):
|
||||||
sl = StructLayout({})
|
sl = StructLayout({})
|
||||||
self.assertIs(Layout.cast(sl), sl)
|
self.assertIs(Layout.cast(sl), sl)
|
||||||
|
@ -371,6 +374,53 @@ class LayoutTestCase(TestCase):
|
||||||
self.assertIs(Layout.of(v), sl)
|
self.assertIs(Layout.of(v), sl)
|
||||||
self.assertIs(v.as_value(), s)
|
self.assertIs(v.as_value(), s)
|
||||||
|
|
||||||
|
def test_const(self):
|
||||||
|
sl = StructLayout({
|
||||||
|
"a": unsigned(1),
|
||||||
|
"b": unsigned(2)
|
||||||
|
})
|
||||||
|
self.assertRepr(sl.const(None), "(const 3'd0)")
|
||||||
|
self.assertRepr(sl.const({"a": 0b1, "b": 0b10}), "(const 3'd5)")
|
||||||
|
|
||||||
|
ul = UnionLayout({
|
||||||
|
"a": unsigned(1),
|
||||||
|
"b": unsigned(2)
|
||||||
|
})
|
||||||
|
self.assertRepr(ul.const({"a": 0b11}), "(const 2'd1)")
|
||||||
|
self.assertRepr(ul.const({"b": 0b10}), "(const 2'd2)")
|
||||||
|
self.assertRepr(ul.const({"a": 0b1, "b": 0b10}), "(const 2'd2)")
|
||||||
|
|
||||||
|
def test_const_wrong(self):
|
||||||
|
sl = StructLayout({"f": unsigned(1)})
|
||||||
|
with self.assertRaisesRegex(TypeError,
|
||||||
|
r"^Layout constant initializer must be a mapping or a sequence, not "
|
||||||
|
r"<.+?object.+?>$"):
|
||||||
|
sl.const(object())
|
||||||
|
|
||||||
|
def test_const_field_shape_castable(self):
|
||||||
|
class CastableFromHex(ShapeCastable):
|
||||||
|
def as_shape(self):
|
||||||
|
return unsigned(8)
|
||||||
|
|
||||||
|
def const(self, init):
|
||||||
|
return int(init, 16)
|
||||||
|
|
||||||
|
sl = StructLayout({"f": CastableFromHex()})
|
||||||
|
self.assertRepr(sl.const({"f": "aa"}), "(const 8'd170)")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape "
|
||||||
|
r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"):
|
||||||
|
sl.const({"f": "01"})
|
||||||
|
|
||||||
|
def test_signal_reset(self):
|
||||||
|
sl = StructLayout({
|
||||||
|
"a": unsigned(1),
|
||||||
|
"b": unsigned(2)
|
||||||
|
})
|
||||||
|
self.assertEqual(Signal(sl).reset, 0)
|
||||||
|
self.assertEqual(Signal(sl, reset={"a": 0b1, "b": 0b10}).reset, 5)
|
||||||
|
|
||||||
|
|
||||||
class ViewTestCase(FHDLTestCase):
|
class ViewTestCase(FHDLTestCase):
|
||||||
def test_construct(self):
|
def test_construct(self):
|
||||||
|
@ -434,7 +484,7 @@ class ViewTestCase(FHDLTestCase):
|
||||||
|
|
||||||
def test_signal_reset_wrong(self):
|
def test_signal_reset_wrong(self):
|
||||||
with self.assertRaisesRegex(TypeError,
|
with self.assertRaisesRegex(TypeError,
|
||||||
r"^Layout initializer must be a mapping or a sequence, not 1$"):
|
r"^Reset value must be a constant initializer of StructLayout\({}\)$"):
|
||||||
View(StructLayout({}), reset=0b1)
|
View(StructLayout({}), reset=0b1)
|
||||||
|
|
||||||
def test_target_signal_wrong(self):
|
def test_target_signal_wrong(self):
|
||||||
|
@ -483,6 +533,9 @@ class ViewTestCase(FHDLTestCase):
|
||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
return value[::-1]
|
return value[::-1]
|
||||||
|
|
||||||
|
def const(self, init):
|
||||||
|
return Const(init, 2)
|
||||||
|
|
||||||
v = View(StructLayout({
|
v = View(StructLayout({
|
||||||
"f": Reverser()
|
"f": Reverser()
|
||||||
}))
|
}))
|
||||||
|
@ -497,13 +550,15 @@ class ViewTestCase(FHDLTestCase):
|
||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def const(self, init):
|
||||||
|
return Const(init, 2)
|
||||||
|
|
||||||
v = View(StructLayout({
|
v = View(StructLayout({
|
||||||
"f": WrongCastable()
|
"f": WrongCastable()
|
||||||
}))
|
}))
|
||||||
with self.assertRaisesRegex(TypeError,
|
with self.assertRaisesRegex(TypeError,
|
||||||
r"^<tests\.test_lib_data\.ViewTestCase\.test_getitem_custom_call_wrong\.<locals>"
|
r"^<.+?\.WrongCastable.+?>\.__call__\(\) must return a value or a value-castable "
|
||||||
r"\.WrongCastable object at 0x.+?>\.__call__\(\) must return a value or "
|
r"object, not None$"):
|
||||||
r"a value-castable object, not None$"):
|
|
||||||
v.f
|
v.f
|
||||||
|
|
||||||
def test_index_wrong_missing(self):
|
def test_index_wrong_missing(self):
|
||||||
|
|
|
@ -5,18 +5,12 @@ from .utils import *
|
||||||
|
|
||||||
|
|
||||||
class EnumTestCase(FHDLTestCase):
|
class EnumTestCase(FHDLTestCase):
|
||||||
def test_non_int_members(self):
|
def test_members_non_int(self):
|
||||||
# Mustn't raise to be a drop-in replacement for Enum.
|
# Mustn't raise to be a drop-in replacement for Enum.
|
||||||
class EnumA(Enum):
|
class EnumA(Enum):
|
||||||
A = "str"
|
A = "str"
|
||||||
|
|
||||||
def test_non_const_non_int_members_wrong(self):
|
def test_members_const_non_int(self):
|
||||||
with self.assertRaisesRegex(TypeError,
|
|
||||||
r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"):
|
|
||||||
class EnumA(Enum, shape=unsigned(1)):
|
|
||||||
A = "str"
|
|
||||||
|
|
||||||
def test_const_non_int_members(self):
|
|
||||||
class EnumA(Enum):
|
class EnumA(Enum):
|
||||||
A = C(0)
|
A = C(0)
|
||||||
B = C(1)
|
B = C(1)
|
||||||
|
@ -59,6 +53,12 @@ class EnumTestCase(FHDLTestCase):
|
||||||
B = -5
|
B = -5
|
||||||
self.assertEqual(Shape.cast(EnumD), signed(4))
|
self.assertEqual(Shape.cast(EnumD), signed(4))
|
||||||
|
|
||||||
|
def test_shape_members_non_const_non_int_wrong(self):
|
||||||
|
with self.assertRaisesRegex(TypeError,
|
||||||
|
r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"):
|
||||||
|
class EnumA(Enum, shape=unsigned(1)):
|
||||||
|
A = "str"
|
||||||
|
|
||||||
def test_shape_explicit_wrong_signed_mismatch(self):
|
def test_shape_explicit_wrong_signed_mismatch(self):
|
||||||
with self.assertWarnsRegex(SyntaxWarning,
|
with self.assertWarnsRegex(SyntaxWarning,
|
||||||
r"^Value -1 of enumeration member 'A' is signed, but the enumeration "
|
r"^Value -1 of enumeration member 'A' is signed, but the enumeration "
|
||||||
|
@ -88,6 +88,23 @@ class EnumTestCase(FHDLTestCase):
|
||||||
A = 1
|
A = 1
|
||||||
self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)")
|
self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)")
|
||||||
|
|
||||||
|
def test_const_no_shape(self):
|
||||||
|
class EnumA(Enum):
|
||||||
|
Z = 0
|
||||||
|
A = 10
|
||||||
|
B = 20
|
||||||
|
self.assertRepr(EnumA.const(None), "(const 5'd0)")
|
||||||
|
self.assertRepr(EnumA.const(10), "(const 5'd10)")
|
||||||
|
self.assertRepr(EnumA.const(EnumA.A), "(const 5'd10)")
|
||||||
|
|
||||||
|
def test_const_shape(self):
|
||||||
|
class EnumA(Enum, shape=8):
|
||||||
|
Z = 0
|
||||||
|
A = 10
|
||||||
|
self.assertRepr(EnumA.const(None), "(const 8'd0)")
|
||||||
|
self.assertRepr(EnumA.const(10), "(const 8'd10)")
|
||||||
|
self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)")
|
||||||
|
|
||||||
def test_shape_implicit_wrong_in_concat(self):
|
def test_shape_implicit_wrong_in_concat(self):
|
||||||
class EnumA(Enum):
|
class EnumA(Enum):
|
||||||
A = 0
|
A = 0
|
||||||
|
|
Loading…
Reference in a new issue