From 7e3e10e7333630999c5b51440f0e5f0daef0358c Mon Sep 17 00:00:00 2001 From: Catherine Date: Wed, 6 Apr 2022 04:00:22 +0000 Subject: [PATCH] lib.data: implement RFC 1 "Aggregate data structure library". See amaranth-lang/rfcs#1. --- amaranth/lib/data.py | 432 +++++++++++++++++++++++ tests/test_lib_data.py | 760 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1192 insertions(+) create mode 100644 amaranth/lib/data.py create mode 100644 tests/test_lib_data.py diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py new file mode 100644 index 0000000..01b77d9 --- /dev/null +++ b/amaranth/lib/data.py @@ -0,0 +1,432 @@ +from abc import ABCMeta, abstractmethod, abstractproperty +from collections.abc import Mapping, Sequence + +from amaranth.hdl import * +from amaranth.hdl.ast import ShapeCastable, ValueCastable + + +__all__ = [ + "Field", "Layout", "StructLayout", "UnionLayout", "ArrayLayout", "FlexibleLayout", + "View", "Struct", "Union", +] + + +class Field: + def __init__(self, shape, offset): + self.shape = shape + self.offset = offset + + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, shape): + try: + Shape.cast(shape) + except TypeError as e: + raise TypeError("Field shape must be a shape-castable object, not {!r}" + .format(shape)) from e + self._shape = shape + + @property + def offset(self): + return self._offset + + @offset.setter + def offset(self, offset): + if not isinstance(offset, int) or offset < 0: + raise TypeError("Field offset must be a non-negative integer, not {!r}" + .format(offset)) + self._offset = offset + + @property + def width(self): + return Shape.cast(self.shape).width + + def __eq__(self, other): + return (isinstance(other, Field) and + self._shape == other.shape and self._offset == other.offset) + + def __repr__(self): + return f"Field({self._shape!r}, {self._offset})" + + +class Layout(ShapeCastable, metaclass=ABCMeta): + @staticmethod + def cast(obj): + """Cast a shape-castable object to a layout.""" + while isinstance(obj, ShapeCastable): + if isinstance(obj, Layout): + return obj + new_obj = obj.as_shape() + if new_obj is obj: + break + obj = new_obj + Shape.cast(obj) # delegate non-layout-specific error handling to Shape + raise TypeError("Object {!r} cannot be converted to a data layout" + .format(obj)) + + @staticmethod + def of(obj): + """Extract the layout from a view.""" + if not isinstance(obj, View): + raise TypeError("Object {!r} is not a data view" + .format(obj)) + return obj._View__orig_layout + + @abstractmethod + def __iter__(self): + """Iterate the layout, yielding ``(key, field)`` pairs. Keys may be strings or integers.""" + + @abstractmethod + def __getitem__(self, key): + """Retrieve the :class:`Field` associated with the ``key``, or raise ``KeyError``.""" + + size = abstractproperty() + """The number of bits in the representation defined by the layout.""" + + def as_shape(self): + """Convert the representation defined by the layout to an unsigned :class:`Shape`.""" + return unsigned(self.size) + + def __eq__(self, other): + """Compare the layout with another. + + Two layouts are equal if they have the same size and the same fields under the same names. + The order of the fields is not considered. + """ + while isinstance(other, ShapeCastable) and not isinstance(other, Layout): + new_other = other.as_shape() + if new_other is other: + break + other = new_other + return (isinstance(other, Layout) and self.size == other.size and + dict(iter(self)) == dict(iter(other))) + + def _convert_to_int(self, value): + """Convert ``value``, which may be a dict or an array of field values, to an integer using + the representation defined by this layout. + + This method is roughly equivalent to :meth:`Const.normalize`. It is private because + Amaranth does not currently have a concept of a constant initializer; this requires + an RFC. It will be renamed or removed in a future version.""" + if isinstance(value, Mapping): + iterator = value.items() + elif isinstance(value, Sequence): + iterator = enumerate(value) + else: + raise TypeError("Layout initializer must be a mapping or a sequence, not {!r}" + .format(value)) + + int_value = 0 + for key, key_value in iterator: + field = self[key] + if isinstance(field.shape, Layout): + key_value = field.shape._convert_to_int(key_value) + int_value |= Const.normalize(key_value, Shape.cast(field.shape)) << field.offset + return int_value + + +class StructLayout(Layout): + def __init__(self, members): + self.members = members + + @property + def members(self): + return {key: field.shape for key, field in self._fields.items()} + + @members.setter + def members(self, members): + offset = 0 + self._fields = {} + if not isinstance(members, Mapping): + raise TypeError("Struct layout members must be provided as a mapping, not {!r}" + .format(members)) + for key, shape in members.items(): + if not isinstance(key, str): + raise TypeError("Struct layout member name must be a string, not {!r}" + .format(key)) + try: + cast_shape = Shape.cast(shape) + except TypeError as e: + raise TypeError("Struct layout member shape must be a shape-castable object, " + "not {!r}" + .format(shape)) from e + self._fields[key] = Field(shape, offset) + offset += cast_shape.width + + def __iter__(self): + return iter(self._fields.items()) + + def __getitem__(self, key): + return self._fields[key] + + @property + def size(self): + return max((field.offset + field.width for field in self._fields.values()), default=0) + + def __repr__(self): + return f"StructLayout({self.members!r})" + + +class UnionLayout(Layout): + def __init__(self, members): + self.members = members + + @property + def members(self): + return {key: field.shape for key, field in self._fields.items()} + + @members.setter + def members(self, members): + self._fields = {} + if not isinstance(members, Mapping): + raise TypeError("Union layout members must be provided as a mapping, not {!r}" + .format(members)) + for key, shape in members.items(): + if not isinstance(key, str): + raise TypeError("Union layout member name must be a string, not {!r}" + .format(key)) + try: + cast_shape = Shape.cast(shape) + except TypeError as e: + raise TypeError("Union layout member shape must be a shape-castable object, " + "not {!r}" + .format(shape)) from e + self._fields[key] = Field(shape, 0) + + def __iter__(self): + return iter(self._fields.items()) + + def __getitem__(self, key): + return self._fields[key] + + @property + def size(self): + return max((field.width for field in self._fields.values()), default=0) + + def __repr__(self): + return f"UnionLayout({self.members!r})" + + +class ArrayLayout(Layout): + def __init__(self, elem_shape, length): + self.elem_shape = elem_shape + self.length = length + + @property + def elem_shape(self): + return self._elem_shape + + @elem_shape.setter + def elem_shape(self, elem_shape): + try: + Shape.cast(elem_shape) + except TypeError as e: + raise TypeError("Array layout element shape must be a shape-castable object, " + "not {!r}" + .format(elem_shape)) from e + self._elem_shape = elem_shape + + @property + def length(self): + return self._length + + @length.setter + def length(self, length): + if not isinstance(length, int) or length < 0: + raise TypeError("Array layout length must be a non-negative integer, not {!r}" + .format(length)) + self._length = length + + def __iter__(self): + offset = 0 + for index in range(self._length): + yield index, Field(self._elem_shape, offset) + offset += Shape.cast(self._elem_shape).width + + def __getitem__(self, key): + if isinstance(key, int): + if key not in range(-self._length, self._length): + # Layout's interface requires us to raise KeyError, not IndexError + raise KeyError(key) + if key < 0: + key += self._length + return Field(self._elem_shape, key * Shape.cast(self._elem_shape).width) + raise TypeError("Cannot index array layout with {!r}".format(key)) + + @property + def size(self): + return Shape.cast(self._elem_shape).width * self.length + + def __repr__(self): + return f"ArrayLayout({self._elem_shape!r}, {self.length})" + + +class FlexibleLayout(Layout): + def __init__(self, size, fields): + self.size = size + self.fields = fields + + @property + def size(self): + return self._size + + @size.setter + def size(self, size): + if not isinstance(size, int) or size < 0: + raise TypeError("Flexible layout size must be a non-negative integer, not {!r}" + .format(size)) + if hasattr(self, "_fields") and self._fields: + endmost_name, endmost_field = max(self._fields.items(), + key=lambda pair: pair[1].offset + pair[1].width) + if endmost_field.offset + endmost_field.width > size: + raise ValueError("Flexible layout size {} does not cover the field '{}', which " + "ends at bit {}" + .format(size, endmost_name, + endmost_field.offset + endmost_field.width)) + self._size = size + + @property + def fields(self): + return {**self._fields} + + @fields.setter + def fields(self, fields): + self._fields = {} + if not isinstance(fields, Mapping): + raise TypeError("Flexible layout fields must be provided as a mapping, not {!r}" + .format(fields)) + for key, field in fields.items(): + if not isinstance(key, (int, str)) or (isinstance(key, int) and key < 0): + raise TypeError("Flexible layout field name must be a non-negative integer or " + "a string, not {!r}" + .format(key)) + if not isinstance(field, Field): + raise TypeError("Flexible layout field value must be a Field instance, not {!r}" + .format(field)) + if field.offset + field.width > self._size: + raise ValueError("Flexible layout field '{}' ends at bit {}, exceeding " + "the size of {} bit(s)" + .format(key, field.offset + field.width, self._size)) + self._fields[key] = field + + def __iter__(self): + return iter(self._fields.items()) + + def __getitem__(self, key): + if isinstance(key, (int, str)): + return self._fields[key] + raise TypeError("Cannot index flexible layout with {!r}".format(key)) + + def __repr__(self): + return f"FlexibleLayout({self._size}, {self._fields!r})" + + +class View(ValueCastable): + def __init__(self, layout, target=None, *, name=None, reset=None, reset_less=None, + attrs=None, decoder=None, src_loc_at=0): + try: + cast_layout = Layout.cast(layout) + except TypeError as e: + raise TypeError("View layout must be a Layout instance, not {!r}" + .format(layout)) from e + if target is not None: + if (name is not None or reset is not None or reset_less is not None or + attrs is not None or decoder is not None): + raise ValueError("View target cannot be provided at the same time as any of " + "the Signal constructor arguments (name, reset, reset_less, " + "attrs, decoder)") + try: + cast_target = Value.cast(target) + except TypeError as e: + raise TypeError("View target must be a value-castable object, not {!r}" + .format(target)) from e + if len(cast_target) != cast_layout.size: + raise ValueError("View target is {} bit(s) wide, which is not compatible with " + "the {} bit(s) wide view layout" + .format(len(cast_target), cast_layout.size)) + else: + if reset is None: + reset = 0 + else: + reset = cast_layout._convert_to_int(reset) + if reset_less is None: + reset_less = False + cast_target = Signal(cast_layout, name=name, reset=reset, reset_less=reset_less, + attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1) + self.__orig_layout = layout + self.__layout = cast_layout + self.__target = cast_target + + @ValueCastable.lowermethod + def as_value(self): + return self.__target + + def eq(self, other): + return self.as_value().eq(other) + + def __getitem__(self, key): + if isinstance(self.__layout, ArrayLayout): + shape = self.__layout.elem_shape + value = self.__target.word_select(key, Shape.cast(self.__layout.elem_shape).width) + else: + if isinstance(key, (Value, ValueCastable)): + raise TypeError("Only views with array layout, not {!r}, may be indexed " + "with a value" + .format(self.__layout)) + field = self.__layout[key] + shape = field.shape + value = self.__target[field.offset:field.offset + field.width] + if isinstance(shape, _AggregateMeta): + return shape(value) + if isinstance(shape, Layout): + return View(shape, value) + if Shape.cast(shape).signed: + return value.as_signed() + else: + return value + + def __getattr__(self, name): + try: + item = self[name] + except KeyError: + raise AttributeError("View of {!r} does not have a field {!r}; " + "did you mean one of: {}?" + .format(self.__target, name, + ", ".join(repr(name) + for name, field in self.__layout))) + if name.startswith("_"): + raise AttributeError("View of {!r} field {!r} has a reserved name and may only be " + "accessed by indexing" + .format(self.__target, name)) + return item + + +class _AggregateMeta(ShapeCastable, type): + def __new__(metacls, name, bases, namespace, *, _layout_cls=None, **kwargs): + cls = type.__new__(metacls, name, bases, namespace, **kwargs) + if _layout_cls is not None: + cls.__layout_cls = _layout_cls + if "__annotations__" in namespace: + cls.__layout = cls.__layout_cls(namespace["__annotations__"]) + return cls + + def as_shape(cls): + return cls.__layout + + +class _Aggregate(View, metaclass=_AggregateMeta): + def __init__(self, target=None, *, name=None, reset=None, reset_less=None, + attrs=None, decoder=None, src_loc_at=0): + super().__init__(self.__class__, target, name=name, reset=reset, reset_less=reset_less, + attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1) + + +class Struct(_Aggregate, _layout_cls=StructLayout): + pass + + +class Union(_Aggregate, _layout_cls=UnionLayout): + pass diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py new file mode 100644 index 0000000..fbc5b1b --- /dev/null +++ b/tests/test_lib_data.py @@ -0,0 +1,760 @@ +from enum import Enum +from unittest import TestCase + +from amaranth.hdl import * +from amaranth.hdl.ast import ShapeCastable +from amaranth.lib.data import * +from amaranth.sim import Simulator + +from .utils import * + + +class MockShapeCastable(ShapeCastable): + def __init__(self, shape): + self.shape = shape + + def as_shape(self): + return self.shape + + +class FieldTestCase(TestCase): + def test_construct(self): + f = Field(unsigned(2), 1) + self.assertEqual(f.shape, unsigned(2)) + self.assertEqual(f.offset, 1) + self.assertEqual(f.width, 2) + + def test_repr(self): + f = Field(unsigned(2), 1) + self.assertEqual(repr(f), "Field(unsigned(2), 1)") + + def test_equal(self): + f1 = Field(unsigned(2), 1) + f2 = Field(unsigned(2), 0) + self.assertNotEqual(f1, f2) + f3 = Field(unsigned(2), 1) + self.assertEqual(f1, f3) + f4 = Field(2, 1) + self.assertEqual(f1, f4) + f5 = Field(MockShapeCastable(unsigned(2)), 1) + self.assertEqual(f1, f5) + self.assertNotEqual(f1, object()) + + def test_preserve_shape(self): + sc = MockShapeCastable(unsigned(2)) + f = Field(sc, 0) + self.assertEqual(f.shape, sc) + self.assertEqual(f.width, 2) + + def test_shape_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Field shape must be a shape-castable object, not <.+>$"): + Field(object(), 0) + + def test_offset_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Field offset must be a non-negative integer, not <.+>$"): + Field(unsigned(2), object()) + with self.assertRaisesRegex(TypeError, + r"^Field offset must be a non-negative integer, not -1$"): + Field(unsigned(2), -1) + + +class StructLayoutTestCase(TestCase): + def test_construct(self): + sl = StructLayout({ + "a": unsigned(1), + "b": 2 + }) + self.assertEqual(sl.members, { + "a": unsigned(1), + "b": 2 + }) + self.assertEqual(sl.size, 3) + self.assertEqual(list(iter(sl)), [ + ("a", Field(unsigned(1), 0)), + ("b", Field(2, 1)) + ]) + self.assertEqual(sl["a"], Field(unsigned(1), 0)) + self.assertEqual(sl["b"], Field(2, 1)) + + def test_size_empty(self): + self.assertEqual(StructLayout({}).size, 0) + + def test_eq(self): + self.assertEqual(StructLayout({"a": unsigned(1), "b": 2}), + StructLayout({"a": unsigned(1), "b": unsigned(2)})) + self.assertNotEqual(StructLayout({"a": unsigned(1), "b": 2}), + StructLayout({"b": unsigned(2), "a": unsigned(1)})) + self.assertNotEqual(StructLayout({"a": unsigned(1), "b": 2}), + StructLayout({"a": unsigned(1)})) + + def test_repr(self): + sl = StructLayout({ + "a": unsigned(1), + "b": 2 + }) + self.assertEqual(repr(sl), "StructLayout({'a': unsigned(1), 'b': 2})") + + def test_members_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Struct layout members must be provided as a mapping, not <.+>$"): + StructLayout(object()) + + def test_member_key_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Struct layout member name must be a string, not 1\.0$"): + StructLayout({1.0: unsigned(1)}) + + def test_member_value_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Struct layout member shape must be a shape-castable object, not 1\.0$"): + StructLayout({"a": 1.0}) + + +class UnionLayoutTestCase(TestCase): + def test_construct(self): + ul = UnionLayout({ + "a": unsigned(1), + "b": 2 + }) + self.assertEqual(ul.members, { + "a": unsigned(1), + "b": 2 + }) + self.assertEqual(ul.size, 2) + self.assertEqual(list(iter(ul)), [ + ("a", Field(unsigned(1), 0)), + ("b", Field(2, 0)) + ]) + self.assertEqual(ul["a"], Field(unsigned(1), 0)) + self.assertEqual(ul["b"], Field(2, 0)) + + def test_size_empty(self): + self.assertEqual(UnionLayout({}).size, 0) + + def test_eq(self): + self.assertEqual(UnionLayout({"a": unsigned(1), "b": 2}), + UnionLayout({"a": unsigned(1), "b": unsigned(2)})) + self.assertEqual(UnionLayout({"a": unsigned(1), "b": 2}), + UnionLayout({"b": unsigned(2), "a": unsigned(1)})) + self.assertNotEqual(UnionLayout({"a": unsigned(1), "b": 2}), + UnionLayout({"a": unsigned(1)})) + + def test_repr(self): + ul = UnionLayout({ + "a": unsigned(1), + "b": 2 + }) + self.assertEqual(repr(ul), "UnionLayout({'a': unsigned(1), 'b': 2})") + + def test_members_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Union layout members must be provided as a mapping, not <.+>$"): + UnionLayout(object()) + + def test_member_key_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Union layout member name must be a string, not 1\.0$"): + UnionLayout({1.0: unsigned(1)}) + + def test_member_value_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Union layout member shape must be a shape-castable object, not 1\.0$"): + UnionLayout({"a": 1.0}) + + +class ArrayLayoutTestCase(TestCase): + def test_construct(self): + al = ArrayLayout(unsigned(2), 3) + self.assertEqual(al.elem_shape, unsigned(2)) + self.assertEqual(al.length, 3) + self.assertEqual(list(iter(al)), [ + (0, Field(unsigned(2), 0)), + (1, Field(unsigned(2), 2)), + (2, Field(unsigned(2), 4)), + ]) + self.assertEqual(al[0], Field(unsigned(2), 0)) + self.assertEqual(al[1], Field(unsigned(2), 2)) + self.assertEqual(al[2], Field(unsigned(2), 4)) + self.assertEqual(al[-1], Field(unsigned(2), 4)) + self.assertEqual(al[-2], Field(unsigned(2), 2)) + self.assertEqual(al[-3], Field(unsigned(2), 0)) + self.assertEqual(al.size, 6) + + def test_shape_castable(self): + al = ArrayLayout(2, 3) + self.assertEqual(al.size, 6) + + def test_eq(self): + self.assertEqual(ArrayLayout(unsigned(2), 3), + ArrayLayout(unsigned(2), 3)) + self.assertNotEqual(ArrayLayout(unsigned(2), 3), + ArrayLayout(unsigned(2), 4)) + + def test_repr(self): + al = ArrayLayout(unsigned(2), 3) + self.assertEqual(repr(al), "ArrayLayout(unsigned(2), 3)") + + def test_elem_shape_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Array layout element shape must be a shape-castable object, not <.+>$"): + ArrayLayout(object(), 1) + + def test_length_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Array layout length must be a non-negative integer, not <.+>$"): + ArrayLayout(unsigned(1), object()) + with self.assertRaisesRegex(TypeError, + r"^Array layout length must be a non-negative integer, not -1$"): + ArrayLayout(unsigned(1), -1) + + def test_key_wrong_bounds(self): + al = ArrayLayout(unsigned(2), 3) + with self.assertRaisesRegex(KeyError, r"^4$"): + al[4] + with self.assertRaisesRegex(KeyError, r"^-4$"): + al[-4] + + def test_key_wrong_type(self): + al = ArrayLayout(unsigned(2), 3) + with self.assertRaisesRegex(TypeError, + r"^Cannot index array layout with 'a'$"): + al["a"] + + +class FlexibleLayoutTestCase(TestCase): + def test_construct(self): + il = FlexibleLayout(8, { + "a": Field(unsigned(1), 1), + "b": Field(unsigned(3), 0), + 0: Field(unsigned(2), 5) + }) + self.assertEqual(il.size, 8) + self.assertEqual(il.fields, { + "a": Field(unsigned(1), 1), + "b": Field(unsigned(3), 0), + 0: Field(unsigned(2), 5) + }) + self.assertEqual(list(iter(il)), [ + ("a", Field(unsigned(1), 1)), + ("b", Field(unsigned(3), 0)), + (0, Field(unsigned(2), 5)) + ]) + self.assertEqual(il["a"], Field(unsigned(1), 1)) + self.assertEqual(il["b"], Field(unsigned(3), 0)) + self.assertEqual(il[0], Field(unsigned(2), 5)) + + def test_eq(self): + self.assertEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0)}), + FlexibleLayout(3, {"a": Field(unsigned(1), 0)})) + self.assertNotEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0)}), + FlexibleLayout(4, {"a": Field(unsigned(1), 0)})) + self.assertNotEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0)}), + FlexibleLayout(3, {"a": Field(unsigned(1), 1)})) + + def test_eq_duck(self): + self.assertEqual(FlexibleLayout(3, {"a": Field(unsigned(1), 0), + "b": Field(unsigned(2), 1)}), + StructLayout({"a": unsigned(1), + "b": unsigned(2)})) + self.assertEqual(FlexibleLayout(2, {"a": Field(unsigned(1), 0), + "b": Field(unsigned(2), 0)}), + UnionLayout({"a": unsigned(1), + "b": unsigned(2)})) + + def test_repr(self): + il = FlexibleLayout(8, { + "a": Field(unsigned(1), 1), + "b": Field(unsigned(3), 0), + 0: Field(unsigned(2), 5) + }) + self.assertEqual(repr(il), "FlexibleLayout(8, {" + "'a': Field(unsigned(1), 1), " + "'b': Field(unsigned(3), 0), " + "0: Field(unsigned(2), 5)})") + + def test_fields_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Flexible layout fields must be provided as a mapping, not <.+>$"): + FlexibleLayout(8, object()) + + def test_field_key_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Flexible layout field name must be a non-negative integer or a string, " + r"not 1\.0$"): + FlexibleLayout(8, {1.0: unsigned(1)}) + with self.assertRaisesRegex(TypeError, + r"^Flexible layout field name must be a non-negative integer or a string, " + r"not -1$"): + FlexibleLayout(8, {-1: unsigned(1)}) + + def test_field_value_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Flexible layout field value must be a Field instance, not 1\.0$"): + FlexibleLayout(8, {"a": 1.0}) + + def test_size_wrong_negative(self): + with self.assertRaisesRegex(TypeError, + r"^Flexible layout size must be a non-negative integer, not -1$"): + FlexibleLayout(-1, {}) + + def test_size_wrong_small(self): + with self.assertRaisesRegex(ValueError, + r"^Flexible layout field 'a' ends at bit 8, exceeding the size of 4 bit\(s\)$"): + FlexibleLayout(4, {"a": Field(unsigned(8), 0)}) + with self.assertRaisesRegex(ValueError, + r"^Flexible layout field 'a' ends at bit 5, exceeding the size of 4 bit\(s\)$"): + FlexibleLayout(4, {"a": Field(unsigned(2), 3)}) + + def test_size_wrong_shrink(self): + il = FlexibleLayout(8, {"a": Field(unsigned(2), 3)}) + with self.assertRaisesRegex(ValueError, + r"^Flexible layout size 4 does not cover the field 'a', which ends at bit 5$"): + il.size = 4 + + def test_key_wrong_missing(self): + il = FlexibleLayout(8, {"a": Field(unsigned(2), 3)}) + with self.assertRaisesRegex(KeyError, + r"^0$"): + il[0] + + def test_key_wrong_type(self): + il = FlexibleLayout(8, {"a": Field(unsigned(2), 3)}) + with self.assertRaisesRegex(TypeError, + r"^Cannot index flexible layout with <.+>$"): + il[object()] + + +class LayoutTestCase(TestCase): + def test_cast(self): + sl = StructLayout({}) + self.assertIs(Layout.cast(sl), sl) + + def test_cast_wrong_not_layout(self): + with self.assertRaisesRegex(TypeError, + r"^Object unsigned\(1\) cannot be converted to a data layout$"): + Layout.cast(unsigned(1)) + + def test_cast_wrong_type(self): + with self.assertRaisesRegex(TypeError, + r"^Object <.+> cannot be converted to an Amaranth shape$"): + Layout.cast(object()) + + def test_cast_wrong_recur(self): + sc = MockShapeCastable(None) + sc.shape = sc + with self.assertRaisesRegex(RecursionError, + r"^Shape-castable object <.+> casts to itself$"): + Layout.cast(sc) + + def test_of_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Object <.+> is not a data view$"): + Layout.of(object()) + + def test_eq_wrong_recur(self): + sc = MockShapeCastable(None) + sc.shape = sc + self.assertNotEqual(StructLayout({}), sc) + + +class ViewTestCase(FHDLTestCase): + def test_construct(self): + s = Signal(3) + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), s) + self.assertIs(Value.cast(v), s) + self.assertRepr(v["a"], "(slice (sig s) 0:1)") + self.assertRepr(v["b"], "(slice (sig s) 1:3)") + + def test_construct_signal(self): + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)})) + cv = Value.cast(v) + self.assertIsInstance(cv, Signal) + self.assertEqual(cv.shape(), unsigned(3)) + self.assertEqual(cv.name, "v") + + def test_construct_signal_name(self): + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), name="named") + self.assertEqual(Value.cast(v).name, "named") + + def test_construct_signal_reset(self): + v1 = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), + reset={"a": 0b1, "b": 0b10}) + self.assertEqual(Value.cast(v1).reset, 0b101) + v2 = View(StructLayout({"a": unsigned(1), + "b": StructLayout({"x": unsigned(1), "y": unsigned(1)})}), + reset={"a": 0b1, "b": {"x": 0b0, "y": 0b1}}) + self.assertEqual(Value.cast(v2).reset, 0b101) + v3 = View(ArrayLayout(unsigned(2), 2), + reset=[0b01, 0b10]) + self.assertEqual(Value.cast(v3).reset, 0b1001) + + def test_construct_signal_reset_less(self): + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), reset_less=True) + self.assertEqual(Value.cast(v).reset_less, True) + + def test_construct_signal_attrs(self): + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), attrs={"debug": 1}) + self.assertEqual(Value.cast(v).attrs, {"debug": 1}) + + def test_construct_signal_decoder(self): + decoder = lambda x: f"{x}" + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), decoder=decoder) + self.assertEqual(Value.cast(v).decoder, decoder) + + def test_layout_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^View layout must be a Layout instance, not <.+?>$"): + View(object(), Signal(1)) + + def test_target_wrong_type(self): + with self.assertRaisesRegex(TypeError, + r"^View target must be a value-castable object, not <.+?>$"): + View(StructLayout({}), object()) + + def test_target_wrong_size(self): + with self.assertRaisesRegex(ValueError, + r"^View target is 2 bit\(s\) wide, which is not compatible with the 1 bit\(s\) " + r"wide view layout$"): + View(StructLayout({"a": unsigned(1)}), Signal(2)) + + def test_signal_reset_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Layout initializer must be a mapping or a sequence, not 1$"): + View(StructLayout({}), reset=0b1) + + def test_target_signal_wrong(self): + with self.assertRaisesRegex(ValueError, + r"^View target cannot be provided at the same time as any of the Signal " + r"constructor arguments \(name, reset, reset_less, attrs, decoder\)$"): + View(StructLayout({}), Signal(), reset=0b1) + + def test_getitem(self): + v = View(UnionLayout({ + "a": unsigned(2), + "s": StructLayout({ + "b": unsigned(1), + "c": unsigned(3) + }), + "p": 1, + "q": signed(1), + "r": ArrayLayout(unsigned(2), 2), + "t": ArrayLayout(StructLayout({ + "u": unsigned(1), + "v": unsigned(1) + }), 2), + })) + cv = Value.cast(v) + i = Signal(1) + self.assertEqual(cv.shape(), unsigned(4)) + self.assertRepr(v["a"], "(slice (sig v) 0:2)") + self.assertEqual(v["a"].shape(), unsigned(2)) + self.assertRepr(v["s"]["b"], "(slice (slice (sig v) 0:4) 0:1)") + self.assertRepr(v["s"]["c"], "(slice (slice (sig v) 0:4) 1:4)") + self.assertRepr(v["p"], "(slice (sig v) 0:1)") + self.assertEqual(v["p"].shape(), unsigned(1)) + self.assertRepr(v["q"], "(s (slice (sig v) 0:1))") + self.assertEqual(v["q"].shape(), signed(1)) + self.assertRepr(v["r"][0], "(slice (slice (sig v) 0:4) 0:2)") + self.assertRepr(v["r"][1], "(slice (slice (sig v) 0:4) 2:4)") + self.assertRepr(v["r"][i], "(part (slice (sig v) 0:4) (sig i) 2 2)") + self.assertRepr(v["t"][0]["u"], "(slice (slice (slice (sig v) 0:4) 0:2) 0:1)") + self.assertRepr(v["t"][1]["v"], "(slice (slice (slice (sig v) 0:4) 2:4) 1:2)") + + def test_index_wrong_missing(self): + with self.assertRaisesRegex(KeyError, + r"^'a'$"): + View(StructLayout({}))["a"] + + def test_index_wrong_struct_dynamic(self): + with self.assertRaisesRegex(TypeError, + r"^Only views with array layout, not StructLayout\(\{\}\), may be indexed " + r"with a value$"): + View(StructLayout({}))[Signal(1)] + + def test_getattr(self): + v = View(UnionLayout({ + "a": unsigned(2), + "s": StructLayout({ + "b": unsigned(1), + "c": unsigned(3) + }), + "p": 1, + "q": signed(1), + })) + cv = Value.cast(v) + i = Signal(1) + self.assertEqual(cv.shape(), unsigned(4)) + self.assertRepr(v.a, "(slice (sig v) 0:2)") + self.assertEqual(v.a.shape(), unsigned(2)) + self.assertRepr(v.s.b, "(slice (slice (sig v) 0:4) 0:1)") + self.assertRepr(v.s.c, "(slice (slice (sig v) 0:4) 1:4)") + self.assertRepr(v.p, "(slice (sig v) 0:1)") + self.assertEqual(v.p.shape(), unsigned(1)) + self.assertRepr(v.q, "(s (slice (sig v) 0:1))") + self.assertEqual(v.q.shape(), signed(1)) + + def test_getattr_reserved(self): + v = View(UnionLayout({ + "_a": unsigned(2) + })) + self.assertRepr(v["_a"], "(slice (sig v) 0:2)") + + def test_attr_wrong_missing(self): + with self.assertRaisesRegex(AttributeError, + r"^View of \(sig \$signal\) does not have a field 'a'; " + r"did you mean one of: 'b', 'c'\?$"): + View(StructLayout({"b": unsigned(1), "c": signed(1)})).a + + def test_attr_wrong_reserved(self): + with self.assertRaisesRegex(AttributeError, + r"^View of \(sig \$signal\) field '_c' has a reserved name " + r"and may only be accessed by indexing$"): + View(StructLayout({"_c": signed(1)}))._c + + +class StructTestCase(FHDLTestCase): + def test_construct(self): + class S(Struct): + a: unsigned(1) + b: signed(3) + + self.assertEqual(Shape.cast(S), unsigned(4)) + self.assertEqual(Layout.cast(S), StructLayout({ + "a": unsigned(1), + "b": signed(3) + })) + + v = S() + self.assertEqual(Layout.of(v), S) + self.assertEqual(Value.cast(v).shape(), S) + self.assertEqual(Value.cast(v).name, "v") + self.assertRepr(v.a, "(slice (sig v) 0:1)") + self.assertRepr(v.b, "(s (slice (sig v) 1:4))") + + def test_construct_nested(self): + Q = StructLayout({"r": signed(2), "s": signed(2)}) + + class R(Struct): + p: 4 + q: Q + + class S(Struct): + a: unsigned(1) + b: R + + self.assertEqual(S, unsigned(9)) + + v = S() + self.assertIs(Layout.of(v), S) + self.assertIsInstance(v, S) + self.assertIs(Layout.of(v.b), R) + self.assertIsInstance(v.b, R) + self.assertIs(Layout.of(v.b.q), Q) + self.assertIsInstance(v.b.q, View) + self.assertRepr(v.b.p, "(slice (slice (sig v) 1:9) 0:4)") + self.assertRepr(v.b.q.as_value(), "(slice (slice (sig v) 1:9) 4:8)") + self.assertRepr(v.b.q.r, "(s (slice (slice (slice (sig v) 1:9) 4:8) 0:2))") + self.assertRepr(v.b.q.s, "(s (slice (slice (slice (sig v) 1:9) 4:8) 2:4))") + + def test_construct_signal_kwargs(self): + decoder = lambda x: f"{x}" + v = View(StructLayout({"a": unsigned(1), "b": unsigned(2)}), + name="named", reset={"b": 0b1}, reset_less=True, attrs={"debug": 1}, decoder=decoder) + s = Value.cast(v) + self.assertEqual(s.name, "named") + self.assertEqual(s.reset, 0b010) + self.assertEqual(s.reset_less, True) + self.assertEqual(s.attrs, {"debug": 1}) + self.assertEqual(s.decoder, decoder) + + +class UnionTestCase(FHDLTestCase): + def test_construct(self): + class U(Union): + a: unsigned(1) + b: signed(3) + + self.assertEqual(Shape.cast(U), unsigned(3)) + self.assertEqual(Layout.cast(U), UnionLayout({ + "a": unsigned(1), + "b": signed(3) + })) + + v = U() + self.assertEqual(Layout.of(v), U) + self.assertEqual(Value.cast(v).shape(), U) + self.assertRepr(v.a, "(slice (sig v) 0:1)") + self.assertRepr(v.b, "(s (slice (sig v) 0:3))") + + def test_construct_signal_kwargs(self): + decoder = lambda x: f"{x}" + v = View(UnionLayout({"a": unsigned(1), "b": unsigned(2)}), + name="named", reset={"b": 0b1}, reset_less=True, attrs={"debug": 1}, decoder=decoder) + s = Value.cast(v) + self.assertEqual(s.name, "named") + self.assertEqual(s.reset, 0b01) + self.assertEqual(s.reset_less, True) + self.assertEqual(s.attrs, {"debug": 1}) + self.assertEqual(s.decoder, decoder) + + +# Examples from https://github.com/amaranth-lang/amaranth/issues/693 +class RFCExamplesTestCase(TestCase): + @staticmethod + def simulate(m): + def wrapper(fn): + sim = Simulator(m) + sim.add_process(fn) + sim.run() + return wrapper + + def test_rfc_example_1(self): + class Float32(Struct): + fraction: unsigned(23) + exponent: unsigned(8) + sign: unsigned(1) + + self.assertEqual(Float32.as_shape().size, 32) + + flt_a = Float32() + flt_b = Float32(Const(0b00111110001000000000000000000000, 32)) + + m1 = Module() + with m1.If(flt_b.fraction > 0): + m1.d.comb += [ + flt_a.sign.eq(1), + flt_a.exponent.eq(127) + ] + + @self.simulate(m1) + def check_m1(): + self.assertEqual((yield flt_a.as_value()), 0xbf800000) + + class FloatOrInt32(Union): + float: Float32 + int: signed(32) + + f_or_i = FloatOrInt32() + is_gt_1 = Signal() + m2 = Module() + m2.d.comb += [ + f_or_i.int.eq(0x41C80000), + is_gt_1.eq(f_or_i.float.exponent >= 127) # => 1 + ] + + @self.simulate(m2) + def check_m2(): + self.assertEqual((yield is_gt_1), 1) + + class Op(Enum): + ADD = 0 + SUB = 1 + + adder_op_layout = StructLayout({ + "op": Op, + "a": Float32, + "b": Float32 + }) + + adder_op_storage = Signal(adder_op_layout) + self.assertEqual(len(adder_op_storage), 65) + + adder_op = View(adder_op_layout, adder_op_storage) + m3 = Module() + m3.d.comb += [ + adder_op.eq(Op.SUB), + adder_op.a.eq(flt_a), + adder_op.b.eq(flt_b) + ] + + @self.simulate(m3) + def check_m3(): + self.assertEqual((yield adder_op.as_value()), 0x7c40000000000001) + + def test_rfc_example_2(self): + class Kind(Enum): + ONE_SIGNED = 0 + TWO_UNSIGNED = 1 + + layout1 = StructLayout({ + "kind": Kind, + "value": UnionLayout({ + "one_signed": signed(2), + "two_unsigned": ArrayLayout(unsigned(1), 2) + }) + }) + self.assertEqual(layout1.size, 3) + + sig1 = Signal(layout1) + self.assertEqual(sig1.shape(), unsigned(3)) + + view1 = View(layout1, sig1) + self.assertIs(Value.cast(view1), sig1) + + view2 = View(layout1) + self.assertIsInstance(Value.cast(view2), Signal) + self.assertEqual(Value.cast(view2).shape(), unsigned(3)) + + m1 = Module() + m1.d.comb += [ + view1.kind.eq(Kind.TWO_UNSIGNED), + view1.value.two_unsigned[0].eq(1), + ] + + @self.simulate(m1) + def check_m1(): + self.assertEqual((yield view1.as_value()), 0b011) + + class SomeVariant(Struct): + class Value(Union): + one_signed: signed(2) + two_unsigned: ArrayLayout(unsigned(1), 2) + + kind: Kind + value: Value + + self.assertEqual(SomeVariant, unsigned(3)) + + view3 = SomeVariant() + self.assertIsInstance(Value.cast(view3), Signal) + self.assertEqual(Value.cast(view3).shape(), unsigned(3)) + + m2 = Module() + m2.submodules += m1 + m2.d.comb += [ + view3.kind.eq(Kind.ONE_SIGNED), + view3.value.eq(view1.value) + ] + + @self.simulate(m2) + def check_m2(): + self.assertEqual((yield view3.as_value()), 0b010) + + sig2 = Signal(SomeVariant) + self.assertEqual(sig2.shape(), unsigned(3)) + + layout2 = StructLayout({ + "ready": unsigned(1), + "payload": SomeVariant + }) + self.assertEqual(layout2.size, 4) + + self.assertEqual(layout1, Layout.cast(SomeVariant)) + + self.assertIs(SomeVariant, Layout.of(view3)) + + def test_rfc_example_3(self): + class Stream8b10b(View): + data: Signal + ctrl: Signal + + def __init__(self, value=None, *, width: int): + super().__init__(StructLayout({ + "data": unsigned(8 * width), + "ctrl": unsigned(width) + }), value) + + self.assertEqual(len(Stream8b10b(width=1).data), 8) + self.assertEqual(len(Stream8b10b(width=4).data), 32)