From 04f906965a9ebc9173951c96fa0c8b67ef090a79 Mon Sep 17 00:00:00 2001 From: Catherine Date: Mon, 27 Nov 2023 16:47:34 +0000 Subject: [PATCH] lib.wiring: in `is_compliant(sig, obj)`, check that `obj` is an interface object with that signature. Fixes #935. --- amaranth/lib/wiring.py | 15 ++++++++++++ tests/test_lib_wiring.py | 50 ++++++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/amaranth/lib/wiring.py b/amaranth/lib/wiring.py index bb8b4cb..624d98a 100644 --- a/amaranth/lib/wiring.py +++ b/amaranth/lib/wiring.py @@ -405,6 +405,21 @@ class Signature(metaclass=SignatureMeta): yield from iter_dimensions(value, dimensions=member.dimensions, path=path) def is_compliant(self, obj, *, reasons=None, path=("obj",)): + if not hasattr(obj, "signature"): + if reasons is not None: + reasons.append(f"{_format_path(path)} does not have an attribute 'signature'") + return False + if not isinstance(obj.signature, Signature): + if reasons is not None: + reasons.append(f"{_format_path(path + ('signature',))} is expected to be " + f"a signature, but it is a {obj.signature!r}") + return False + if self != obj.signature: + if reasons is not None: + reasons.append(f"{_format_path(path + ('signature',))} is expected to be equal " + f"to this signature, {self!r}, but it is a {obj.signature!r}") + return False + def check_attr_value(member, attr_value, *, path): if member.is_port: try: diff --git a/tests/test_lib_wiring.py b/tests/test_lib_wiring.py index c08f5dd..2ee96c2 100644 --- a/tests/test_lib_wiring.py +++ b/tests/test_lib_wiring.py @@ -416,7 +416,32 @@ class SignatureTestCase(unittest.TestCase): (("d", "s"), In (1), intf.d.s), ]) + def test_is_compliant_signature(self): + sig = Signature({}) + + obj1 = NS() + self.assertFalse(sig.is_compliant(obj1)) + reasons = [] + self.assertFalse(sig.is_compliant(obj1, reasons=reasons)) + self.assertEqual(reasons, ["'obj' does not have an attribute 'signature'"]) + + obj = NS(signature=1) + self.assertFalse(sig.is_compliant(obj)) + reasons = [] + self.assertFalse(sig.is_compliant(obj, reasons=reasons)) + self.assertEqual(reasons, ["'obj.signature' is expected to be a signature, but it is a 1"]) + + obj = NS(signature=Signature({"a": In(1)})) + self.assertFalse(sig.is_compliant(obj)) + reasons = [] + self.assertFalse(sig.is_compliant(obj, reasons=reasons)) + self.assertEqual(reasons, [ + "'obj.signature' is expected to be equal to this signature, " + "Signature({}), but it is a Signature({'a': In(1)})" + ]) + def assertNotCompliant(self, reason_regex, sig, obj): + obj.signature = sig self.assertFalse(sig.is_compliant(obj)) reasons = [] self.assertFalse(sig.is_compliant(obj, reasons=reasons)) @@ -474,25 +499,32 @@ class SignatureTestCase(unittest.TestCase): self.assertNotCompliant( r"^'obj\.a' does not have an attribute 'b'$", sig=Signature({"a": Out(Signature({"b": In(1)}))}), - obj=NS(a=Signal())) + obj=NS(a=NS(signature=Signature({"b": In(1)})))) self.assertTrue( Signature({"a": In(1)}).is_compliant( - NS(a=Signal()))) + NS(signature=Signature({"a": In(1)}), + a=Signal()))) self.assertTrue( Signature({"a": In(1)}).is_compliant( - NS(a=Const(1)))) + NS(signature=Signature({"a": In(1)}), + a=Const(1)))) self.assertTrue( # list Signature({"a": In(1).array(2, 2)}).is_compliant( - NS(a=[[Const(1), Const(1)], [Signal(), Signal()]]))) + NS(signature=Signature({"a": In(1).array(2, 2)}), + a=[[Const(1), Const(1)], [Signal(), Signal()]]))) self.assertTrue( # tuple Signature({"a": In(1).array(2, 2)}).is_compliant( - NS(a=((Const(1), Const(1)), (Signal(), Signal()))))) + NS(signature=Signature({"a": In(1).array(2, 2)}), + a=((Const(1), Const(1)), (Signal(), Signal()))))) self.assertTrue( # mixed list and tuple Signature({"a": In(1).array(2, 2)}).is_compliant( - NS(a=[[Const(1), Const(1)], (Signal(), Signal())]))) + NS(signature=Signature({"a": In(1).array(2, 2)}), + a=[[Const(1), Const(1)], (Signal(), Signal())]))) self.assertTrue( Signature({"a": Out(Signature({"b": In(1)}))}).is_compliant( - NS(a=NS(b=Signal())))) + NS(signature=Signature({"a": Out(Signature({"b": In(1)}))}), + a=NS(signature=Signature({"b": In(1)}), + b=Signal())))) def test_repr(self): sig = Signature({"a": In(1)}) @@ -933,9 +965,9 @@ class ConnectTestCase(unittest.TestCase): m = Module() connect(m, p=NS(signature=Signature({"a": Out(Signature({"f": Out(1)}))}), - a=NS(f=Signal(name='p__a'))), + a=NS(signature=Signature({"f": Out(1)}), f=Signal(name='p__a'))), q=NS(signature=Signature({"a": In(Signature({"f": Out(1)}))}), - a=NS(f=Signal(name='q__a')))) + a=NS(signature=Signature({"f": Out(1)}).flip(), f=Signal(name='q__a')))) self.assertEqual([repr(stmt) for stmt in m._statements], [ '(eq (sig q__a) (sig p__a))' ])