lib.data: implement RFC 1 "Aggregate data structure library".

See amaranth-lang/rfcs#1.
This commit is contained in:
Catherine 2022-04-06 04:00:22 +00:00
parent a7fec279aa
commit 7e3e10e733
2 changed files with 1192 additions and 0 deletions

432
amaranth/lib/data.py Normal file
View file

@ -0,0 +1,432 @@
from abc import ABCMeta, abstractmethod, abstractproperty
from collections.abc import Mapping, Sequence
from amaranth.hdl import *
from amaranth.hdl.ast import ShapeCastable, ValueCastable
__all__ = [
"Field", "Layout", "StructLayout", "UnionLayout", "ArrayLayout", "FlexibleLayout",
"View", "Struct", "Union",
]
class Field:
def __init__(self, shape, offset):
self.shape = shape
self.offset = offset
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
try:
Shape.cast(shape)
except TypeError as e:
raise TypeError("Field shape must be a shape-castable object, not {!r}"
.format(shape)) from e
self._shape = shape
@property
def offset(self):
return self._offset
@offset.setter
def offset(self, offset):
if not isinstance(offset, int) or offset < 0:
raise TypeError("Field offset must be a non-negative integer, not {!r}"
.format(offset))
self._offset = offset
@property
def width(self):
return Shape.cast(self.shape).width
def __eq__(self, other):
return (isinstance(other, Field) and
self._shape == other.shape and self._offset == other.offset)
def __repr__(self):
return f"Field({self._shape!r}, {self._offset})"
class Layout(ShapeCastable, metaclass=ABCMeta):
@staticmethod
def cast(obj):
"""Cast a shape-castable object to a layout."""
while isinstance(obj, ShapeCastable):
if isinstance(obj, Layout):
return obj
new_obj = obj.as_shape()
if new_obj is obj:
break
obj = new_obj
Shape.cast(obj) # delegate non-layout-specific error handling to Shape
raise TypeError("Object {!r} cannot be converted to a data layout"
.format(obj))
@staticmethod
def of(obj):
"""Extract the layout from a view."""
if not isinstance(obj, View):
raise TypeError("Object {!r} is not a data view"
.format(obj))
return obj._View__orig_layout
@abstractmethod
def __iter__(self):
"""Iterate the layout, yielding ``(key, field)`` pairs. Keys may be strings or integers."""
@abstractmethod
def __getitem__(self, key):
"""Retrieve the :class:`Field` associated with the ``key``, or raise ``KeyError``."""
size = abstractproperty()
"""The number of bits in the representation defined by the layout."""
def as_shape(self):
"""Convert the representation defined by the layout to an unsigned :class:`Shape`."""
return unsigned(self.size)
def __eq__(self, other):
"""Compare the layout with another.
Two layouts are equal if they have the same size and the same fields under the same names.
The order of the fields is not considered.
"""
while isinstance(other, ShapeCastable) and not isinstance(other, Layout):
new_other = other.as_shape()
if new_other is other:
break
other = new_other
return (isinstance(other, Layout) and self.size == other.size and
dict(iter(self)) == dict(iter(other)))
def _convert_to_int(self, value):
"""Convert ``value``, which may be a dict or an array of field values, to an integer using
the representation defined by this layout.
This method is roughly equivalent to :meth:`Const.normalize`. It is private because
Amaranth does not currently have a concept of a constant initializer; this requires
an RFC. It will be renamed or removed in a future version."""
if isinstance(value, Mapping):
iterator = value.items()
elif isinstance(value, Sequence):
iterator = enumerate(value)
else:
raise TypeError("Layout initializer must be a mapping or a sequence, not {!r}"
.format(value))
int_value = 0
for key, key_value in iterator:
field = self[key]
if isinstance(field.shape, Layout):
key_value = field.shape._convert_to_int(key_value)
int_value |= Const.normalize(key_value, Shape.cast(field.shape)) << field.offset
return int_value
class StructLayout(Layout):
def __init__(self, members):
self.members = members
@property
def members(self):
return {key: field.shape for key, field in self._fields.items()}
@members.setter
def members(self, members):
offset = 0
self._fields = {}
if not isinstance(members, Mapping):
raise TypeError("Struct layout members must be provided as a mapping, not {!r}"
.format(members))
for key, shape in members.items():
if not isinstance(key, str):
raise TypeError("Struct layout member name must be a string, not {!r}"
.format(key))
try:
cast_shape = Shape.cast(shape)
except TypeError as e:
raise TypeError("Struct layout member shape must be a shape-castable object, "
"not {!r}"
.format(shape)) from e
self._fields[key] = Field(shape, offset)
offset += cast_shape.width
def __iter__(self):
return iter(self._fields.items())
def __getitem__(self, key):
return self._fields[key]
@property
def size(self):
return max((field.offset + field.width for field in self._fields.values()), default=0)
def __repr__(self):
return f"StructLayout({self.members!r})"
class UnionLayout(Layout):
def __init__(self, members):
self.members = members
@property
def members(self):
return {key: field.shape for key, field in self._fields.items()}
@members.setter
def members(self, members):
self._fields = {}
if not isinstance(members, Mapping):
raise TypeError("Union layout members must be provided as a mapping, not {!r}"
.format(members))
for key, shape in members.items():
if not isinstance(key, str):
raise TypeError("Union layout member name must be a string, not {!r}"
.format(key))
try:
cast_shape = Shape.cast(shape)
except TypeError as e:
raise TypeError("Union layout member shape must be a shape-castable object, "
"not {!r}"
.format(shape)) from e
self._fields[key] = Field(shape, 0)
def __iter__(self):
return iter(self._fields.items())
def __getitem__(self, key):
return self._fields[key]
@property
def size(self):
return max((field.width for field in self._fields.values()), default=0)
def __repr__(self):
return f"UnionLayout({self.members!r})"
class ArrayLayout(Layout):
def __init__(self, elem_shape, length):
self.elem_shape = elem_shape
self.length = length
@property
def elem_shape(self):
return self._elem_shape
@elem_shape.setter
def elem_shape(self, elem_shape):
try:
Shape.cast(elem_shape)
except TypeError as e:
raise TypeError("Array layout element shape must be a shape-castable object, "
"not {!r}"
.format(elem_shape)) from e
self._elem_shape = elem_shape
@property
def length(self):
return self._length
@length.setter
def length(self, length):
if not isinstance(length, int) or length < 0:
raise TypeError("Array layout length must be a non-negative integer, not {!r}"
.format(length))
self._length = length
def __iter__(self):
offset = 0
for index in range(self._length):
yield index, Field(self._elem_shape, offset)
offset += Shape.cast(self._elem_shape).width
def __getitem__(self, key):
if isinstance(key, int):
if key not in range(-self._length, self._length):
# Layout's interface requires us to raise KeyError, not IndexError
raise KeyError(key)
if key < 0:
key += self._length
return Field(self._elem_shape, key * Shape.cast(self._elem_shape).width)
raise TypeError("Cannot index array layout with {!r}".format(key))
@property
def size(self):
return Shape.cast(self._elem_shape).width * self.length
def __repr__(self):
return f"ArrayLayout({self._elem_shape!r}, {self.length})"
class FlexibleLayout(Layout):
def __init__(self, size, fields):
self.size = size
self.fields = fields
@property
def size(self):
return self._size
@size.setter
def size(self, size):
if not isinstance(size, int) or size < 0:
raise TypeError("Flexible layout size must be a non-negative integer, not {!r}"
.format(size))
if hasattr(self, "_fields") and self._fields:
endmost_name, endmost_field = max(self._fields.items(),
key=lambda pair: pair[1].offset + pair[1].width)
if endmost_field.offset + endmost_field.width > size:
raise ValueError("Flexible layout size {} does not cover the field '{}', which "
"ends at bit {}"
.format(size, endmost_name,
endmost_field.offset + endmost_field.width))
self._size = size
@property
def fields(self):
return {**self._fields}
@fields.setter
def fields(self, fields):
self._fields = {}
if not isinstance(fields, Mapping):
raise TypeError("Flexible layout fields must be provided as a mapping, not {!r}"
.format(fields))
for key, field in fields.items():
if not isinstance(key, (int, str)) or (isinstance(key, int) and key < 0):
raise TypeError("Flexible layout field name must be a non-negative integer or "
"a string, not {!r}"
.format(key))
if not isinstance(field, Field):
raise TypeError("Flexible layout field value must be a Field instance, not {!r}"
.format(field))
if field.offset + field.width > self._size:
raise ValueError("Flexible layout field '{}' ends at bit {}, exceeding "
"the size of {} bit(s)"
.format(key, field.offset + field.width, self._size))
self._fields[key] = field
def __iter__(self):
return iter(self._fields.items())
def __getitem__(self, key):
if isinstance(key, (int, str)):
return self._fields[key]
raise TypeError("Cannot index flexible layout with {!r}".format(key))
def __repr__(self):
return f"FlexibleLayout({self._size}, {self._fields!r})"
class View(ValueCastable):
def __init__(self, layout, target=None, *, name=None, reset=None, reset_less=None,
attrs=None, decoder=None, src_loc_at=0):
try:
cast_layout = Layout.cast(layout)
except TypeError as e:
raise TypeError("View layout must be a Layout instance, not {!r}"
.format(layout)) from e
if target is not None:
if (name is not None or reset is not None or reset_less is not None or
attrs is not None or decoder is not None):
raise ValueError("View target cannot be provided at the same time as any of "
"the Signal constructor arguments (name, reset, reset_less, "
"attrs, decoder)")
try:
cast_target = Value.cast(target)
except TypeError as e:
raise TypeError("View target must be a value-castable object, not {!r}"
.format(target)) from e
if len(cast_target) != cast_layout.size:
raise ValueError("View target is {} bit(s) wide, which is not compatible with "
"the {} bit(s) wide view layout"
.format(len(cast_target), cast_layout.size))
else:
if reset is None:
reset = 0
else:
reset = cast_layout._convert_to_int(reset)
if reset_less is None:
reset_less = False
cast_target = Signal(cast_layout, name=name, reset=reset, reset_less=reset_less,
attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1)
self.__orig_layout = layout
self.__layout = cast_layout
self.__target = cast_target
@ValueCastable.lowermethod
def as_value(self):
return self.__target
def eq(self, other):
return self.as_value().eq(other)
def __getitem__(self, key):
if isinstance(self.__layout, ArrayLayout):
shape = self.__layout.elem_shape
value = self.__target.word_select(key, Shape.cast(self.__layout.elem_shape).width)
else:
if isinstance(key, (Value, ValueCastable)):
raise TypeError("Only views with array layout, not {!r}, may be indexed "
"with a value"
.format(self.__layout))
field = self.__layout[key]
shape = field.shape
value = self.__target[field.offset:field.offset + field.width]
if isinstance(shape, _AggregateMeta):
return shape(value)
if isinstance(shape, Layout):
return View(shape, value)
if Shape.cast(shape).signed:
return value.as_signed()
else:
return value
def __getattr__(self, name):
try:
item = self[name]
except KeyError:
raise AttributeError("View of {!r} does not have a field {!r}; "
"did you mean one of: {}?"
.format(self.__target, name,
", ".join(repr(name)
for name, field in self.__layout)))
if name.startswith("_"):
raise AttributeError("View of {!r} field {!r} has a reserved name and may only be "
"accessed by indexing"
.format(self.__target, name))
return item
class _AggregateMeta(ShapeCastable, type):
def __new__(metacls, name, bases, namespace, *, _layout_cls=None, **kwargs):
cls = type.__new__(metacls, name, bases, namespace, **kwargs)
if _layout_cls is not None:
cls.__layout_cls = _layout_cls
if "__annotations__" in namespace:
cls.__layout = cls.__layout_cls(namespace["__annotations__"])
return cls
def as_shape(cls):
return cls.__layout
class _Aggregate(View, metaclass=_AggregateMeta):
def __init__(self, target=None, *, name=None, reset=None, reset_less=None,
attrs=None, decoder=None, src_loc_at=0):
super().__init__(self.__class__, target, name=name, reset=reset, reset_less=reset_less,
attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1)
class Struct(_Aggregate, _layout_cls=StructLayout):
pass
class Union(_Aggregate, _layout_cls=UnionLayout):
pass

