ast: allow overriding Value operators.
This commit is contained in:
		
							parent
							
								
									1c3227d956
								
							
						
					
					
						commit
						879601380d
					
				|  | @ -156,6 +156,25 @@ def signed(width): | ||||||
|     return Shape(width, signed=True) |     return Shape(width, signed=True) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _overridable_by_reflected(method_name): | ||||||
|  |     """Allow overriding the decorated method. | ||||||
|  | 
 | ||||||
|  |     Allows :class:`ValueCastable` to override the decorated method by implementing | ||||||
|  |     a reflected method named ``method_name``. Intended for operators, but | ||||||
|  |     also usable for other methods that have a reflected counterpart. | ||||||
|  |     """ | ||||||
|  |     def decorator(f): | ||||||
|  |         @functools.wraps(f) | ||||||
|  |         def wrapper(self, other): | ||||||
|  |             if isinstance(other, ValueCastable) and hasattr(other, method_name): | ||||||
|  |                 res = getattr(other, method_name)(self) | ||||||
|  |                 if res is not NotImplemented: | ||||||
|  |                     return res | ||||||
|  |             return f(self, other) | ||||||
|  |         return wrapper | ||||||
|  |     return decorator | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class Value(metaclass=ABCMeta): | class Value(metaclass=ABCMeta): | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def cast(obj): |     def cast(obj): | ||||||
|  | @ -195,26 +214,31 @@ class Value(metaclass=ABCMeta): | ||||||
|     def __neg__(self): |     def __neg__(self): | ||||||
|         return Operator("-", [self]) |         return Operator("-", [self]) | ||||||
| 
 | 
 | ||||||
|  |     @_overridable_by_reflected("__radd__") | ||||||
|     def __add__(self, other): |     def __add__(self, other): | ||||||
|         return Operator("+", [self, other]) |         return Operator("+", [self, other], src_loc_at=1) | ||||||
|     def __radd__(self, other): |     def __radd__(self, other): | ||||||
|         return Operator("+", [other, self]) |         return Operator("+", [other, self]) | ||||||
|  |     @_overridable_by_reflected("__rsub__") | ||||||
|     def __sub__(self, other): |     def __sub__(self, other): | ||||||
|         return Operator("-", [self, other]) |         return Operator("-", [self, other], src_loc_at=1) | ||||||
|     def __rsub__(self, other): |     def __rsub__(self, other): | ||||||
|         return Operator("-", [other, self]) |         return Operator("-", [other, self]) | ||||||
| 
 | 
 | ||||||
|  |     @_overridable_by_reflected("__rmul__") | ||||||
|     def __mul__(self, other): |     def __mul__(self, other): | ||||||
|         return Operator("*", [self, other]) |         return Operator("*", [self, other], src_loc_at=1) | ||||||
|     def __rmul__(self, other): |     def __rmul__(self, other): | ||||||
|         return Operator("*", [other, self]) |         return Operator("*", [other, self]) | ||||||
| 
 | 
 | ||||||
|  |     @_overridable_by_reflected("__rmod__") | ||||||
|     def __mod__(self, other): |     def __mod__(self, other): | ||||||
|         return Operator("%", [self, other]) |         return Operator("%", [self, other], src_loc_at=1) | ||||||
|     def __rmod__(self, other): |     def __rmod__(self, other): | ||||||
|         return Operator("%", [other, self]) |         return Operator("%", [other, self]) | ||||||
|  |     @_overridable_by_reflected("__rfloordiv__") | ||||||
|     def __floordiv__(self, other): |     def __floordiv__(self, other): | ||||||
|         return Operator("//", [self, other]) |         return Operator("//", [self, other], src_loc_at=1) | ||||||
|     def __rfloordiv__(self, other): |     def __rfloordiv__(self, other): | ||||||
|         return Operator("//", [other, self]) |         return Operator("//", [other, self]) | ||||||
| 
 | 
 | ||||||
