lib.data: implement equality for View, reject all other operators.

This commit is contained in:
Wanda 2023-11-25 04:15:45 +01:00 committed by Catherine
parent 4bfe2cde6f
commit c6000b1097
2 changed files with 106 additions and 1 deletions

View file

@ -580,6 +580,10 @@ class View(ValueCastable):
an Amaranth value instead of a constant integer. The returned element is chosen dynamically an Amaranth value instead of a constant integer. The returned element is chosen dynamically
in that case. in that case.
A view can only be compared for equality with another view of the same layout,
returning a single-bit value. No other operators are supported on views. If required,
a view can be converted back to its underlying value via :meth:`as_value`.
Custom view classes Custom view classes
################### ###################
@ -609,7 +613,7 @@ class View(ValueCastable):
warnings.warn("View layout includes a field {!r} that will be shadowed by " warnings.warn("View layout includes a field {!r} that will be shadowed by "
"the view attribute '{}.{}.{}'" "the view attribute '{}.{}.{}'"
.format(name, type(self).__module__, type(self).__qualname__, name), .format(name, type(self).__module__, type(self).__qualname__, name),
SyntaxWarning, stacklevel=1) SyntaxWarning, stacklevel=2)
self.__orig_layout = layout self.__orig_layout = layout
self.__layout = cast_layout self.__layout = cast_layout
self.__target = cast_target self.__target = cast_target
@ -732,6 +736,49 @@ class View(ValueCastable):
.format(self.__target, name)) .format(self.__target, name))
return item return item
def __eq__(self, other):
if not isinstance(other, View) or self.__layout != other.__layout:
raise TypeError(f"View of {self.__layout!r} can only be compared to another view of the same layout, not {other!r}")
return self.__target == other.__target
def __ne__(self, other):
if not isinstance(other, View) or self.__layout != other.__layout:
raise TypeError(f"View of {self.__layout!r} can only be compared to another view of the same layout, not {other!r}")
return self.__target != other.__target
def __add__(self, other):
raise TypeError("Cannot perform arithmetic operations on a View")
__radd__ = __add__
__sub__ = __add__
__rsub__ = __add__
__mul__ = __add__
__rmul__ = __add__
__floordiv__ = __add__
__rfloordiv__ = __add__
__mod__ = __add__
__rmod__ = __add__
__lshift__ = __add__
__rlshift__ = __add__
__rshift__ = __add__
__rrshift__ = __add__
__lt__ = __add__
__le__ = __add__
__gt__ = __add__
__ge__ = __add__
def __and__(self, other):
raise TypeError("Cannot perform bitwise operations on a View")
__rand__ = __and__
__or__ = __and__
__ror__ = __and__
__xor__ = __and__
__rxor__ = __and__
def __repr__(self):
return f"{self.__class__.__name__}({self.__layout!r}, {self.__target!r})"
class _AggregateMeta(ShapeCastable, type): class _AggregateMeta(ShapeCastable, type):
def __new__(metacls, name, bases, namespace): def __new__(metacls, name, bases, namespace):

View file

@ -1,4 +1,5 @@
from enum import Enum from enum import Enum
import operator
from unittest import TestCase from unittest import TestCase
from amaranth.hdl import * from amaranth.hdl import *
@ -632,6 +633,63 @@ class ViewTestCase(FHDLTestCase):
r"^View of \(sig \$signal\) with an array layout does not have fields$"): r"^View of \(sig \$signal\) with an array layout does not have fields$"):
Signal(ArrayLayout(unsigned(1), 1), reset=[0]).reset Signal(ArrayLayout(unsigned(1), 1), reset=[0]).reset
def test_eq(self):
s1 = Signal(StructLayout({"a": unsigned(2)}))
s2 = Signal(StructLayout({"a": unsigned(2)}))
s3 = Signal(StructLayout({"a": unsigned(1), "b": unsigned(1)}))
self.assertRepr(s1 == s2, "(== (sig s1) (sig s2))")
self.assertRepr(s1 != s2, "(!= (sig s1) (sig s2))")
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
s1 == s3
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
s1 != s3
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
s1 == Const(0, 2)
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
s1 != Const(0, 2)
def test_operator(self):
s1 = Signal(StructLayout({"a": unsigned(2)}))
s2 = Signal(unsigned(2))
for op in [
operator.__add__,
operator.__sub__,
operator.__mul__,
operator.__floordiv__,
operator.__mod__,
operator.__lshift__,
operator.__rshift__,
operator.__lt__,
operator.__le__,
operator.__gt__,
operator.__ge__,
]:
with self.assertRaisesRegex(TypeError,
r"^Cannot perform arithmetic operations on a View$"):
op(s1, s2)
with self.assertRaisesRegex(TypeError,
r"^Cannot perform arithmetic operations on a View$"):
op(s2, s1)
for op in [
operator.__and__,
operator.__or__,
operator.__xor__,
]:
with self.assertRaisesRegex(TypeError,
r"^Cannot perform bitwise operations on a View$"):
op(s1, s2)
with self.assertRaisesRegex(TypeError,
r"^Cannot perform bitwise operations on a View$"):
op(s2, s1)
def test_repr(self):
s1 = Signal(StructLayout({"a": unsigned(2)}))
self.assertRepr(s1, "View(StructLayout({'a': unsigned(2)}), (sig s1))")
class StructTestCase(FHDLTestCase): class StructTestCase(FHDLTestCase):
def test_construct(self): def test_construct(self):