diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index 7655d8a..f183302 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -156,6 +156,25 @@ def signed(width): return Shape(width, signed=True) +def _overridable_by_reflected(method_name): + """Allow overriding the decorated method. + + Allows :class:`ValueCastable` to override the decorated method by implementing + a reflected method named ``method_name``. Intended for operators, but + also usable for other methods that have a reflected counterpart. + """ + def decorator(f): + @functools.wraps(f) + def wrapper(self, other): + if isinstance(other, ValueCastable) and hasattr(other, method_name): + res = getattr(other, method_name)(self) + if res is not NotImplemented: + return res + return f(self, other) + return wrapper + return decorator + + class Value(metaclass=ABCMeta): @staticmethod def cast(obj): @@ -195,26 +214,31 @@ class Value(metaclass=ABCMeta): def __neg__(self): return Operator("-", [self]) + @_overridable_by_reflected("__radd__") def __add__(self, other): - return Operator("+", [self, other]) + return Operator("+", [self, other], src_loc_at=1) def __radd__(self, other): return Operator("+", [other, self]) + @_overridable_by_reflected("__rsub__") def __sub__(self, other): - return Operator("-", [self, other]) + return Operator("-", [self, other], src_loc_at=1) def __rsub__(self, other): return Operator("-", [other, self]) + @_overridable_by_reflected("__rmul__") def __mul__(self, other): - return Operator("*", [self, other]) + return Operator("*", [self, other], src_loc_at=1) def __rmul__(self, other): return Operator("*", [other, self]) + @_overridable_by_reflected("__rmod__") def __mod__(self, other): - return Operator("%", [self, other]) + return Operator("%", [self, other], src_loc_at=1) def __rmod__(self, other): return Operator("%", [other, self]) + @_overridable_by_reflected("__rfloordiv__") def __floordiv__(self, other): - return Operator("//", [self, other]) + return Operator("//", [self, other], src_loc_at=1) def __rfloordiv__(self, other): return Operator("//", [other, self]) @@ -224,46 +248,57 @@ class Value(metaclass=ABCMeta): # by a signed value to make sure the shift amount can always be interpreted as # an unsigned value. raise TypeError("Shift amount must be unsigned") + @_overridable_by_reflected("__rlshift__") def __lshift__(self, other): other = Value.cast(other) other.__check_shamt() - return Operator("<<", [self, other]) + return Operator("<<", [self, other], src_loc_at=1) def __rlshift__(self, other): self.__check_shamt() return Operator("<<", [other, self]) + @_overridable_by_reflected("__rrshift__") def __rshift__(self, other): other = Value.cast(other) other.__check_shamt() - return Operator(">>", [self, other]) + return Operator(">>", [self, other], src_loc_at=1) def __rrshift__(self, other): self.__check_shamt() return Operator(">>", [other, self]) + @_overridable_by_reflected("__rand__") def __and__(self, other): - return Operator("&", [self, other]) + return Operator("&", [self, other], src_loc_at=1) def __rand__(self, other): return Operator("&", [other, self]) + @_overridable_by_reflected("__rxor__") def __xor__(self, other): - return Operator("^", [self, other]) + return Operator("^", [self, other], src_loc_at=1) def __rxor__(self, other): return Operator("^", [other, self]) + @_overridable_by_reflected("__ror__") def __or__(self, other): - return Operator("|", [self, other]) + return Operator("|", [self, other], src_loc_at=1) def __ror__(self, other): return Operator("|", [other, self]) + @_overridable_by_reflected("__eq__") def __eq__(self, other): - return Operator("==", [self, other]) + return Operator("==", [self, other], src_loc_at=1) + @_overridable_by_reflected("__ne__") def __ne__(self, other): - return Operator("!=", [self, other]) + return Operator("!=", [self, other], src_loc_at=1) + @_overridable_by_reflected("__gt__") def __lt__(self, other): - return Operator("<", [self, other]) + return Operator("<", [self, other], src_loc_at=1) + @_overridable_by_reflected("__ge__") def __le__(self, other): - return Operator("<=", [self, other]) + return Operator("<=", [self, other], src_loc_at=1) + @_overridable_by_reflected("__lt__") def __gt__(self, other): - return Operator(">", [self, other]) + return Operator(">", [self, other], src_loc_at=1) + @_overridable_by_reflected("__le__") def __ge__(self, other): - return Operator(">=", [self, other]) + return Operator(">=", [self, other], src_loc_at=1) def __abs__(self): if self.shape().signed: diff --git a/docs/changes.rst b/docs/changes.rst index c982a5f..204918d 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -55,6 +55,7 @@ Implemented RFCs .. _RFC 19: https://amaranth-lang.org/rfcs/0019-remove-scheduler.html .. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html .. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html +.. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html * `RFC 1`_: Aggregate data structure library @@ -71,6 +72,7 @@ Implemented RFCs * `RFC 15`_: Lifting shape-castable objects * `RFC 20`_: Deprecate non-FWFT FIFOs * `RFC 22`_: Define ``ValueCastable.shape()`` +* `RFC 28`_: Allow overriding ``Value`` operators Language changes diff --git a/tests/test_hdl_rec.py b/tests/test_hdl_rec.py index 12e1af8..4bf3592 100644 --- a/tests/test_hdl_rec.py +++ b/tests/test_hdl_rec.py @@ -285,22 +285,22 @@ class RecordTestCase(FHDLTestCase): # __eq__, __ne__, __lt__, __le__, __gt__, __ge__ self.assertEqual(repr(r1 == 1), "(== (cat (sig r1__a)) (const 1'd1))") self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))") - self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 == r1), "(== (cat (sig r1__a)) (sig s1))") self.assertEqual(repr(r1 != 1), "(!= (cat (sig r1__a)) (const 1'd1))") self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))") - self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 != r1), "(!= (cat (sig r1__a)) (sig s1))") self.assertEqual(repr(r1 < 1), "(< (cat (sig r1__a)) (const 1'd1))") self.assertEqual(repr(r1 < s1), "(< (cat (sig r1__a)) (sig s1))") - self.assertEqual(repr(s1 < r1), "(< (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 < r1), "(> (cat (sig r1__a)) (sig s1))") self.assertEqual(repr(r1 <= 1), "(<= (cat (sig r1__a)) (const 1'd1))") self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))") - self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 <= r1), "(>= (cat (sig r1__a)) (sig s1))") self.assertEqual(repr(r1 > 1), "(> (cat (sig r1__a)) (const 1'd1))") self.assertEqual(repr(r1 > s1), "(> (cat (sig r1__a)) (sig s1))") - self.assertEqual(repr(s1 > r1), "(> (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 > r1), "(< (cat (sig r1__a)) (sig s1))") self.assertEqual(repr(r1 >= 1), "(>= (cat (sig r1__a)) (const 1'd1))") self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))") - self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 >= r1), "(<= (cat (sig r1__a)) (sig s1))") # __abs__, __len__ self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))")