diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 151696e..601fdcd 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -580,6 +580,10 @@ class View(ValueCastable): an Amaranth value instead of a constant integer. The returned element is chosen dynamically in that case. + A view can only be compared for equality with another view of the same layout, + returning a single-bit value. No other operators are supported on views. If required, + a view can be converted back to its underlying value via :meth:`as_value`. + Custom view classes ################### @@ -609,7 +613,7 @@ class View(ValueCastable): warnings.warn("View layout includes a field {!r} that will be shadowed by " "the view attribute '{}.{}.{}'" .format(name, type(self).__module__, type(self).__qualname__, name), - SyntaxWarning, stacklevel=1) + SyntaxWarning, stacklevel=2) self.__orig_layout = layout self.__layout = cast_layout self.__target = cast_target @@ -732,6 +736,49 @@ class View(ValueCastable): .format(self.__target, name)) return item + def __eq__(self, other): + if not isinstance(other, View) or self.__layout != other.__layout: + raise TypeError(f"View of {self.__layout!r} can only be compared to another view of the same layout, not {other!r}") + return self.__target == other.__target + + def __ne__(self, other): + if not isinstance(other, View) or self.__layout != other.__layout: + raise TypeError(f"View of {self.__layout!r} can only be compared to another view of the same layout, not {other!r}") + return self.__target != other.__target + + def __add__(self, other): + raise TypeError("Cannot perform arithmetic operations on a View") + + __radd__ = __add__ + __sub__ = __add__ + __rsub__ = __add__ + __mul__ = __add__ + __rmul__ = __add__ + __floordiv__ = __add__ + __rfloordiv__ = __add__ + __mod__ = __add__ + __rmod__ = __add__ + __lshift__ = __add__ + __rlshift__ = __add__ + __rshift__ = __add__ + __rrshift__ = __add__ + __lt__ = __add__ + __le__ = __add__ + __gt__ = __add__ + __ge__ = __add__ + + def __and__(self, other): + raise TypeError("Cannot perform bitwise operations on a View") + + __rand__ = __and__ + __or__ = __and__ + __ror__ = __and__ + __xor__ = __and__ + __rxor__ = __and__ + + def __repr__(self): + return f"{self.__class__.__name__}({self.__layout!r}, {self.__target!r})" + class _AggregateMeta(ShapeCastable, type): def __new__(metacls, name, bases, namespace): diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index 2aad149..19b88dc 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -1,4 +1,5 @@ from enum import Enum +import operator from unittest import TestCase from amaranth.hdl import * @@ -632,6 +633,63 @@ class ViewTestCase(FHDLTestCase): r"^View of \(sig \$signal\) with an array layout does not have fields$"): Signal(ArrayLayout(unsigned(1), 1), reset=[0]).reset + def test_eq(self): + s1 = Signal(StructLayout({"a": unsigned(2)})) + s2 = Signal(StructLayout({"a": unsigned(2)})) + s3 = Signal(StructLayout({"a": unsigned(1), "b": unsigned(1)})) + self.assertRepr(s1 == s2, "(== (sig s1) (sig s2))") + self.assertRepr(s1 != s2, "(!= (sig s1) (sig s2))") + with self.assertRaisesRegex(TypeError, + r"^View of .* can only be compared to another view of the same layout, not .*$"): + s1 == s3 + with self.assertRaisesRegex(TypeError, + r"^View of .* can only be compared to another view of the same layout, not .*$"): + s1 != s3 + with self.assertRaisesRegex(TypeError, + r"^View of .* can only be compared to another view of the same layout, not .*$"): + s1 == Const(0, 2) + with self.assertRaisesRegex(TypeError, + r"^View of .* can only be compared to another view of the same layout, not .*$"): + s1 != Const(0, 2) + + def test_operator(self): + s1 = Signal(StructLayout({"a": unsigned(2)})) + s2 = Signal(unsigned(2)) + for op in [ + operator.__add__, + operator.__sub__, + operator.__mul__, + operator.__floordiv__, + operator.__mod__, + operator.__lshift__, + operator.__rshift__, + operator.__lt__, + operator.__le__, + operator.__gt__, + operator.__ge__, + ]: + with self.assertRaisesRegex(TypeError, + r"^Cannot perform arithmetic operations on a View$"): + op(s1, s2) + with self.assertRaisesRegex(TypeError, + r"^Cannot perform arithmetic operations on a View$"): + op(s2, s1) + for op in [ + operator.__and__, + operator.__or__, + operator.__xor__, + ]: + with self.assertRaisesRegex(TypeError, + r"^Cannot perform bitwise operations on a View$"): + op(s1, s2) + with self.assertRaisesRegex(TypeError, + r"^Cannot perform bitwise operations on a View$"): + op(s2, s1) + + def test_repr(self): + s1 = Signal(StructLayout({"a": unsigned(2)})) + self.assertRepr(s1, "View(StructLayout({'a': unsigned(2)}), (sig s1))") + class StructTestCase(FHDLTestCase): def test_construct(self):