|  | @ -224,46 +248,57 @@ class Value(metaclass=ABCMeta): | ||||||
|             # by a signed value to make sure the shift amount can always be interpreted as |             # by a signed value to make sure the shift amount can always be interpreted as | ||||||
|             # an unsigned value. |             # an unsigned value. | ||||||
|             raise TypeError("Shift amount must be unsigned") |             raise TypeError("Shift amount must be unsigned") | ||||||
|  |     @_overridable_by_reflected("__rlshift__") | ||||||
|     def __lshift__(self, other): |     def __lshift__(self, other): | ||||||
|         other = Value.cast(other) |         other = Value.cast(other) | ||||||
|         other.__check_shamt() |         other.__check_shamt() | ||||||
|         return Operator("<<", [self, other]) |         return Operator("<<", [self, other], src_loc_at=1) | ||||||
|     def __rlshift__(self, other): |     def __rlshift__(self, other): | ||||||
|         self.__check_shamt() |         self.__check_shamt() | ||||||
|         return Operator("<<", [other, self]) |         return Operator("<<", [other, self]) | ||||||
|  |     @_overridable_by_reflected("__rrshift__") | ||||||
|     def __rshift__(self, other): |     def __rshift__(self, other): | ||||||
|         other = Value.cast(other) |         other = Value.cast(other) | ||||||
|         other.__check_shamt() |         other.__check_shamt() | ||||||
|         return Operator(">>", [self, other]) |         return Operator(">>", [self, other], src_loc_at=1) | ||||||
|     def __rrshift__(self, other): |     def __rrshift__(self, other): | ||||||
|         self.__check_shamt() |         self.__check_shamt() | ||||||
|         return Operator(">>", [other, self]) |         return Operator(">>", [other, self]) | ||||||
| 
 | 
 | ||||||
|  |     @_overridable_by_reflected("__rand__") | ||||||
|     def __and__(self, other): |     def __and__(self, other): | ||||||
|         return Operator("&", [self, other]) |         return Operator("&", [self, other], src_loc_at=1) | ||||||
|     def __rand__(self, other): |     def __rand__(self, other): | ||||||
|         return Operator("&", [other, self]) |         return Operator("&", [other, self]) | ||||||
|  |     @_overridable_by_reflected("__rxor__") | ||||||
|     def __xor__(self, other): |     def __xor__(self, other): | ||||||
|         return Operator("^", [self, other]) |         return Operator("^", [self, other], src_loc_at=1) | ||||||
|     def __rxor__(self, other): |     def __rxor__(self, other): | ||||||
|         return Operator("^", [other, self]) |         return Operator("^", [other, self]) | ||||||
|  |     @_overridable_by_reflected("__ror__") | ||||||
|     def __or__(self, other): |     def __or__(self, other): | ||||||
|         return Operator("|", [self, other]) |         return Operator("|", [self, other], src_loc_at=1) | ||||||
|     def __ror__(self, other): |     def __ror__(self, other): | ||||||
|         return Operator("|", [other, self]) |         return Operator("|", [other, self]) | ||||||
| 
 | 
 | ||||||
|  |     @_overridable_by_reflected("__eq__") | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         return Operator("==", [self, other]) |         return Operator("==", [self, other], src_loc_at=1) | ||||||
|  |     @_overridable_by_reflected("__ne__") | ||||||
|     def __ne__(self, other): |     def __ne__(self, other): | ||||||
|         return Operator("!=", [self, other]) |         return Operator("!=", [self, other], src_loc_at=1) | ||||||
|  |     @_overridable_by_reflected("__gt__") | ||||||
|     def __lt__(self, other): |     def __lt__(self, other): | ||||||
|         return Operator("<", [self, other]) |         return Operator("<", [self, other], src_loc_at=1) | ||||||
|  |     @_overridable_by_reflected("__ge__") | ||||||
|     def __le__(self, other): |     def __le__(self, other): | ||||||
|         return Operator("<=", [self, other]) |         return Operator("<=", [self, other], src_loc_at=1) | ||||||
|  |     @_overridable_by_reflected("__lt__") | ||||||
|     def __gt__(self, other): |     def __gt__(self, other): | ||||||
|         return Operator(">", [self, other]) |         return Operator(">", [self, other], src_loc_at=1) | ||||||
|  |     @_overridable_by_reflected("__le__") | ||||||
|     def __ge__(self, other): |     def __ge__(self, other): | ||||||
|         return Operator(">=", [self, other]) |         return Operator(">=", [self, other], src_loc_at=1) | ||||||
| 
 | 
 | ||||||
