ast: allow overriding Value operators.
This commit is contained in:
parent
1c3227d956
commit
879601380d
|
@ -156,6 +156,25 @@ def signed(width):
|
|||
return Shape(width, signed=True)
|
||||
|
||||
|
||||
def _overridable_by_reflected(method_name):
|
||||
"""Allow overriding the decorated method.
|
||||
|
||||
Allows :class:`ValueCastable` to override the decorated method by implementing
|
||||
a reflected method named ``method_name``. Intended for operators, but
|
||||
also usable for other methods that have a reflected counterpart.
|
||||
"""
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(self, other):
|
||||
if isinstance(other, ValueCastable) and hasattr(other, method_name):
|
||||
res = getattr(other, method_name)(self)
|
||||
if res is not NotImplemented:
|
||||
return res
|
||||
return f(self, other)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class Value(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def cast(obj):
|
||||
|
@ -195,26 +214,31 @@ class Value(metaclass=ABCMeta):
|
|||
def __neg__(self):
|
||||
return Operator("-", [self])
|
||||
|
||||
@_overridable_by_reflected("__radd__")
|
||||
def __add__(self, other):
|
||||
return Operator("+", [self, other])
|
||||
return Operator("+", [self, other], src_loc_at=1)
|
||||
def __radd__(self, other):
|
||||
return Operator("+", [other, self])
|
||||
@_overridable_by_reflected("__rsub__")
|
||||
def __sub__(self, other):
|
||||
return Operator("-", [self, other])
|
||||
return Operator("-", [self, other], src_loc_at=1)
|
||||
def __rsub__(self, other):
|
||||
return Operator("-", [other, self])
|
||||
|
||||
@_overridable_by_reflected("__rmul__")
|
||||
def __mul__(self, other):
|
||||
return Operator("*", [self, other])
|
||||
return Operator("*", [self, other], src_loc_at=1)
|
||||
def __rmul__(self, other):
|
||||
return Operator("*", [other, self])
|
||||
|
||||
@_overridable_by_reflected("__rmod__")
|
||||
def __mod__(self, other):
|
||||
return Operator("%", [self, other])
|
||||
return Operator("%", [self, other], src_loc_at=1)
|
||||
def __rmod__(self, other):
|
||||
return Operator("%", [other, self])
|
||||
@_overridable_by_reflected("__rfloordiv__")
|
||||
def __floordiv__(self, other):
|
||||
return Operator("//", [self, other])
|
||||
return Operator("//", [self, other], src_loc_at=1)
|
||||
def __rfloordiv__(self, other):
|
||||
return Operator("//", [other, self])
|
||||
|
||||
|
@ -224,46 +248,57 @@ class Value(metaclass=ABCMeta):
|
|||
# by a signed value to make sure the shift amount can always be interpreted as
|
||||
# an unsigned value.
|
||||
raise TypeError("Shift amount must be unsigned")
|
||||
@_overridable_by_reflected("__rlshift__")
|
||||
def __lshift__(self, other):
|
||||
other = Value.cast(other)
|
||||
other.__check_shamt()
|
||||
return Operator("<<", [self, other])
|
||||
return Operator("<<", [self, other], src_loc_at=1)
|
||||
def __rlshift__(self, other):
|
||||
self.__check_shamt()
|
||||
return Operator("<<", [other, self])
|
||||
@_overridable_by_reflected("__rrshift__")
|
||||
def __rshift__(self, other):
|
||||
other = Value.cast(other)
|
||||
other.__check_shamt()
|
||||
return Operator(">>", [self, other])
|
||||
return Operator(">>", [self, other], src_loc_at=1)
|
||||
def __rrshift__(self, other):
|
||||
self.__check_shamt()
|
||||
return Operator(">>", [other, self])
|
||||
|
||||
@_overridable_by_reflected("__rand__")
|
||||
def __and__(self, other):
|
||||
return Operator("&", [self, other])
|
||||
return Operator("&", [self, other], src_loc_at=1)
|
||||
def __rand__(self, other):
|
||||
return Operator("&", [other, self])
|
||||
@_overridable_by_reflected("__rxor__")
|
||||
def __xor__(self, other):
|
||||
return Operator("^", [self, other])
|
||||
return Operator("^", [self, other], src_loc_at=1)
|
||||
def __rxor__(self, other):
|
||||
return Operator("^", [other, self])
|
||||
@_overridable_by_reflected("__ror__")
|
||||
def __or__(self, other):
|
||||
return Operator("|", [self, other])
|
||||
return Operator("|", [self, other], src_loc_at=1)
|
||||
def __ror__(self, other):
|
||||
return Operator("|", [other, self])
|
||||
|
||||
@_overridable_by_reflected("__eq__")
|
||||
def __eq__(self, other):
|
||||
return Operator("==", [self, other])
|
||||
return Operator("==", [self, other], src_loc_at=1)
|
||||
@_overridable_by_reflected("__ne__")
|
||||
def __ne__(self, other):
|
||||
return Operator("!=", [self, other])
|
||||
return Operator("!=", [self, other], src_loc_at=1)
|
||||
@_overridable_by_reflected("__gt__")
|
||||
def __lt__(self, other):
|
||||
return Operator("<", [self, other])
|
||||
return Operator("<", [self, other], src_loc_at=1)
|
||||
@_overridable_by_reflected("__ge__")
|
||||
def __le__(self, other):
|
||||
return Operator("<=", [self, other])
|
||||
return Operator("<=", [self, other], src_loc_at=1)
|
||||
@_overridable_by_reflected("__lt__")
|
||||
def __gt__(self, other):
|
||||
return Operator(">", [self, other])
|
||||
return Operator(">", [self, other], src_loc_at=1)
|
||||
@_overridable_by_reflected("__le__")
|
||||
def __ge__(self, other):
|
||||
return Operator(">=", [self, other])
|
||||
return Operator(">=", [self, other], src_loc_at=1)
|
||||
|
||||
def __abs__(self):
|
||||
if self.shape().signed:
|
||||
|
|
|
@ -55,6 +55,7 @@ Implemented RFCs
|
|||
.. _RFC 19: https://amaranth-lang.org/rfcs/0019-remove-scheduler.html
|
||||
.. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html
|
||||
.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html
|
||||
.. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html
|
||||
|
||||
|
||||
* `RFC 1`_: Aggregate data structure library
|
||||
|
@ -71,6 +72,7 @@ Implemented RFCs
|
|||
* `RFC 15`_: Lifting shape-castable objects
|
||||
* `RFC 20`_: Deprecate non-FWFT FIFOs
|
||||
* `RFC 22`_: Define ``ValueCastable.shape()``
|
||||
* `RFC 28`_: Allow overriding ``Value`` operators
|
||||
|
||||
|
||||
Language changes
|
||||
|
|
|
@ -285,22 +285,22 @@ class RecordTestCase(FHDLTestCase):
|
|||
# __eq__, __ne__, __lt__, __le__, __gt__, __ge__
|
||||
self.assertEqual(repr(r1 == 1), "(== (cat (sig r1__a)) (const 1'd1))")
|
||||
self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))")
|
||||
self.assertEqual(repr(s1 == r1), "(== (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(r1 != 1), "(!= (cat (sig r1__a)) (const 1'd1))")
|
||||
self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))")
|
||||
self.assertEqual(repr(s1 != r1), "(!= (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(r1 < 1), "(< (cat (sig r1__a)) (const 1'd1))")
|
||||
self.assertEqual(repr(r1 < s1), "(< (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(s1 < r1), "(< (sig s1) (cat (sig r1__a)))")
|
||||
self.assertEqual(repr(s1 < r1), "(> (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(r1 <= 1), "(<= (cat (sig r1__a)) (const 1'd1))")
|
||||
self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))")
|
||||
self.assertEqual(repr(s1 <= r1), "(>= (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(r1 > 1), "(> (cat (sig r1__a)) (const 1'd1))")
|
||||
self.assertEqual(repr(r1 > s1), "(> (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(s1 > r1), "(> (sig s1) (cat (sig r1__a)))")
|
||||
self.assertEqual(repr(s1 > r1), "(< (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(r1 >= 1), "(>= (cat (sig r1__a)) (const 1'd1))")
|
||||
self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))")
|
||||
self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))")
|
||||
self.assertEqual(repr(s1 >= r1), "(<= (cat (sig r1__a)) (sig s1))")
|
||||
|
||||
# __abs__, __len__
|
||||
self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))")
|
||||
|
|
Loading…
Reference in a new issue