diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index 9bf873b..6c5c0c4 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -133,13 +133,8 @@ class Shape: return "unsigned({})".format(self.width) def __eq__(self, other): - if not isinstance(other, Shape): - try: - other = self.__class__.cast(other) - except TypeError as e: - raise TypeError("Shapes may be compared with shape-castable objects, not {!r}" - .format(other)) from e - return self.width == other.width and self.signed == other.signed + return (isinstance(other, Shape) and + self.width == other.width and self.signed == other.signed) def unsigned(width): diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 275512e..f58a1a5 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -38,7 +38,8 @@ class Field: def __eq__(self, other): return (isinstance(other, Field) and - self._shape == other.shape and self._offset == other.offset) + Shape.cast(self._shape) == Shape.cast(other.shape) and + self._offset == other.offset) def __repr__(self): return f"Field({self._shape!r}, {self._offset})" diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index a794f28..215bac2 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -47,15 +47,8 @@ class ShapeTestCase(FHDLTestCase): r"^Width must be a non-negative integer, not -1$"): Shape(-1) - def test_compare_wrong(self): - with self.assertRaisesRegex(TypeError, - r"^Shapes may be compared with shape-castable objects, not 'hi'$"): - Shape(1, True) == 'hi' - - def test_compare_tuple_wrong(self): - with self.assertRaisesRegex(TypeError, - r"^Shapes may be compared with shape-castable objects, not \(2, 3\)$"): - Shape(1, True) == (2, 3) + def test_compare_non_shape(self): + self.assertNotEqual(Shape(1, True), "hi") def test_repr(self): self.assertEqual(repr(Shape()), "unsigned(1)") diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index b8e71bd..9e32ac2 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -534,7 +534,7 @@ class StructTestCase(FHDLTestCase): v = S() self.assertEqual(Layout.of(v), S) - self.assertEqual(Value.cast(v).shape(), S) + self.assertEqual(Value.cast(v).shape(), Shape.cast(S)) self.assertEqual(Value.cast(v).name, "v") self.assertRepr(v.a, "(slice (sig v) 0:1)") self.assertRepr(v.b, "(s (slice (sig v) 1:4))") @@ -550,7 +550,7 @@ class StructTestCase(FHDLTestCase): a: unsigned(1) b: R - self.assertEqual(S, unsigned(9)) + self.assertEqual(Shape.cast(S), unsigned(9)) v = S() self.assertIs(Layout.of(v), S) @@ -654,7 +654,7 @@ class UnionTestCase(FHDLTestCase): v = U() self.assertEqual(Layout.of(v), U) - self.assertEqual(Value.cast(v).shape(), U) + self.assertEqual(Value.cast(v).shape(), Shape.cast(U)) self.assertRepr(v.a, "(slice (sig v) 0:1)") self.assertRepr(v.b, "(s (slice (sig v) 0:3))") @@ -803,7 +803,7 @@ class RFCExamplesTestCase(TestCase): kind: Kind value: Value - self.assertEqual(SomeVariant, unsigned(3)) + self.assertEqual(Shape.cast(SomeVariant), unsigned(3)) view3 = SomeVariant() self.assertIsInstance(Value.cast(view3), Signal)