|     def __abs__(self): |     def __abs__(self): | ||||||
|         if self.shape().signed: |         if self.shape().signed: | ||||||
|  |  | ||||||
|  | @ -55,6 +55,7 @@ Implemented RFCs | ||||||
| .. _RFC 19: https://amaranth-lang.org/rfcs/0019-remove-scheduler.html | .. _RFC 19: https://amaranth-lang.org/rfcs/0019-remove-scheduler.html | ||||||
| .. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html | .. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html | ||||||
| .. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html | .. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html | ||||||
|  | .. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| * `RFC 1`_: Aggregate data structure library | * `RFC 1`_: Aggregate data structure library | ||||||
|  | @ -71,6 +72,7 @@ Implemented RFCs | ||||||
| * `RFC 15`_: Lifting shape-castable objects | * `RFC 15`_: Lifting shape-castable objects | ||||||
| * `RFC 20`_: Deprecate non-FWFT FIFOs | * `RFC 20`_: Deprecate non-FWFT FIFOs | ||||||
| * `RFC 22`_: Define ``ValueCastable.shape()`` | * `RFC 22`_: Define ``ValueCastable.shape()`` | ||||||
|  | * `RFC 28`_: Allow overriding ``Value`` operators | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| Language changes | Language changes | ||||||
|  |  | ||||||
|  | @ -285,22 +285,22 @@ class RecordTestCase(FHDLTestCase): | ||||||
|         # __eq__, __ne__, __lt__, __le__, __gt__, __ge__ |         # __eq__, __ne__, __lt__, __le__, __gt__, __ge__ | ||||||
|         self.assertEqual(repr(r1 == 1),  "(== (cat (sig r1__a)) (const 1'd1))") |         self.assertEqual(repr(r1 == 1),  "(== (cat (sig r1__a)) (const 1'd1))") | ||||||
|         self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))") |         self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))") |         self.assertEqual(repr(s1 == r1), "(== (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(r1 != 1),  "(!= (cat (sig r1__a)) (const 1'd1))") |         self.assertEqual(repr(r1 != 1),  "(!= (cat (sig r1__a)) (const 1'd1))") | ||||||
|         self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))") |         self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))") |         self.assertEqual(repr(s1 != r1), "(!= (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(r1 < 1),   "(< (cat (sig r1__a)) (const 1'd1))") |         self.assertEqual(repr(r1 < 1),   "(< (cat (sig r1__a)) (const 1'd1))") | ||||||
|         self.assertEqual(repr(r1 < s1),  "(< (cat (sig r1__a)) (sig s1))") |         self.assertEqual(repr(r1 < s1),  "(< (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(s1 < r1),  "(< (sig s1) (cat (sig r1__a)))") |         self.assertEqual(repr(s1 < r1),  "(> (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(r1 <= 1),  "(<= (cat (sig r1__a)) (const 1'd1))") |         self.assertEqual(repr(r1 <= 1),  "(<= (cat (sig r1__a)) (const 1'd1))") | ||||||
|         self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))") |         self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))") |         self.assertEqual(repr(s1 <= r1), "(>= (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(r1 > 1),   "(> (cat (sig r1__a)) (const 1'd1))") |         self.assertEqual(repr(r1 > 1),   "(> (cat (sig r1__a)) (const 1'd1))") | ||||||
|         self.assertEqual(repr(r1 > s1),  "(> (cat (sig r1__a)) (sig s1))") |         self.assertEqual(repr(r1 > s1),  "(> (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(s1 > r1),  "(> (sig s1) (cat (sig r1__a)))") |         self.assertEqual(repr(s1 > r1),  "(< (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(r1 >= 1),  "(>= (cat (sig r1__a)) (const 1'd1))") |         self.assertEqual(repr(r1 >= 1),  "(>= (cat (sig r1__a)) (const 1'd1))") | ||||||
|         self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))") |         self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))") | ||||||
|         self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))") |         self.assertEqual(repr(s1 >= r1), "(<= (cat (sig r1__a)) (sig s1))") | ||||||
| 
 | 
 | ||||||
|         # __abs__, __len__ |         # __abs__, __len__ | ||||||
|         self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))") |         self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))") | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue
	
	 Vegard Storheil Eriksen
						Vegard Storheil Eriksen