760
tests/test_lib_data.py Normal file
View file

@ -0,0 +1,760 @@
from enum import Enum
from unittest import TestCase
from amaranth.hdl import *
from amaranth.hdl.ast import ShapeCastable
from amaranth.lib.data import *
from amaranth.sim import Simulator
from .utils import *
class MockShapeCastable(ShapeCastable):
def __init__(self, shape):
self.shape = shape
def as_shape(self):
return self.shape
class FieldTestCase(TestCase):
def test_construct(self):
f = Field(unsigned(2), 1)
self.assertEqual(f.shape, unsigned(2))
self.assertEqual(f.offset, 1)
self.assertEqual(f.width, 2)
def test_repr(self):
f = Field(unsigned(2), 1)
self.assertEqual(repr(f), "Field(unsigned(2), 1)")
def test_equal(self):
f1 = Field(unsigned(2), 1)
f2 = Field(unsigned(2), 0)
self.assertNotEqual(f1, f2)
f3 = Field(unsigned(2), 1)
self.assertEqual(f1, f3)
f4 = Field(2, 1)
self.assertEqual(f1, f4)
f5 = Field(MockShapeCastable(unsigned(2)), 1)
self.assertEqual(f1, f5)
self.assertNotEqual(f1, object())
def test_preserve_shape(self):
sc = MockShapeCastable(unsigned(2))
f = Field(sc, 0)
self.assertEqual(f.shape, sc)
self.assertEqual(f.width, 2)
def test_shape_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Field shape must be a shape-castable object, not <.+>$"):
Field(object(), 0)
def test_offset_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Field offset must be a non-negative integer, not <.+>$"):
Field(unsigned(2), object())
with self.assertRaisesRegex(TypeError,
r"^Field offset must be a non-negative integer, not -1$"):
Field(unsigned(2), -1)
class StructLayoutTestCase(TestCase):
def test_construct(self):
sl = StructLayout({
"a": unsigned(1),
"b": 2
})
self.assertEqual(sl.members, {
"a": unsigned(1),
"b": 2
})
self.assertEqual(sl.size, 3)
self.assertEqual(list(iter(sl)), [
("a", Field(unsigned(1), 0)),
("b", Field(2, 1))
])
self.assertEqual(sl["a"], Field(unsigned(1), 0))
self.assertEqual(sl["b"], Field(2, 1))
def test_size_empty(self):
self.assertEqual(StructLayout({}).size, 0)
def test_eq(self):
self.assertEqual(StructLayout({"a": unsigned(1), "b": 2}),
StructLayout({"a": unsigned(1), "b": unsigned(2)}))
self.assertNotEqual(StructLayout({"a": unsigned(1), "b": 2}),
StructLayout({"b": unsigned(2), "a": unsigned(1)}))
self.assertNotEqual(StructLayout({"a": unsigned(1), "b": 2}),
StructLayout({"a": unsigned(1)}))
def test_repr(self):
sl = StructLayout({
"a": unsigned(1),
"b": 2
})
self.assertEqual(repr(sl), "StructLayout({'a': unsigned(1), 'b': 2})")
def test_members_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Struct layout members must be provided as a mapping, not <.+>$"):
StructLayout(object())
def test_member_key_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Struct layout member name must be a string, not 1\.0$"):
StructLayout({1.0: unsigned(1)})
def test_member_value_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Struct layout member shape must be a shape-castable object, not 1\.0$"):
StructLayout({"a": 1.0})
class UnionLayoutTestCase(TestCase):
def test_construct(self):
ul = UnionLayout({
"a": unsigned(1),
"b": 2
})
self.assertEqual(ul.members, {
"a": unsigned(1),
"b": 2
})
self.assertEqual(ul.size, 2)
self.assertEqual(list(iter(ul)), [
("a", Field(unsigned(1), 0)),
("b", Field(2, 0))
])
self.assertEqual(ul["a"], Field(unsigned(1), 0))
self.assertEqual(ul["b"], Field(2, 0))
def test_size_empty(self):
self.assertEqual(UnionLayout({}).size, 0)
def test_eq(self):
self.assertEqual(UnionLayout({"a": unsigned(1), "b": 2}),
UnionLayout({"a": unsigned(1), "b": unsigned(2)}))
self.assertEqual(UnionLayout({"a": unsigned(1), "b": 2}),
UnionLayout({"b": unsigned(2), "a": unsigned(1)}))
self.assertNotEqual(UnionLayout({"a": unsigned(1), "b": 2}),
UnionLayout({"a": unsigned(1)}))
def test_repr(self):
ul = UnionLayout({
"a": unsigned(1),
"b": 2
})
self.assertEqual(repr(ul), "UnionLayout({'a': unsigned(1), 'b': 2})")
def test_members_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Union layout members must be provided as a mapping, not <.+>$"):
UnionLayout(object())
def test_member_key_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Union layout member name must be a string, not 1\.0$"):
UnionLayout({1.0: unsigned(1)})
def test_member_value_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Union layout member shape must be a shape-castable object, not 1\.0$"):
UnionLayout({"a": 1.0})
class ArrayLayoutTestCase(TestCase):
def test_construct(self):
al = ArrayLayout(unsigned(2), 3)
self.assertEqual(al.elem_shape, unsigned(2))
self.assertEqual(al.length, 3)
self.assertEqual(list(iter(al)), [
(0, Field(unsigned(2), 0)),
(1, Field(unsigned(2), 2)),
(2, Field(unsigned(2), 4)),
])
self.assertEqual(al[0], Field(unsigned(2), 0))
self.assertEqual(al[1], Field(unsigned(2), 2))
self.assertEqual(al[2], Field(unsigned(2), 4))
self.assertEqual(al[-1], Field(unsigned(2), 4))
self.assertEqual(al[-2], Field(unsigned(2), 2))
self.assertEqual(al[-3], Field(unsigned(2), 0))
self.assertEqual(al.size, 6)
def test_shape_castable(self):
al = ArrayLayout(2, 3)
self.assertEqual(al.size, 6)
def test_eq(self):
self.assertEqual(ArrayLayout(unsigned(2), 3),
ArrayLayout(unsigned(2), 3))
self.assertNotEqual(ArrayLayout(unsigned(2), 3),
ArrayLayout(unsigned(2), 4))
def test_repr(self):
al = ArrayLayout(unsigned(2), 3)
self.assertEqual(repr(al), "ArrayLayout(unsigned(2), 3)")
def test_elem_shape_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Array layout element shape must be a shape-castable object, not <.+>$"):
ArrayLayout(object(), 1)
def test_length_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Array layout length must be a non-negative integer, not <.+>$"):
ArrayLayout(unsigned(1), object())
with self.assertRaisesRegex(TypeError,
r"^Array layout length must be a non-negative integer, not -1$"):
ArrayLayout(unsigned(1), -1)
def test_key_wrong_bounds(self):
al = ArrayLayout(unsigned(2), 3)
with self.assertRaisesRegex(KeyError, r"^4$"):
al[4]
with self.assertRaisesRegex(KeyError, r"^-4$"):
al[-4]
def test_key_wrong_type(self):
al = ArrayLayout(unsigned(2), 3)
with self.assertRaisesRegex(TypeError,
r"^Cannot index array layout with 'a'$"):
al["a"]
class FlexibleLayoutTestCase(TestCase):
def test_construct(self):
il = FlexibleLayout(8, {
"a": Field(unsigned(1), 1),
"b": Field(unsigned(3), 0),
0: Field(unsigned(2), 5)
})
self.assertEqual(il.size, 8)
self.assertEqual(il.fields, {
"a": Field(unsigned(1), 1),
"b": Field(unsigned(3), 0),
0: Field(unsigned(2), 5)
})
self.assertEqual(list(iter(il)), [
("a", Field(unsigned(1), 1)),
("b", Field(unsigned(3), 0)),
(0, Field(unsigned(2), 5))
])
self.assertEqual(il["a"], Field(unsigned(1), 1))
self.assertEqual(il["b"], Field(unsigned(3), 0))
self.assertEqual(il[0], Field(unsigned(2), 5))
def test_eq(self):
self.assertEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0)}),
FlexibleLayout(3, {"a": Field(unsigned(1), 0)}))
self.assertNotEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0)}),
FlexibleLayout(4, {"a": Field(unsigned(1), 0)}))
self.assertNotEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0)}),
FlexibleLayout(3, {"a": Field(unsigned(1), 1)}))
def test_eq_duck(self):
self.assertEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0),
"b": Field(unsigned(2), 1)}),
StructLayout({"a": unsigned(1),
"b": unsigned(2)}))
self.assertEqual(FlexibleLayout(2, {"a": Field(unsigned(1), 0),
"b": Field(unsigned(2), 0)}),
UnionLayout({"a": unsigned(1),
"b": unsigned(2)}))
def test_repr(self):
il = FlexibleLayout(8, {
"a": Field(unsigned(1), 1),
"b": Field(unsigned(3), 0),
0: Field(unsigned(2), 5)
})
self.assertEqual(repr(il), "FlexibleLayout(8, {"
"'a': Field(unsigned(1), 1), "
"'b': Field(unsigned(3), 0), "
"0: Field(unsigned(2), 5)})")
def test_fields_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Flexible layout fields must be provided as a mapping, not <.+>$"):
FlexibleLayout(8, object())
def test_field_key_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Flexible layout field name must be a non-negative integer or a string, "
r"not 1\.0$"):
FlexibleLayout(8, {1.0: unsigned(1)})
with self.assertRaisesRegex(TypeError,
r"^Flexible layout field name must be a non-negative integer or a string, "
r"not -1$"):
FlexibleLayout(8, {-1: unsigned(1)})
def test_field_value_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Flexible layout field value must be a Field instance, not 1\.0$"):
FlexibleLayout(8, {"a": 1.0})
def test_size_wrong_negative(self):
with self.assertRaisesRegex(TypeError,
r"^Flexible layout size must be a non-negative integer, not -1$"):
FlexibleLayout(-1, {})
def test_size_wrong_small(self):
with self.assertRaisesRegex(ValueError,
r"^Flexible layout field 'a' ends at bit 8, exceeding the size of 4 bit\(s\)$"):
FlexibleLayout(4, {"a": Field(unsigned(8), 0)})
with self.assertRaisesRegex(ValueError,
r"^Flexible layout field 'a' ends at bit 5, exceeding the size of 4 bit\(s\)$"):
FlexibleLayout(4, {"a": Field(unsigned(2), 3)})
def test_size_wrong_shrink(self):
il = FlexibleLayout(8, {"a": Field(unsigned(2), 3)})
with self.assertRaisesRegex(ValueError,
r"^Flexible layout size 4 does not cover the field 'a', which ends at bit 5$"):
il.size = 4
def test_key_wrong_missing(self):
il = FlexibleLayout(8, {"a": Field(unsigned(2), 3)})
with self.assertRaisesRegex(KeyError,
r"^0$"):
il[0]
def test_key_wrong_type(self):
il = FlexibleLayout(8, {"a": Field(unsigned(2), 3)})
with self.assertRaisesRegex(TypeError,
r"^Cannot index flexible layout with <.+>$"):
il[object()]
class LayoutTestCase(TestCase):
def test_cast(self):
sl = StructLayout({})
self.assertIs(Layout.cast(sl), sl)
def test_cast_wrong_not_layout(self):
with self.assertRaisesRegex(TypeError,
r"^Object unsigned\(1\) cannot be converted to a data layout$"):
Layout.cast(unsigned(1))
def test_cast_wrong_type(self):
with self.assertRaisesRegex(TypeError,
r"^Object <.+> cannot be converted to an Amaranth shape$"):
Layout.cast(object())
def test_cast_wrong_recur(self):
sc = MockShapeCastable(None)
sc.shape = sc
with self.assertRaisesRegex(RecursionError,
r"^Shape-castable object <.+> casts to itself$"):
Layout.cast(sc)
def test_of_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Object <.+> is not a data view$"):
Layout.of(object())
def test_eq_wrong_recur(self):
sc = MockShapeCastable(None)
sc.shape = sc
self.assertNotEqual(StructLayout({}), sc)
class ViewTestCase(FHDLTestCase):
def test_construct(self):
s = Signal(3)
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), s)
self.assertIs(Value.cast(v), s)
self.assertRepr(v["a"], "(slice (sig s) 0:1)")
self.assertRepr(v["b"], "(slice (sig s) 1:3)")
def test_construct_signal(self):
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}))
cv = Value.cast(v)
self.assertIsInstance(cv, Signal)
self.assertEqual(cv.shape(), unsigned(3))
self.assertEqual(cv.name, "v")
def test_construct_signal_name(self):
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), name="named")
self.assertEqual(Value.cast(v).name, "named")
def test_construct_signal_reset(self):
v1 = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}),
reset={"a": 0b1, "b": 0b10})
self.assertEqual(Value.cast(v1).reset, 0b101)
v2 = View(StructLayout({"a": unsigned(1),
"b": StructLayout({"x": unsigned(1), "y": unsigned(1)})}),
reset={"a": 0b1, "b": {"x": 0b0, "y": 0b1}})
self.assertEqual(Value.cast(v2).reset, 0b101)
v3 = View(ArrayLayout(unsigned(2), 2),
reset=[0b01, 0b10])
self.assertEqual(Value.cast(v3).reset, 0b1001)
def test_construct_signal_reset_less(self):
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), reset_less=True)
self.assertEqual(Value.cast(v).reset_less, True)
def test_construct_signal_attrs(self):
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), attrs={"debug": 1})
self.assertEqual(Value.cast(v).attrs, {"debug": 1})
def test_construct_signal_decoder(self):
decoder = lambda x: f"{x}"
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), decoder=decoder)
self.assertEqual(Value.cast(v).decoder, decoder)
def test_layout_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^View layout must be a Layout instance, not <.+?>$"):
View(object(), Signal(1))
def test_target_wrong_type(self):
with self.assertRaisesRegex(TypeError,
r"^View target must be a value-castable object, not <.+?>$"):
View(StructLayout({}), object())
def test_target_wrong_size(self):
with self.assertRaisesRegex(ValueError,
r"^View target is 2 bit\(s\) wide, which is not compatible with the 1 bit\(s\) "
r"wide view layout$"):
View(StructLayout({"a": unsigned(1)}), Signal(2))
def test_signal_reset_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Layout initializer must be a mapping or a sequence, not 1$"):
View(StructLayout({}), reset=0b1)
def test_target_signal_wrong(self):
with self.assertRaisesRegex(ValueError,
r"^View target cannot be provided at the same time as any of the Signal "
r"constructor arguments \(name, reset, reset_less, attrs, decoder\)$"):
View(StructLayout({}), Signal(), reset=0b1)
def test_getitem(self):
v = View(UnionLayout({
"a": unsigned(2),
"s": StructLayout({
"b": unsigned(1),
"c": unsigned(3)
}),
"p": 1,
"q": signed(1),
"r": ArrayLayout(unsigned(2), 2),
"t": ArrayLayout(StructLayout({
"u": unsigned(1),
"v": unsigned(1)
}), 2),
}))
cv = Value.cast(v)
i = Signal(1)
self.assertEqual(cv.shape(), unsigned(4))
self.assertRepr(v["a"], "(slice (sig v) 0:2)")
self.assertEqual(v["a"].shape(), unsigned(2))
self.assertRepr(v["s"]["b"], "(slice (slice (sig v) 0:4) 0:1)")
self.assertRepr(v["s"]["c"], "(slice (slice (sig v) 0:4) 1:4)")
self.assertRepr(v["p"], "(slice (sig v) 0:1)")
self.assertEqual(v["p"].shape(), unsigned(1))
self.assertRepr(v["q"], "(s (slice (sig v) 0:1))")
self.assertEqual(v["q"].shape(), signed(1))
self.assertRepr(v["r"][0], "(slice (slice (sig v) 0:4) 0:2)")
self.assertRepr(v["r"][1], "(slice (slice (sig v) 0:4) 2:4)")
self.assertRepr(v["r"][i], "(part (slice (sig v) 0:4) (sig i) 2 2)")
self.assertRepr(v["t"][0]["u"], "(slice (slice (slice (sig v) 0:4) 0:2) 0:1)")
self.assertRepr(v["t"][1]["v"], "(slice (slice (slice (sig v) 0:4) 2:4) 1:2)")
def test_index_wrong_missing(self):
with self.assertRaisesRegex(KeyError,
r"^'a'$"):
View(StructLayout({}))["a"]
def test_index_wrong_struct_dynamic(self):
with self.assertRaisesRegex(TypeError,
r"^Only views with array layout, not StructLayout\(\{\}\), may be indexed "
r"with a value$"):
View(StructLayout({}))[Signal(1)]
def test_getattr(self):
v = View(UnionLayout({
"a": unsigned(2),
"s": StructLayout({
"b": unsigned(1),
"c": unsigned(3)
}),
"p": 1,
"q": signed(1),
}))
cv = Value.cast(v)
i = Signal(1)
self.assertEqual(cv.shape(), unsigned(4))
self.assertRepr(v.a, "(slice (sig v) 0:2)")
self.assertEqual(v.a.shape(), unsigned(2))
self.assertRepr(v.s.b, "(slice (slice (sig v) 0:4) 0:1)")
self.assertRepr(v.s.c, "(slice (slice (sig v) 0:4) 1:4)")
self.assertRepr(v.p, "(slice (sig v) 0:1)")
self.assertEqual(v.p.shape(), unsigned(1))
self.assertRepr(v.q, "(s (slice (sig v) 0:1))")
self.assertEqual(v.q.shape(), signed(1))
def test_getattr_reserved(self):
v = View(UnionLayout({
"_a": unsigned(2)
}))
self.assertRepr(v["_a"], "(slice (sig v) 0:2)")
def test_attr_wrong_missing(self):
with self.assertRaisesRegex(AttributeError,
r"^View of \(sig \$signal\) does not have a field 'a'; "
r"did you mean one of: 'b', 'c'\?$"):
View(StructLayout({"b": unsigned(1), "c": signed(1)})).a
def test_attr_wrong_reserved(self):
with self.assertRaisesRegex(AttributeError,
r"^View of \(sig \$signal\) field '_c' has a reserved name "
r"and may only be accessed by indexing$"):
View(StructLayout({"_c": signed(1)}))._c
class StructTestCase(FHDLTestCase):
def test_construct(self):
class S(Struct):
a: unsigned(1)
b: signed(3)
self.assertEqual(Shape.cast(S), unsigned(4))
self.assertEqual(Layout.cast(S), StructLayout({
"a": unsigned(1),
"b": signed(3)
}))
v = S()
self.assertEqual(Layout.of(v), S)
self.assertEqual(Value.cast(v).shape(), S)
self.assertEqual(Value.cast(v).name, "v")
self.assertRepr(v.a, "(slice (sig v) 0:1)")
self.assertRepr(v.b, "(s (slice (sig v) 1:4))")
def test_construct_nested(self):
Q = StructLayout({"r": signed(2), "s": signed(2)})
class R(Struct):
p: 4
q: Q
class S(Struct):
a: unsigned(1)
b: R
self.assertEqual(S, unsigned(9))
v = S()
self.assertIs(Layout.of(v), S)
self.assertIsInstance(v, S)
self.assertIs(Layout.of(v.b), R)
self.assertIsInstance(v.b, R)
self.assertIs(Layout.of(v.b.q), Q)
self.assertIsInstance(v.b.q, View)
self.assertRepr(v.b.p, "(slice (slice (sig v) 1:9) 0:4)")
self.assertRepr(v.b.q.as_value(), "(slice (slice (sig v) 1:9) 4:8)")
self.assertRepr(v.b.q.r, "(s (slice (slice (slice (sig v) 1:9) 4:8) 0:2))")
self.assertRepr(v.b.q.s, "(s (slice (slice (slice (sig v) 1:9) 4:8) 2:4))")
def test_construct_signal_kwargs(self):
decoder = lambda x: f"{x}"
v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}),
name="named", reset={"b": 0b1}, reset_less=True, attrs={"debug": 1}, decoder=decoder)
s = Value.cast(v)
self.assertEqual(s.name, "named")
self.assertEqual(s.reset, 0b010)
self.assertEqual(s.reset_less, True)
self.assertEqual(s.attrs, {"debug": 1})
self.assertEqual(s.decoder, decoder)
class UnionTestCase(FHDLTestCase):
def test_construct(self):
class U(Union):
a: unsigned(1)
b: signed(3)
self.assertEqual(Shape.cast(U), unsigned(3))
self.assertEqual(Layout.cast(U), UnionLayout({
"a": unsigned(1),
"b": signed(3)
}))
v = U()
self.assertEqual(Layout.of(v), U)
self.assertEqual(Value.cast(v).shape(), U)
self.assertRepr(v.a, "(slice (sig v) 0:1)")
self.assertRepr(v.b, "(s (slice (sig v) 0:3))")
def test_construct_signal_kwargs(self):
decoder = lambda x: f"{x}"
v = View(UnionLayout({"a": unsigned(1), "b": unsigned(2)}),
name="named", reset={"b": 0b1}, reset_less=True, attrs={"debug": 1}, decoder=decoder)
s = Value.cast(v)
self.assertEqual(s.name, "named")
self.assertEqual(s.reset, 0b01)
self.assertEqual(s.reset_less, True)
self.assertEqual(s.attrs, {"debug": 1})
self.assertEqual(s.decoder, decoder)
# Examples from https://github.com/amaranth-lang/amaranth/issues/693
class RFCExamplesTestCase(TestCase):
@staticmethod
def simulate(m):
def wrapper(fn):
sim = Simulator(m)
sim.add_process(fn)
sim.run()
return wrapper
def test_rfc_example_1(self):
class Float32(Struct):
fraction: unsigned(23)
exponent: unsigned(8)
sign: unsigned(1)
self.assertEqual(Float32.as_shape().size, 32)
flt_a = Float32()
flt_b = Float32(Const(0b00111110001000000000000000000000, 32))
m1 = Module()
with m1.If(flt_b.fraction > 0):
m1.d.comb += [
flt_a.sign.eq(1),
flt_a.exponent.eq(127)
]
@self.simulate(m1)
def check_m1():
self.assertEqual((yield flt_a.as_value()), 0xbf800000)
class FloatOrInt32(Union):
float: Float32
int: signed(32)
f_or_i = FloatOrInt32()
is_gt_1 = Signal()
m2 = Module()
m2.d.comb += [
f_or_i.int.eq(0x41C80000),
is_gt_1.eq(f_or_i.float.exponent >= 127) # => 1
]
@self.simulate(m2)
def check_m2():
self.assertEqual((yield is_gt_1), 1)
class Op(Enum):
ADD = 0
SUB = 1
adder_op_layout = StructLayout({
"op": Op,
"a": Float32,
"b": Float32
})
adder_op_storage = Signal(adder_op_layout)
self.assertEqual(len(adder_op_storage), 65)
adder_op = View(adder_op_layout, adder_op_storage)
m3 = Module()
m3.d.comb += [
adder_op.eq(Op.SUB),
adder_op.a.eq(flt_a),
adder_op.b.eq(flt_b)
]
@self.simulate(m3)
def check_m3():
self.assertEqual((yield adder_op.as_value()), 0x7c40000000000001)
def test_rfc_example_2(self):
class Kind(Enum):
ONE_SIGNED = 0
TWO_UNSIGNED = 1
layout1 = StructLayout({
"kind": Kind,
"value": UnionLayout({
"one_signed": signed(2),
"two_unsigned": ArrayLayout(unsigned(1), 2)
})
})
self.assertEqual(layout1.size, 3)
sig1 = Signal(layout1)
self.assertEqual(sig1.shape(), unsigned(3))
view1 = View(layout1, sig1)
self.assertIs(Value.cast(view1), sig1)
view2 = View(layout1)
self.assertIsInstance(Value.cast(view2), Signal)
self.assertEqual(Value.cast(view2).shape(), unsigned(3))
m1 = Module()
m1.d.comb += [
view1.kind.eq(Kind.TWO_UNSIGNED),
view1.value.two_unsigned[0].eq(1),
]
@self.simulate(m1)
def check_m1():
self.assertEqual((yield view1.as_value()), 0b011)
class SomeVariant(Struct):
class Value(Union):
one_signed: signed(2)
two_unsigned: ArrayLayout(unsigned(1), 2)
kind: Kind
value: Value
self.assertEqual(SomeVariant, unsigned(3))
view3 = SomeVariant()
self.assertIsInstance(Value.cast(view3), Signal)
self.assertEqual(Value.cast(view3).shape(), unsigned(3))
m2 = Module()
m2.submodules += m1
m2.d.comb += [
view3.kind.eq(Kind.ONE_SIGNED),
view3.value.eq(view1.value)
]
@self.simulate(m2)
def check_m2():
self.assertEqual((yield view3.as_value()), 0b010)
sig2 = Signal(SomeVariant)
self.assertEqual(sig2.shape(), unsigned(3))
layout2 = StructLayout({
"ready": unsigned(1),
"payload": SomeVariant
})
self.assertEqual(layout2.size, 4)
self.assertEqual(layout1, Layout.cast(SomeVariant))
self.assertIs(SomeVariant, Layout.of(view3))
def test_rfc_example_3(self):
class Stream8b10b(View):
data: Signal
ctrl: Signal
def __init__(self, value=None, *, width: int):
super().__init__(StructLayout({
"data": unsigned(8 * width),
"ctrl": unsigned(width)
}), value)
self.assertEqual(len(Stream8b10b(width=1).data), 8)
self.assertEqual(len(Stream8b10b(width=4).data), 32)