From 4ffadff20dfa2438eb0ea2205782655270309d9e Mon Sep 17 00:00:00 2001 From: Catherine Date: Wed, 28 Jun 2023 14:56:53 +0000 Subject: [PATCH] lib.wiring: implement amaranth-lang/rfcs#2. Co-authored-by: Charlotte --- amaranth/lib/wiring.py | 783 ++++++++++++++++++++++++++++++++++++ tests/test_lib_wiring.py | 834 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 1617 insertions(+) create mode 100644 amaranth/lib/wiring.py create mode 100644 tests/test_lib_wiring.py diff --git a/amaranth/lib/wiring.py b/amaranth/lib/wiring.py new file mode 100644 index 0000000..182bb06 --- /dev/null +++ b/amaranth/lib/wiring.py @@ -0,0 +1,783 @@ +from collections.abc import Mapping +import enum +import types +import inspect +import re +import warnings + +from ..hdl.ast import Shape, ShapeCastable, Const, Signal, Value, ValueCastable +from ..hdl.ir import Elaboratable +from .._utils import final + + +__all__ = ["In", "Out", "Signature", "connect", "flipped", "Component"] + + +class Flow(enum.Enum): + Out = 0 + In = 1 + + def flip(self): + if self == Out: + return In + if self == In: + return Out + assert False # :nocov: + + def __call__(self, description, *, reset=None): + return Member(self, description, reset=reset) + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + + +In = Flow.In +Out = Flow.Out + + +@final +class Member: + def __init__(self, flow, description, *, reset=None, _dimensions=()): + self._flow = flow + self._description = description + self._reset = reset + self._dimensions = _dimensions + + # Check that the description is valid, and populate derived properties. + if self.is_port: + # Cast the description to a shape for typechecking, but keep the original + # shape-castable so that it can be provided + try: + shape = Shape.cast(self._description) + except TypeError as e: + raise TypeError(f"Port member description must be a shape-castable object or " + f"a signature, not {description!r}") from e + # This mirrors the logic that handles Signal(reset=). + # TODO: We need a simpler way to check for "is this a valid constant initializer" + if issubclass(type(self._description), ShapeCastable): + try: + self._reset_as_const = Const.cast(self._description.const(self._reset)) + except Exception as e: + raise TypeError(f"Port member reset value {self._reset!r} is not a valid " + f"constant initializer for {self._description}") from e + else: + try: + self._reset_as_const = Const.cast(reset or 0) + except TypeError: + raise TypeError(f"Port member reset value {self._reset!r} is not a valid " + f"constant initializer for {shape}") + if self.is_signature: + if self._reset is not None: + raise ValueError(f"A signature member cannot have a reset value") + + def flip(self): + return Member(self._flow.flip(), self._description, reset=self._reset, + _dimensions=self._dimensions) + + def array(self, *dimensions): + for dimension in dimensions: + if not (isinstance(dimension, int) and dimension >= 0): + raise TypeError(f"Member array dimensions must be non-negative integers, " + f"not {dimension!r}") + return Member(self._flow, self._description, reset=self._reset, + _dimensions=(*dimensions, *self._dimensions)) + + @property + def flow(self): + return self._flow + + @property + def is_port(self): + return not isinstance(self._description, Signature) + + @property + def is_signature(self): + return isinstance(self._description, Signature) + + @property + def shape(self): + if self.is_signature: + raise AttributeError(f"A signature member does not have a shape") + return self._description + + @property + def reset(self): + if self.is_signature: + raise AttributeError(f"A signature member does not have a reset value") + return self._reset + + @property + def signature(self): + if self.is_port: + raise AttributeError(f"A port member does not have a signature") + if self.flow == Out: + return self._description + if self.flow == In: + return self._description.flip() + assert False # :nocov: + + @property + def dimensions(self): + return self._dimensions + + def __eq__(self, other): + return (type(other) is Member and + self._flow == other._flow and + self._description == other._description and + self._reset == other._reset and + self._dimensions == other._dimensions) + + def __repr__(self): + reset_repr = dimensions_repr = "" + if self._reset: + reset_repr = f", reset={self._reset!r}" + if self._dimensions: + dimensions_repr = f".array({', '.join(map(str, self._dimensions))})" + return f"{self._flow!r}({self._description!r}{reset_repr}){dimensions_repr}" + + +@final +class SignatureError(Exception): + pass + + +# Inherits from Mapping and not MutableMapping because it's only mutable in a very limited way +# and most of the methods (except for `update`) added by MutableMapping are useless. +@final +class SignatureMembers(Mapping): + def __init__(self, members=()): + self._dict = dict() + self._frozen = False + self += members + + def flip(self): + return FlippedSignatureMembers(self) + + def __eq__(self, other): + return (isinstance(other, (SignatureMembers, FlippedSignatureMembers)) and + list(self.flatten()) == list(other.flatten())) + + def __contains__(self, name): + return name in self._dict + + def _check_name(self, name): + if not isinstance(name, str): + raise TypeError(f"Member name must be a string, not {name!r}") + if not re.match(r"^[A-Za-z][0-9A-Za-z_]*$", name): + raise NameError(f"Member name '{name}' must be a valid, public Python attribute name") + if name == "signature": + raise NameError(f"Member name cannot be '{name}'") + + def __getitem__(self, name): + self._check_name(name) + if name not in self._dict: + raise SignatureError(f"Member '{name}' is not a part of the signature") + return self._dict[name] + + def __setitem__(self, name, member): + self._check_name(name) + if name in self._dict: + raise SignatureError(f"Member '{name}' already exists in the signature and cannot " + f"be replaced") + if type(member) is not Member: + raise TypeError(f"Assigned value {member!r} must be a member; " + f"did you mean In({member!r}) or Out({member!r})?") + if self._frozen: + raise SignatureError("Cannot add members to a frozen signature") + self._dict[name] = member + + def __delitem__(self, name): + raise SignatureError("Members cannot be removed from a signature") + + def __iter__(self): + return iter(sorted(self._dict)) + + def __len__(self): + return len(self._dict) + + def __iadd__(self, members): + for name, member in dict(members).items(): + self[name] = member + return self + + @property + def frozen(self): + return self._frozen + + def freeze(self): + self._frozen = True + for member in self.values(): + if member.is_signature: + member.signature.freeze() + + def flatten(self, *, path=()): + for name, member in self.items(): + yield ((*path, name), member) + if member.is_signature: + yield from member.signature.members.flatten(path=(*path, name)) + + def create(self, *, path=()): + attrs = {} + for name, member in self.items(): + def create_value(path): + if member.is_port: + return Signal(member.shape, reset=member.reset, + name="__".join(str(item) for item in path)) + if member.is_signature: + return member.signature.create(path=path) + assert False # :nocov: + def create_dimensions(dimensions, *, path): + if not dimensions: + return create_value(path) + dimension, *rest_of_dimensions = dimensions + return [create_dimensions(rest_of_dimensions, path=(*path, index)) + for index in range(dimension)] + attrs[name] = create_dimensions(member.dimensions, path=(*path, name)) + return attrs + + def __repr__(self): + frozen_repr = ".freeze()" if self._frozen else "" + return f"SignatureMembers({self._dict}){frozen_repr}" + + +@final +class FlippedSignatureMembers(Mapping): + def __init__(self, unflipped): + self.__unflipped = unflipped + + def flip(self): + return self.__unflipped + + # See the note below. + __eq__ = SignatureMembers.__eq__ + + def __contains__(self, name): + return name in self.__unflipped + + def __getitem__(self, name): + return self.__unflipped.__getitem__(name).flip() + + def __setitem__(self, name, member): + self.__unflipped.__setitem__(name, member.flip()) + + def __delitem__(self, name): + self.__unflipped.__delitem__(name) + + def __iter__(self): + return self.__unflipped.__iter__() + + def __len__(self): + return self.__unflipped.__len__() + + def __iadd__(self, members): + self.__unflipped.__iadd__({name: member.flip() for name, member in members.items()}) + return self + + @property + def frozen(self): + return self.__unflipped.frozen + + def freeze(self): + self.__unflipped.freeze() + + # These methods do not access instance variables and so their implementation can be shared + # between the normal and the flipped member collections. + flatten = SignatureMembers.flatten + create = SignatureMembers.create + + def __repr__(self): + return f"{self.__unflipped!r}.flip()" + + +def _format_path(path): + first, *rest = path + if isinstance(first, int): + # only happens in connect() + chunks = [f"arg{first}"] + else: + chunks = [first] + for item in rest: + if isinstance(item, int): + chunks.append(f"[{item}]") + else: + chunks.append(f".{item}") + return f"'{''.join(chunks)}'" + + +def _traverse_path(path, obj): + first, *rest = path + obj = obj[first] + for item in rest: + if isinstance(item, int): + obj = obj[item] + else: + obj = getattr(obj, item) + return obj + + +def _format_shape(shape): + if type(shape) is Shape: + return f"{shape}" + if isinstance(shape, int): + return f"{Shape.cast(shape)}" + return f"{Shape.cast(shape)} ({shape!r})" + + +class SignatureMeta(type): + def __subclasscheck__(cls, subclass): + # `FlippedSignature` is a subclass of `Signature` or any of its subclasses because all of + # them may return a Liskov-compatible instance of it from `self.flip()`. + if subclass is FlippedSignature: + return True + return super().__subclasscheck__(subclass) + + def __instancecheck__(cls, instance): + # `FlippedSignature` is an instance of a `Signature` or its subclass if the unflipped + # object is. + if type(instance) is FlippedSignature: + return super().__instancecheck__(instance.flip()) + return super().__instancecheck__(instance) + + +class Signature(metaclass=SignatureMeta): + def __init__(self, members): + self.__members = SignatureMembers(members) + + def flip(self): + return FlippedSignature(self) + + @property + def members(self): + return self.__members + + def __eq__(self, other): + other_unflipped = other.flip() if type(other) is FlippedSignature else other + if type(self) is type(other_unflipped) is Signature: + # If both `self` and `other` are anonymous signatures, compare structurally. + return self.members == other.members + else: + # Otherwise (if `self` refers to a derived class) compare by identity. This will + # usually be overridden in a derived class. + return self is other + + @property + def frozen(self): + return self.members.frozen + + def freeze(self): + self.members.freeze() + return self + + def is_compliant(self, obj, *, reasons=None, path=("obj",)): + def check_attr_value(member, attr_value, *, path): + if member.is_port: + try: + attr_value_cast = Value.cast(attr_value) + except: + if reasons is not None: + reasons.append(f"{_format_path(path)} is not a value-castable object, " + f"but {attr_value!r}") + return False + if not isinstance(attr_value_cast, (Signal, Const)): + if reasons is not None: + reasons.append(f"{_format_path(path)} is neither a signal nor a constant, " + f"but {attr_value_cast!r}") + return False + attr_shape = attr_value_cast.shape() + if Shape.cast(attr_shape) != Shape.cast(member.shape): + if reasons is not None: + reasons.append(f"{_format_path(path)} is expected to have " + f"the shape {_format_shape(member.shape)}, but it has " + f"the shape {_format_shape(attr_shape)}") + return False + if isinstance(attr_value_cast, Signal): + if attr_value_cast.reset != member._reset_as_const.value: + if reasons is not None: + reasons.append(f"{_format_path(path)} is expected to have " + f"the reset value {member.reset!r}, but it has " + f"the reset value {attr_value_cast.reset!r}") + return False + if attr_value_cast.reset_less: + if reasons is not None: + reasons.append(f"{_format_path(path)} is expected to not be reset-less") + return False + return True + if member.is_signature: + return member.signature.is_compliant(attr_value, reasons=reasons, path=path) + assert False # :nocov: + + def check_dimensions(member, attr_value, dimensions, *, path): + if not dimensions: + return check_attr_value(member, attr_value, path=path) + + dimension, *rest_of_dimensions = dimensions + if not isinstance(attr_value, (tuple, list)): + if reasons is not None: + reasons.append(f"{_format_path(path)} is expected to be a tuple or a list, " + f"but it is a {attr_value!r}") + return False + if len(attr_value) != dimension: + if reasons is not None: + reasons.append(f"{_format_path(path)} is expected to have dimension " + f"{dimension}, but its length is {len(attr_value)}") + return False + + result = True + for index in range(dimension): + if not check_dimensions(member, attr_value[index], rest_of_dimensions, + path=(*path, index)): + result = False + if reasons is None: + break # short cicruit if detailed error message isn't required + return result + + result = True + for attr_name, member in self.members.items(): + try: + attr_value = getattr(obj, attr_name) + except AttributeError: + if reasons is None: + return False + else: + reasons.append(f"{_format_path(path)} does not have an attribute " + f"{attr_name!r}") + result = False + continue + if not check_dimensions(member, attr_value, member.dimensions, path=(*path, attr_name)): + if reasons is None: + return False + else: + # `reasons` was mutated by check_dimensions() + result = False + continue + return result + + def create(self, *, path=()): + return Interface(self, path=path) + + def __repr__(self): + if type(self) is Signature: + return f"Signature({dict(self.members.items())})" + return super().__repr__() + + +# To simplify implementation and reduce API surface area `FlippedSignature` is made final. This +# restriction could be lifted if there is a compelling use case. +@final +class FlippedSignature: + def __init__(self, signature): + object.__setattr__(self, "_FlippedSignature__unflipped", signature) + + def flip(self): + return self.__unflipped + + @property + def members(self): + return FlippedSignatureMembers(self.__unflipped.members) + + def __eq__(self, other): + if type(other) is FlippedSignature: + # Trivial case. + return self.flip() == other.flip() + else: + # Delegate comparisons back to Signature (or its descendant) by flipping the arguments; + # equality must be reflexive but the implementation of __eq__ need not be, and we can + # take advantage of it here. + return other == self + + # These methods do not access instance variables and so their implementation can be shared + # between the normal and the flipped member collections. + frozen = Signature.frozen + freeze = Signature.freeze + is_compliant = Signature.is_compliant + create = Signature.create + + # FIXME: document this logic + def __getattr__(self, name): + value = getattr(self.__unflipped, name) + if inspect.ismethod(value): + return types.MethodType(value.__func__, self) + else: + return value + + def __setattr__(self, name, value): + return setattr(self.__unflipped, name, value) + + def __repr__(self): + return f"{self.__unflipped!r}.flip()" + + +class Interface: + def __init__(self, signature, *, path): + self.__dict__.update({ + "signature": signature, + **signature.members.create(path=path) + }) + + +# To reduce API surface area `FlippedInterface` is made final. This restriction could be lifted +# if there is a compelling use case. +@final +class FlippedInterface: + def __init__(self, interface): + if not (hasattr(interface, "signature") and isinstance(interface.signature, Signature)): + raise TypeError(f"flipped() can only flip an interface object, not {interface!r}") + object.__setattr__(self, "_FlippedInterface__unflipped", interface) + + @property + def signature(self): + return self.__unflipped.signature.flip() + + def __eq__(self, other): + return type(self) is type(other) and self.__unflipped == other.__unflipped + + # FIXME: document this logic + def __getattr__(self, name): + value = getattr(self.__unflipped, name) + if inspect.ismethod(value): + return types.MethodType(value.__func__, self) + else: + return value + + def __setattr__(self, name, value): + return setattr(self.__unflipped, name, value) + + def __repr__(self): + return f"flipped({self.__unflipped!r})" + + +def flipped(interface): + return FlippedInterface(interface) + + +@final +class ConnectionError(Exception): + pass + + +def connect(m, *args, **kwargs): + objects = { + **{index: arg for index, arg in enumerate(args)}, + **{keyword: arg for keyword, arg in kwargs.items()} + } + + # Extract signatures from arguments. + signatures = {} + for handle, obj in objects.items(): + if not hasattr(obj, "signature"): + raise AttributeError(f"Argument {handle!r} must have a 'signature' attribute") + if not isinstance(obj.signature, Signature): + raise TypeError(f"Signature of argument {handle!r} must be a signature, " + f"not {obj.signature!r}") + if not obj.signature.is_compliant(obj): + reasons = [] + obj.signature.is_compliant(obj, reasons=reasons, path=(handle,)) + reasons_as_string = "".join("\n- " + reason for reason in reasons) + raise ConnectionError(f"Argument {handle!r} does not match its signature:" + + reasons_as_string) + signatures[handle] = obj.signature.freeze() + + # Collate signatures and build connections. + flattens = {handle: signature.members.flatten() + for handle, signature in signatures.items()} + connections = [] + # Each iteration of the outer loop is intended to connect several (usually a pair) members + # to each other, e.g. an out member `[0].a` to an in member `[1].a`. However, because we + # do not just check signatures for equality (in order to improve diagnostics), it is possible + # that we will find that in `[0]`, the first member is `a`, and in `[1]`, the first member + # is completely unrelated `[b]`. Since the assumption that all signatures are equal, or even + # of equal length, cannot be made, it is necessary to simultaneously iterate (like with `zip`) + # the signature of every object being connected, making sure each set of next members is + # is_compliant with each other. + while True: + # Classify the members by kind and flow: signature, In, Out. Flow of signature members is + # implied in the flow of each port member, so the signature members are only classified + # here to ensure they are not connected to port members. + is_first = True + sig_kind, out_kind, in_kind = [], [], [] + for handle, flattened_members in flattens.items(): + path_for_handle, member = next(flattened_members, (None, None)) + # First, ensure that the paths are equal (i.e. that the hierarchy matches for all of + # the objects up to this point). + if is_first: + is_first = False + first_path = path_for_handle + else: + first_handle = next(iter(flattens)) + if first_path != path_for_handle: + # The paths are inequal. It is ambiguous how exactly the diagnostic should be + # displayed, and the choices of which other member to use below is arbitrary. + # Signature members are iterated in ascending lexicographical order, so the path + # that sorts greater corresponds to the handle that's missing a member. + if (path_for_handle is None or + (first_path is not None and path_for_handle > first_path)): + first_path_as_string = _format_path(first_path) + raise ConnectionError(f"Member {first_path_as_string} is present in " + f"{first_handle!r}, but not in {handle!r}") + if (first_path is None or + (path_for_handle is not None and path_for_handle < first_path)): + path_for_handle_as_string = _format_path(path_for_handle) + raise ConnectionError(f"Member {path_for_handle_as_string} is present in " + f"{handle!r}, but not in {first_handle!r}") + assert False # :nocov: + # If there is no actual member, the signature has been fully iterated through. + # Other signatures may still have extraneous members, so continue iterating until + # a diagnostic is returned. + if member is None: + continue + # At this point we know the paths are equal, but the members can still have + # inis_compliant flow, kind (signature or port), signature, or shape. Collect all of + # these for later evaluation. + if member.is_port: + if member.flow == Out: + out_kind.append(((handle, *path_for_handle), member)) + if member.flow == In: + in_kind.append(((handle, *path_for_handle), member)) + if member.is_signature: + sig_kind.append(((handle, *path_for_handle), member)) + # If there's no path and an error wasn't raised above, we're done! + if first_path is None: + break + # At this point, valid possibilities are: + # - All of the members are signature members. In this case, we move on to their contents, + # and ignore the signatures themselves. + # - There are no signature members, and there is exactly one Out flow member. In this case, + # this member is connected to the remaining In members, of which there may be any amount. + # All other cases must be rejected with a diagnostic. + if sig_kind and (out_kind or in_kind): + sig_member_paths_as_string = \ + ", ".join(_format_path(h) for h, m in sig_kind) + port_member_paths_as_string = \ + ", ".join(_format_path(h) for h, m in out_kind + in_kind) + raise ConnectionError( + f"Cannot connect signature member(s) {sig_member_paths_as_string} with " + f"port member(s) {port_member_paths_as_string}") + if sig_kind: + # There are no port members at this point; we're done with this path. + continue + # There are only port members after this point. + is_first = True + for (path, member) in in_kind + out_kind: + member_shape = member.shape + if is_first: + is_first = False + first_path = path + first_member_shape = member.shape + first_member_reset = member.reset + first_member_reset_as_const = member._reset_as_const + continue + if Shape.cast(first_member_shape).width != Shape.cast(member_shape).width: + raise ConnectionError( + f"Cannot connect the member {_format_path(first_path)} with shape " + f"{_format_shape(first_member_shape)} to the member {_format_path(path)} with " + f"shape {_format_shape(member_shape)} because the shape widths " + f"({Shape.cast(first_member_shape).width} and " + f"{Shape.cast(member_shape).width}) do not match") + if first_member_reset_as_const.value != member._reset_as_const.value: + raise ConnectionError( + f"Cannot connect together the member {_format_path(first_path)} with reset " + f"value {first_member_reset!r} and the member {_format_path(path)} with reset " + f"value {member.reset} because the reset values do not match") + # If there are no Out members, there is nothing to connect. The In members, while not + # explicitly connected, will stay at the same value since we ensured their reset values + # are all identical. + if len(out_kind) == 0: + continue + # Check that there is only one Out member. In the future we could extend connection to + # handle wired-OR and wired-AND, and this check may go away. + if len(out_kind) != 1: + out_member_paths_as_string = \ + ", ".join(_format_path(h) for h, m in out_kind) + raise ConnectionError( + f"Cannot connect several output members {out_member_paths_as_string} together") + # There is exactly one Out member after this point, and any amount of In members. + # Traversing the paths to all of them should always succeed, since the signature check + # at the beginning of `connect()` passed, and so should casting the result to a Value. + (out_path, out_member), = out_kind + for (in_path, in_member) in in_kind: + def connect_value(*, out_path, in_path): + in_value = Value.cast(_traverse_path(in_path, objects)) + out_value = Value.cast(_traverse_path(out_path, objects)) + assert type(in_value) in (Const, Signal) + # If the input is a constant, only a constant may be connected to it. Ensure that + # this is the case. + if type(in_value) is Const: + # If the output is not a constant, the connection is illegal. + if type(out_value) is not Const: + raise ConnectionError( + f"Cannot connect to the input member {_format_path(in_path)} that has " + f"a constant value {in_value.value!r}") + # If the output is a constant, the connection is legal only if the value is + # the same for both the input and the output. + if type(out_value) is Const and in_value.value != out_value.value: + raise ConnectionError( + f"Cannot connect input member {_format_path(in_path)} that has " + f"a constant value {in_value.value!r} to an output member " + f"{_format_path(out_path)} that has a differing constant value " + f"{out_value.value!r}") + # We never actually connect anything to the constant input; we only ensure its + # value (which is constant) is consistent with a connection that would have + # been made. + return + # A connection that is made at this point is guaranteed to be valid. + connections.append(in_value.eq(out_value)) + def connect_dimensions(dimensions, *, out_path, in_path): + if not dimensions: + return connect_value(out_path=out_path, in_path=in_path) + dimension, *rest_of_dimensions = dimensions + for index in range(dimension): + connect_dimensions(rest_of_dimensions, + out_path=(*out_path, index), in_path=(*in_path, index)) + assert out_member.dimensions == in_member.dimensions + connect_dimensions(out_member.dimensions, out_path=out_path, in_path=in_path) + # Now that we know all of the connections are legal, add them to the module. This is done + # instead of returning them because adding them to a non-comb domain would subtly violate + # assumptions that `connect()` is intended to provide. + m.d.comb += connections + + +class Component(Elaboratable): + def __init__(self): + for name in self.signature.members: + if hasattr(self, name): + raise NameError(f"Cannot initialize attribute for signature member {name!r} " + f"because an attribute with the same name already exists") + self.__dict__.update(self.signature.members.create()) + + # TODO(py3.9): This should be a class method, but descriptors don't stack this way + # in Python 3.8 and below. + # @classmethod + @property + def signature(self): + cls = type(self) + signature = Signature({}) + for base in cls.mro()[:cls.mro().index(Component)]: + for name, annot in getattr(base, "__annotations__", {}).items(): + if name.startswith("_"): + continue + if (annot in (Value, Signal, Const) or + (isinstance(annot, type) and issubclass(annot, ValueCastable)) or + isinstance(annot, Signature)): + if isinstance(annot, type): + annot_repr = annot.__name__ + else: + annot_repr = repr(annot) + # To suppress this warning in the rare cases where it is necessary (and naming + # the field with a leading underscore is infeasible), override the property. + warnings.warn( + message=f"Component '{cls.__module__}.{cls.__qualname__}' has " + f"an annotation '{name}: {annot_repr}', which is not " + f"a signature member; did you mean '{name}: In({annot_repr})' " + f"or '{name}: Out({annot_repr})'?", + category=SyntaxWarning, + stacklevel=2) + elif type(annot) is Member: + signature.members[name] = annot + if not signature.members: + raise NotImplementedError( + f"Component '{cls.__module__}.{cls.__qualname__}' does not have signature member " + f"annotations") + return signature diff --git a/tests/test_lib_wiring.py b/tests/test_lib_wiring.py new file mode 100644 index 0000000..b1bf74e --- /dev/null +++ b/tests/test_lib_wiring.py @@ -0,0 +1,834 @@ +import unittest +from types import SimpleNamespace as NS + +from amaranth import * +from amaranth.hdl.ast import ValueCastable +from amaranth.lib import data, enum +from amaranth.lib.wiring import Flow, In, Out, Member +from amaranth.lib.wiring import SignatureError, SignatureMembers, FlippedSignatureMembers +from amaranth.lib.wiring import Signature, FlippedSignature, Interface, FlippedInterface +from amaranth.lib.wiring import Component +from amaranth.lib.wiring import ConnectionError, connect, flipped + + +class FlowTestCase(unittest.TestCase): + def test_flow_call(self): + self.assertEqual(In(unsigned(1)), Member(Flow.In, unsigned(1))) + self.assertEqual(Out(5), Member(Flow.Out, 5)) + + def test_flow_repr(self): + self.assertEqual(repr(Flow.In), "In") + self.assertEqual(repr(Flow.Out), "Out") + + def test_flow_str(self): + self.assertEqual(str(Flow.In), "In") + self.assertEqual(str(Flow.Out), "Out") + + +class MemberTestCase(unittest.TestCase): + def test_port_member(self): + member = Member(In, unsigned(1)) + self.assertEqual(member.flow, In) + self.assertEqual(member.is_port, True) + self.assertEqual(member.shape, unsigned(1)) + self.assertEqual(member.reset, None) + self.assertEqual(member.is_signature, False) + with self.assertRaisesRegex(AttributeError, + r"^A port member does not have a signature$"): + member.signature + self.assertEqual(member.dimensions, ()) + self.assertEqual(repr(member), "In(unsigned(1))") + + def test_port_member_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Port member description must be a shape-castable object or a signature, " + r"not 'whatever'$"): + Member(In, "whatever") + + def test_port_member_reset(self): + member = Member(Out, unsigned(1), reset=1) + self.assertEqual(member.flow, Out) + self.assertEqual(member.shape, unsigned(1)) + self.assertEqual(member.reset, 1) + self.assertEqual(repr(member._reset_as_const), repr(Const(1, 1))) + self.assertEqual(repr(member), "Out(unsigned(1), reset=1)") + + def test_port_member_reset_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Port member reset value 'no' is not a valid constant initializer " + r"for unsigned\(1\)$"): + Member(In, 1, reset="no") + + def test_port_member_reset_shape_castable(self): + layout = data.StructLayout({"a": 32}) + member = Member(In, layout, reset={"a": 1}) + self.assertEqual(member.flow, In) + self.assertEqual(member.shape, layout) + self.assertEqual(member.reset, {"a": 1}) + self.assertEqual(repr(member), "In(StructLayout({'a': 32}), reset={'a': 1})") + + def test_port_member_reset_shape_castable_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Port member reset value 'no' is not a valid constant initializer " + r"for StructLayout\({'a': 32}\)$"): + Member(In, data.StructLayout({"a": 32}), reset="no") + + def test_signature_member_out(self): + sig = Signature({"data": Out(unsigned(32))}) + member = Member(Out, sig) + self.assertEqual(member.flow, Out) + self.assertEqual(member.is_port, False) + with self.assertRaisesRegex(AttributeError, + r"^A signature member does not have a shape$"): + member.shape + with self.assertRaisesRegex(AttributeError, + r"^A signature member does not have a reset value$"): + member.reset + self.assertEqual(member.is_signature, True) + self.assertEqual(member.signature, sig) + self.assertEqual(member.dimensions, ()) + self.assertEqual(repr(member), "Out(Signature({'data': Out(unsigned(32))}))") + + def test_signature_member_in(self): + sig = Signature({"data": In(unsigned(32))}) + member = Member(In, sig) + self.assertEqual(member.flow, In) + self.assertEqual(member.is_port, False) + with self.assertRaisesRegex(AttributeError, + r"^A signature member does not have a shape$"): + member.shape + with self.assertRaisesRegex(AttributeError, + r"^A signature member does not have a reset value$"): + member.reset + self.assertEqual(member.is_signature, True) + self.assertEqual(member.signature, sig.flip()) + self.assertEqual(member.dimensions, ()) + self.assertEqual(repr(member), "In(Signature({'data': In(unsigned(32))}))") + + def test_signature_member_wrong(self): + with self.assertRaisesRegex(ValueError, + r"^A signature member cannot have a reset value$"): + Member(In, Signature({}), reset=1) + + def test_array(self): + array_2 = Member(In, unsigned(1)).array(2) + self.assertEqual(array_2.dimensions, (2,)) + self.assertEqual(repr(array_2), "In(unsigned(1)).array(2)") + + array_2_3 = Member(In, unsigned(1)).array(2, 3) + self.assertEqual(array_2_3.dimensions, (2, 3)) + self.assertEqual(repr(array_2_3), "In(unsigned(1)).array(2, 3)") + + array_2_3_chained = Member(In, unsigned(1)).array(3).array(2) + self.assertEqual(array_2_3_chained.dimensions, (2, 3)) + self.assertEqual(repr(array_2_3_chained), "In(unsigned(1)).array(2, 3)") + + def test_array_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^Member array dimensions must be non-negative integers, not -1$"): + Member(In, unsigned(1)).array(-1) + with self.assertRaisesRegex(TypeError, + r"^Member array dimensions must be non-negative integers, not 'what'$"): + Member(In, unsigned(1)).array("what") + + def test_flip(self): + self.assertEqual(In(1).flip(), Out(1)) + self.assertEqual(Out(1).flip(), In(1)) + + def test_equality(self): + self.assertEqual(In(1), In(1)) + self.assertNotEqual(In(1), Out(1)) + self.assertNotEqual(In(1), In(1, reset=1)) + self.assertNotEqual(In(1), In(1, reset=0)) + self.assertEqual(In(1), In(1).array()) + self.assertNotEqual(In(1), In(1).array(1)) + sig = Signature({}) + self.assertEqual(In(sig), In(sig)) + self.assertNotEqual(In(1), In(Signature({}))) + + +class SignatureMembersTestCase(unittest.TestCase): + def test_contains(self): + self.assertNotIn("a", SignatureMembers()) + self.assertIn("a", SignatureMembers({"a": In(1)})) + + def test_getitem(self): + members = SignatureMembers({"a": In(1)}) + self.assertEqual(members["a"], In(1)) + + def test_getitem_missing(self): + members = SignatureMembers({"a": In(1)}) + with self.assertRaisesRegex(SignatureError, + r"^Member 'b' is not a part of the signature$"): + members["b"] + + def test_getitem_wrong(self): + members = SignatureMembers({"a": In(1)}) + with self.assertRaisesRegex(TypeError, + r"^Member name must be a string, not 1$"): + members[1] + with self.assertRaisesRegex(NameError, + r"^Member name '_a' must be a valid, public Python attribute name$"): + members["_a"] + with self.assertRaisesRegex(NameError, + r"^Member name cannot be 'signature'$"): + members["signature"] + + def test_setitem(self): + members = SignatureMembers() + members["a"] = In(1) + self.assertEqual(members["a"], In(1)) + + def test_setitem_existing(self): + members = SignatureMembers({"a": In(1)}) + with self.assertRaisesRegex(SignatureError, + r"^Member 'a' already exists in the signature and cannot be replaced$"): + members["a"] = Out(2) + + def test_setitem_wrong(self): + members = SignatureMembers() + with self.assertRaisesRegex(TypeError, + r"^Member name must be a string, not 1$"): + members[1] = Out(1) + with self.assertRaisesRegex(TypeError, + r"^Assigned value 1 must be a member; did you mean In\(1\) or Out\(1\)\?$"): + members["a"] = 1 + with self.assertRaisesRegex(NameError, + r"^Member name '_a' must be a valid, public Python attribute name$"): + members["_a"] = Out(1) + with self.assertRaisesRegex(NameError, + r"^Member name cannot be 'signature'$"): + members["signature"] = Out(1) + + def test_delitem(self): + members = SignatureMembers() + with self.assertRaisesRegex(SignatureError, + r"^Members cannot be removed from a signature$"): + del members["a"] + + def test_iter_len(self): + members = SignatureMembers() + self.assertEqual(list(iter(members)), []) + self.assertEqual(len(members), 0) + members["a"] = In(1) + self.assertEqual(list(iter(members)), ["a"]) + self.assertEqual(len(members), 1) + + def test_iter_sorted(self): + self.assertEqual(list(iter(SignatureMembers({"a": In(1), "b": Out(1)}))), + ["a", "b"]) + self.assertEqual(list(iter(SignatureMembers({"b": In(1), "a": Out(1)}))), + ["a", "b"]) + + def test_iadd(self): + members = SignatureMembers() + members += {"a": In(1)} + members += [("b", Out(1))] + self.assertEqual(members, SignatureMembers({"a": In(1), "b": Out(1)})) + + def test_freeze(self): + members = SignatureMembers({"a": In(1)}) + self.assertEqual(members.frozen, False) + members.freeze() + self.assertEqual(members.frozen, True) + with self.assertRaisesRegex(SignatureError, + r"^Cannot add members to a frozen signature$"): + members += {"b": Out(1)} + + def test_freeze_rec(self): + sig = Signature({}) + members = SignatureMembers({ + "a": In(1), + "s": Out(sig) + }) + self.assertEqual(members.frozen, False) + self.assertEqual(sig.frozen, False) + self.assertEqual(sig.members.frozen, False) + members.freeze() + self.assertEqual(members.frozen, True) + self.assertEqual(sig.frozen, True) + self.assertEqual(sig.members.frozen, True) + + def test_flatten(self): + sig = Signature({ + "b": Out(1), + "c": In(2) + }) + members = SignatureMembers({ + "a": In(1), + "s": Out(sig) + }) + self.assertEqual(list(members.flatten()), [ + (("a",), In(1)), + (("s",), Out(sig)), + (("s", "b"), Out(1)), + (("s", "c"), In(2)), + ]) + + def test_create(self): + sig = Signature({ + "b": Out(2) + }) + members = SignatureMembers({ + "a": In(1), + "s": Out(sig) + }) + attrs = members.create() + self.assertEqual(list(attrs.keys()), ["a", "s"]) + self.assertIsInstance(attrs["a"], Signal) + self.assertEqual(attrs["a"].shape(), unsigned(1)) + self.assertEqual(attrs["a"].name, "a") + self.assertEqual(attrs["s"].b.shape(), unsigned(2)) + self.assertEqual(attrs["s"].b.name, "s__b") + + def test_create_reset(self): + members = SignatureMembers({ + "a": In(1, reset=1), + }) + attrs = members.create() + self.assertEqual(attrs["a"].reset, 1) + + def test_create_tuple(self): + sig = SignatureMembers({ + "a": Out(1).array(2, 3) + }) + members = sig.create() + self.assertEqual(len(members["a"]), 2) + self.assertEqual(len(members["a"][0]), 3) + self.assertEqual(len(members["a"][1]), 3) + for x in members["a"]: + for y in x: + self.assertIsInstance(y, Signal) + self.assertEqual(members["a"][1][2].name, "a__1__2") + + def test_repr(self): + self.assertEqual(repr(SignatureMembers({})), + "SignatureMembers({})") + self.assertEqual(repr(SignatureMembers({"a": In(1)})), + "SignatureMembers({'a': In(1)})") + members = SignatureMembers({"b": Out(2)}) + members.freeze() + self.assertEqual(repr(members), + "SignatureMembers({'b': Out(2)}).freeze()") + + +class FlippedSignatureMembersTestCase(unittest.TestCase): + def test_basic(self): + members = SignatureMembers({"a": In(1)}) + fmembers = members.flip() + self.assertIsInstance(fmembers, FlippedSignatureMembers) + self.assertIn("a", fmembers) + self.assertEqual(fmembers["a"], Out(1)) + fmembers["b"] = Out(2) + self.assertEqual(len(fmembers), 2) + self.assertEqual(members["b"], In(2)) + self.assertEqual(list(fmembers), ["a", "b"]) + fmembers += {"c": In(2)} + self.assertEqual(members["c"], Out(2)) + self.assertIs(fmembers.flip(), members) + + def test_eq(self): + self.assertEqual(SignatureMembers({"a": In(1)}).flip(), + SignatureMembers({"a": In(1)}).flip()) + self.assertEqual(SignatureMembers({"a": In(1)}).flip(), + SignatureMembers({"a": Out(1)})) + + def test_delitem(self): + fmembers = SignatureMembers().flip() + with self.assertRaisesRegex(SignatureError, + r"^Members cannot be removed from a signature$"): + del fmembers["a"] + + def test_freeze(self): + members = SignatureMembers({"a": In(1)}) + fmembers = members.flip() + self.assertEqual(fmembers.frozen, False) + fmembers.freeze() + self.assertEqual(members.frozen, True) + self.assertEqual(fmembers.frozen, True) + + def test_repr(self): + fmembers = SignatureMembers({"a": In(1)}).flip() + self.assertEqual(repr(fmembers), "SignatureMembers({'a': In(1)}).flip()") + + +class SignatureTestCase(unittest.TestCase): + def test_create(self): + sig = Signature({"a": In(1)}) + self.assertEqual(sig.members, SignatureMembers({"a": In(1)})) + + def test_eq(self): + self.assertEqual(Signature({"a": In(1)}), + Signature({"a": In(1)})) + self.assertNotEqual(Signature({"a": In(1)}), + Signature({"a": Out(1)})) + + def test_freeze(self): + sig = Signature({"a": In(1)}) + self.assertEqual(sig.frozen, False) + sig.freeze() + self.assertEqual(sig.frozen, True) + with self.assertRaisesRegex(SignatureError, + r"^Cannot add members to a frozen signature$"): + sig.members += {"b": Out(1)} + + def assertNotCompliant(self, reason_regex, sig, obj): + self.assertFalse(sig.is_compliant(obj)) + reasons = [] + self.assertFalse(sig.is_compliant(obj, reasons=reasons)) + self.assertEqual(len(reasons), 1) + self.assertRegex(reasons[0], reason_regex) + + def test_is_compliant(self): + self.assertNotCompliant( + r"^'obj' does not have an attribute 'a'$", + sig=Signature({"a": In(1)}), + obj=NS()) + self.assertNotCompliant( + r"^'obj\.a' is expected to be a tuple or a list, but it is a \(sig \$signal\)$", + sig=Signature({"a": In(1).array(2)}), + obj=NS(a=Signal())) + self.assertNotCompliant( + r"^'obj\.a' is expected to have dimension 2, but its length is 1$", + sig=Signature({"a": In(1).array(2)}), + obj=NS(a=[Signal()])) + self.assertNotCompliant( + r"^'obj\.a\[0\]' is expected to have dimension 2, but its length is 1$", + sig=Signature({"a": In(1).array(1, 2)}), + obj=NS(a=[[Signal()]])) + self.assertNotCompliant( + r"^'obj\.a' is not a value-castable object, but 'foo'$", + sig=Signature({"a": In(1)}), + obj=NS(a="foo")) + self.assertNotCompliant( + r"^'obj\.a' is neither a signal nor a constant, but " + r"\(\+ \(const 1'd1\) \(const 1'd1\)\)$", + sig=Signature({"a": In(1)}), + obj=NS(a=Const(1)+1)) + self.assertNotCompliant( + r"^'obj\.a' is expected to have the shape unsigned\(1\), but " + r"it has the shape unsigned\(2\)$", + sig=Signature({"a": In(1)}), + obj=NS(a=Signal(2))) + self.assertNotCompliant( + r"^'obj\.a' is expected to have the shape unsigned\(1\), but " + r"it has the shape signed\(1\)$", + sig=Signature({"a": In(unsigned(1))}), + obj=NS(a=Signal(signed(1)))) + self.assertNotCompliant( + r"^'obj\.a' is expected to have the reset value None, but it has the reset value 1$", + sig=Signature({"a": In(1)}), + obj=NS(a=Signal(reset=1))) + self.assertNotCompliant( + r"^'obj\.a' is expected to have the reset value 1, but it has the reset value 0$", + sig=Signature({"a": In(1, reset=1)}), + obj=NS(a=Signal(1))) + self.assertNotCompliant( + r"^'obj\.a' is expected to not be reset-less$", + sig=Signature({"a": In(1)}), + obj=NS(a=Signal(1, reset_less=True))) + self.assertNotCompliant( + r"^'obj\.a' does not have an attribute 'b'$", + sig=Signature({"a": Out(Signature({"b": In(1)}))}), + obj=NS(a=Signal())) + self.assertTrue( + Signature({"a": In(1)}).is_compliant( + NS(a=Signal()))) + self.assertTrue( + Signature({"a": In(1)}).is_compliant( + NS(a=Const(1)))) + self.assertTrue( # list + Signature({"a": In(1).array(2, 2)}).is_compliant( + NS(a=[[Const(1), Const(1)], [Signal(), Signal()]]))) + self.assertTrue( # tuple + Signature({"a": In(1).array(2, 2)}).is_compliant( + NS(a=((Const(1), Const(1)), (Signal(), Signal()))))) + self.assertTrue( # mixed list and tuple + Signature({"a": In(1).array(2, 2)}).is_compliant( + NS(a=[[Const(1), Const(1)], (Signal(), Signal())]))) + self.assertTrue( + Signature({"a": Out(Signature({"b": In(1)}))}).is_compliant( + NS(a=NS(b=Signal())))) + + def test_repr(self): + sig = Signature({"a": In(1)}) + self.assertEqual(repr(sig), "Signature({'a': In(1)})") + + def test_repr_subclass(self): + class S(Signature): + def __init__(self): + super().__init__({"a": In(1)}) + sig = S() + self.assertRegex(repr(sig), r"^<.+\.S object at .+?>$") + + def test_subclasscheck(self): + class S(Signature): + pass + self.assertTrue(issubclass(FlippedSignature, Signature)) + self.assertTrue(issubclass(Signature, Signature)) + self.assertTrue(issubclass(FlippedSignature, S)) + self.assertTrue(not issubclass(Signature, S)) + + def test_instancecheck(self): + class S(Signature): + pass + sig = Signature({}) + sig2 = S({}) + self.assertTrue(isinstance(sig.flip(), Signature)) + self.assertTrue(isinstance(sig2.flip(), Signature)) + self.assertTrue(not isinstance(sig.flip(), S)) + self.assertTrue(isinstance(sig2.flip(), S)) + + +class FlippedSignatureTestCase(unittest.TestCase): + def test_create(self): + sig = Signature({"a": In(1)}) + fsig = sig.flip() + self.assertIsInstance(fsig, FlippedSignature) + self.assertIsInstance(fsig.members, FlippedSignatureMembers) + self.assertIs(fsig.flip(), sig) + + def test_eq(self): + self.assertEqual(Signature({"a": In(1)}).flip(), + Signature({"a": In(1)}).flip()) + self.assertEqual(Signature({"a": In(1)}).flip(), + Signature({"a": Out(1)})) + + def test_repr(self): + sig = Signature({"a": In(1)}).flip() + self.assertEqual(repr(sig), "Signature({'a': In(1)}).flip()") + + def test_getattr_setattr(self): + class S(Signature): + def __init__(self): + super().__init__({}) + self.x = 1 + + def f(self2): + self.assertIsInstance(self2, FlippedSignature) + return "f()" + sig = S() + fsig = sig.flip() + self.assertEqual(fsig.x, 1) + self.assertEqual(fsig.f(), "f()") + fsig.y = 2 + self.assertEqual(sig.y, 2) + + +class InterfaceTestCase(unittest.TestCase): + pass + + +class FlippedInterfaceTestCase(unittest.TestCase): + def test_basic(self): + sig = Signature({"a": In(1)}) + intf = sig.create() + self.assertTrue(sig.is_compliant(intf)) + self.assertIs(intf.signature, sig) + tintf = flipped(intf) + self.assertEqual(tintf.signature, intf.signature.flip()) + self.assertEqual(tintf, flipped(intf)) + self.assertRegex(repr(tintf), + r"^flipped\(<.+?\.Interface object at .+>\)$") + + def test_getattr_setattr(self): + class I(Interface): + signature = Signature({}) + + def __init__(self): + self.x = 1 + + def f(self2): + self.assertIsInstance(self2, FlippedInterface) + return "f()" + intf = I() + tintf = flipped(intf) + self.assertEqual(tintf.x, 1) + self.assertEqual(tintf.f(), "f()") + tintf.y = 2 + self.assertEqual(intf.y, 2) + + def test_flipped_wrong(self): + with self.assertRaisesRegex(TypeError, + r"^flipped\(\) can only flip an interface object, not Signature\({}\)$"): + flipped(Signature({})) + + +class ConnectTestCase(unittest.TestCase): + def test_arg_handles_and_signature_attr(self): + m = Module() + with self.assertRaisesRegex(AttributeError, + r"^Argument 0 must have a 'signature' attribute$"): + connect(m, object()) + with self.assertRaisesRegex(AttributeError, + r"^Argument 'x' must have a 'signature' attribute$"): + connect(m, x=object()) + + def test_signature_type(self): + m = Module() + with self.assertRaisesRegex(TypeError, + r"^Signature of argument 0 must be a signature, not 1$"): + connect(m, NS(signature=1)) + + def test_signature_compliant(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Argument 0 does not match its signature:\n" + r"- 'arg0' does not have an attribute 'a'$"): + connect(m, NS(signature=Signature({"a": In(1)}))) + + def test_signature_freeze(self): + m = Module() + intf = NS(signature=Signature({})) + connect(m, intf) + self.assertTrue(intf.signature.frozen) + + def test_member_missing(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Member 'b' is present in 'q', but not in 'p'$"): + connect(m, + p=NS(signature=Signature({"a": In(1)}), + a=Signal()), + q=NS(signature=Signature({"a": In(1), "b": Out(1)}), + a=Signal(), b=Signal())) + with self.assertRaisesRegex(ConnectionError, + r"^Member 'b' is present in 'p', but not in 'q'$"): + connect(m, + p=NS(signature=Signature({"a": In(1), "b": Out(1)}), + a=Signal(), b=Signal()), + q=NS(signature=Signature({"a": In(1)}), + a=Signal())) + + def test_signature_to_port(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect signature member\(s\) 'p\.a' with port member\(s\) 'q\.a'$"): + connect(m, + p=NS(signature=Signature({"a": Out(Signature({}))}), + a=NS(signature=Signature({}))), + q=NS(signature=Signature({"a": In(1)}), + a=Signal())) + + def test_shape_mismatch(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect the member 'q\.a' with shape unsigned\(2\) to the member 'p\.a' " + r"with shape unsigned\(1\) because the shape widths \(2 and 1\) do not match$"): + connect(m, + p=NS(signature=Signature({"a": Out(1)}), + a=Signal()), + q=NS(signature=Signature({"a": In(2)}), + a=Signal(2))) + + def test_shape_mismatch_enum(self): + class Cycle(enum.Enum, shape=2): + READ = 0 + WRITE = 1 + + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect the member 'q\.a' with shape unsigned\(2\) \(\) " + r"to the member 'p\.a' with shape unsigned\(1\) because the shape widths " + r"\(2 and 1\) do not match$"): + connect(m, + p=NS(signature=Signature({"a": Out(1)}), + a=Signal()), + q=NS(signature=Signature({"a": In(Cycle)}), + a=Signal(Cycle))) + + def test_reset_mismatch(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect together the member 'q\.a' with reset value 1 and the member " + r"'p\.a' with reset value 0 because the reset values do not match$"): + connect(m, + p=NS(signature=Signature({"a": Out(1, reset=0)}), + a=Signal()), + q=NS(signature=Signature({"a": In(1, reset=1)}), + a=Signal(reset=1))) + + def test_reset_none_match(self): + m = Module() + connect(m, + p=NS(signature=Signature({"a": Out(1, reset=0)}), + a=Signal()), + q=NS(signature=Signature({"a": In(1)}), + a=Signal())) + + def test_out_to_out(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect several output members 'p\.a', 'q\.a' together$"): + connect(m, + p=NS(signature=Signature({"a": Out(1)}), + a=Signal()), + q=NS(signature=Signature({"a": Out(1)}), + a=Signal())) + + def test_out_to_const_in(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect to the input member 'q\.a' that has a constant value 0$"): + connect(m, + p=NS(signature=Signature({"a": Out(1)}), + a=Signal()), + q=NS(signature=Signature({"a": In(1)}), + a=Const(0))) + + def test_const_out_to_const_in_value_mismatch(self): + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect input member 'q\.a' that has a constant value 0 to an output " + r"member 'p\.a' that has a differing constant value 1$"): + connect(m, + p=NS(signature=Signature({"a": Out(1)}), + a=Const(1)), + q=NS(signature=Signature({"a": In(1)}), + a=Const(0))) + + def test_simple_bus(self): + class Cycle(enum.Enum): + IDLE = 0 + READ = 1 + WRITE = 2 + sig = Signature({ + "cycle": Out(Cycle), + "addr": Out(16), + "r_data": In(32), + "w_data": Out(32), + }) + + src = sig.create(path=('src',)) + snk = sig.flip().create(path=('snk',)) + + m = Module() + connect(m, src=src, snk=snk) + self.assertEqual([repr(stmt) for stmt in m._statements], [ + '(eq (sig snk__addr) (sig src__addr))', + '(eq (sig snk__cycle) (sig src__cycle))', + '(eq (sig src__r_data) (sig snk__r_data))', + '(eq (sig snk__w_data) (sig src__w_data))' + ]) + + def test_const_in_out(self): + m = Module() + connect(m, + p=NS(signature=Signature({"a": Out(1)}), + a=Const(1)), + q=NS(signature=Signature({"a": In(1)}), + a=Const(1))) + self.assertEqual(m._statements, []) + + def test_nested(self): + m = Module() + connect(m, + p=NS(signature=Signature({"a": Out(Signature({"f": Out(1)}))}), + a=NS(f=Signal(name='p__a'))), + q=NS(signature=Signature({"a": In(Signature({"f": Out(1)}))}), + a=NS(f=Signal(name='q__a')))) + self.assertEqual([repr(stmt) for stmt in m._statements], [ + '(eq (sig q__a) (sig p__a))' + ]) + + def test_dimension(self): + sig = Signature({"a": Out(1).array(2)}) + + m = Module() + connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',))) + self.assertEqual([repr(stmt) for stmt in m._statements], [ + '(eq (sig q__a__0) (sig p__a__0))', + '(eq (sig q__a__1) (sig p__a__1))' + ]) + + def test_dimension_multi(self): + sig = Signature({"a": Out(1).array(1).array(1)}) + + m = Module() + connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',))) + self.assertEqual([repr(stmt) for stmt in m._statements], [ + '(eq (sig q__a__0__0) (sig p__a__0__0))', + ]) + + +class ComponentTestCase(unittest.TestCase): + def test_basic(self): + class C(Component): + sig : Out(2) + + c = C() + self.assertEqual(c.signature, Signature({"sig": Out(2)})) + self.assertIsInstance(c.sig, Signal) + self.assertEqual(c.sig.shape(), unsigned(2)) + + def test_non_member_annotations(self): + class C(Component): + sig : Out(2) + foo : int + + c = C() + self.assertEqual(c.signature, Signature({"sig": Out(2)})) + + def test_private_member_annotations(self): + class C(Component): + sig_pub : Out(2) + _sig_priv : Out(2) + + c = C() + self.assertEqual(c.signature, Signature({"sig_pub": Out(2)})) + + def test_no_annotations(self): + class C(Component): + pass + + with self.assertRaisesRegex(NotImplementedError, + r"^Component '.+?\.C' does not have signature member annotations$"): + C() + + def test_would_overwrite_field(self): + class C(Component): + sig : Out(2) + + def __init__(self): + self.sig = 1 + super().__init__() + + with self.assertRaisesRegex(NameError, + r"^Cannot initialize attribute for signature member 'sig' because an attribute " + r"with the same name already exists$"): + C() + + def test_missing_in_out_warning(self): + class C1(Component): + prt1 : In(1) + sig2 : Signal + + with self.assertWarnsRegex(SyntaxWarning, + r"^Component '.+\.C1' has an annotation 'sig2: Signal', which is not a signature " + r"member; did you mean 'sig2: In\(Signal\)' or 'sig2: Out\(Signal\)'\?$"): + C1().signature + + class C2(Component): + prt1 : In(1) + sig2 : Signature({}) + + with self.assertWarnsRegex(SyntaxWarning, + r"^Component '.+\.C2' has an annotation 'sig2: Signature\({}\)', which is not " + r"a signature member; did you mean 'sig2: In\(Signature\({}\)\)' or " + r"'sig2: Out\(Signature\({}\)\)'\?$"): + C2().signature + + class MockValueCastable(ValueCastable): + def shape(self): pass + @ValueCastable.lowermethod + def as_value(self): pass + + class C3(Component): + prt1 : In(1) + val2 : MockValueCastable + + with self.assertWarnsRegex(SyntaxWarning, + r"^Component '.+\.C3' has an annotation 'val2: MockValueCastable', which is not " + r"a signature member; did you mean 'val2: In\(MockValueCastable\)' or " + r"'val2: Out\(MockValueCastable\)'\?$"): + C3().signature