fhdl.ast: refactor Operator.shape(). NFC.
This commit is contained in:
		
							parent
							
								
									3a8685c352
								
							
						
					
					
						commit
						46f5addf05
					
				|  | @ -259,61 +259,71 @@ class Operator(Value): | ||||||
|         self.operands = [Value.wrap(o) for o in operands] |         self.operands = [Value.wrap(o) for o in operands] | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _bitwise_binary_shape(a, b): |     def _bitwise_binary_shape(a_shape, b_shape): | ||||||
|         if not a[1] and not b[1]: |         a_bits, a_sign = a_shape | ||||||
|  |         b_bits, b_sign = b_shape | ||||||
|  |         if not a_sign and not b_sign: | ||||||
|             # both operands unsigned |             # both operands unsigned | ||||||
|             return max(a[0], b[0]), False |             return max(a_bits, b_bits), False | ||||||
|         elif a[1] and b[1]: |         elif a_sign and b_sign: | ||||||
|             # both operands signed |             # both operands signed | ||||||
|             return max(a[0], b[0]), True |             return max(a_bits, b_bits), True | ||||||
|         elif not a[1] and b[1]: |         elif not a_sign and b_sign: | ||||||
|             # first operand unsigned (add sign bit), second operand signed |             # first operand unsigned (add sign bit), second operand signed | ||||||
|             return max(a[0] + 1, b[0]), True |             return max(a_bits + 1, b_bits), True | ||||||
|         else: |         else: | ||||||
|             # first signed, second operand unsigned (add sign bit) |             # first signed, second operand unsigned (add sign bit) | ||||||
|             return max(a[0], b[0] + 1), True |             return max(a_bits, b_bits + 1), True | ||||||
| 
 | 
 | ||||||
|     def shape(self): |     def shape(self): | ||||||
|         obs = list(map(lambda x: x.shape(), self.operands)) |         op_shapes = list(map(lambda x: x.shape(), self.operands)) | ||||||
|         if self.op == "+" or self.op == "-": |         if len(op_shapes) == 1: | ||||||
|             if len(obs) == 1: |             (a_bits, a_sign), = op_shapes | ||||||
|                 if self.op == "-" and not obs[0][1]: |             if self.op in ("+", "~"): | ||||||
|                     return obs[0][0] + 1, True |                 return a_bits, a_sign | ||||||
|  |             if self.op == "-": | ||||||
|  |                 if not a_sign: | ||||||
|  |                     return a_bits + 1, True | ||||||
|                 else: |                 else: | ||||||
|                     return obs[0] |                     return a_bits, a_sign | ||||||
|             n, s = self._bitwise_binary_shape(*obs) |             if self.op == "b": | ||||||
|             return n + 1, s |                 return 1, False | ||||||
|         elif self.op == "*": |         elif len(op_shapes) == 2: | ||||||
|             if not obs[0][1] and not obs[1][1]: |             (a_bits, a_sign), (b_bits, b_sign) = op_shapes | ||||||
|                 # both operands unsigned |             if self.op == "+" or self.op == "-": | ||||||
|                 return obs[0][0] + obs[1][0], False |                 bits, sign = self._bitwise_binary_shape(*op_shapes) | ||||||
|             elif obs[0][1] and obs[1][1]: |                 return bits + 1, sign | ||||||
|                 # both operands signed |             if self.op == "*": | ||||||
|                 return obs[0][0] + obs[1][0] - 1, True |                 if not a_sign and not b_sign: | ||||||
|             else: |                     # both operands unsigned | ||||||
|  |                     return a_bits + b_bits, False | ||||||
|  |                 if a_sign and b_sign: | ||||||
|  |                     # both operands signed | ||||||
|  |                     return a_bits + b_bits - 1, True | ||||||
|                 # one operand signed, the other unsigned (add sign bit) |                 # one operand signed, the other unsigned (add sign bit) | ||||||
|                 return obs[0][0] + obs[1][0] + 1 - 1, True |                 return a_bits + b_bits + 1 - 1, True | ||||||
|         elif self.op == "<<<": |             if self.op in ("<", "<=", "==", "!=", ">", ">=", "b"): | ||||||
|             if obs[1][1]: |                 return 1, False | ||||||
|                 extra = 2**(obs[1][0] - 1) - 1 |             if self.op in ("&", "^", "|"): | ||||||
|             else: |                 return self._bitwise_binary_shape(*op_shapes) | ||||||
|                 extra = 2**obs[1][0] - 1 |             if self.op == "<<<": | ||||||
|             return obs[0][0] + extra, obs[0][1] |                 if b_sign: | ||||||
|         elif self.op == ">>>": |                     extra = 2**(b_bits - 1) - 1 | ||||||
|             if obs[1][1]: |                 else: | ||||||
|                 extra = 2**(obs[1][0] - 1) |                     extra = 2**b_bits - 1 | ||||||
|             else: |                 return a_bits + extra, a_sign | ||||||
|                 extra = 0 |             if self.op == ">>>": | ||||||
|             return obs[0][0] + extra, obs[0][1] |                 if b_sign: | ||||||
|         elif self.op in ("&", "^", "|"): |                     extra = 2**(b_bits - 1) | ||||||
|             return self._bitwise_binary_shape(*obs) |                 else: | ||||||
|         elif self.op in ("<", "<=", "==", "!=", ">", ">=", "b"): |                     extra = 0 | ||||||
|             return 1, False |                 return a_bits + extra, a_sign | ||||||
|         elif self.op == "~": |         elif len(op_shapes) == 3: | ||||||
|             return obs[0] |             if self.op == "m": | ||||||
|         elif self.op == "m": |                 s_shape, a_shape, b_shape = op_shapes | ||||||
|             return self._bitwise_binary_shape(obs[1], obs[2]) |                 return self._bitwise_binary_shape(a_shape, b_shape) | ||||||
|         raise NotImplementedError("Operator '{}' not implemented".format(self.op)) # :nocov: |         raise NotImplementedError("Operator {}/{} not implemented" | ||||||
|  |                                   .format(self.op, len(op_shapes))) # :nocov: | ||||||
| 
 | 
 | ||||||
|     def _rhs_signals(self): |     def _rhs_signals(self): | ||||||
|         return union(op._rhs_signals() for op in self.operands) |         return union(op._rhs_signals() for op in self.operands) | ||||||
|  |  | ||||||
|  | @ -88,6 +88,11 @@ class ConstTestCase(FHDLTestCase): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class OperatorTestCase(FHDLTestCase): | class OperatorTestCase(FHDLTestCase): | ||||||
|  |     def test_bool(self): | ||||||
|  |         v = Const(0, 4).bool() | ||||||
|  |         self.assertEqual(repr(v), "(b (const 4'd0))") | ||||||
|  |         self.assertEqual(v.shape(), (1, False)) | ||||||
|  | 
 | ||||||
|     def test_invert(self): |     def test_invert(self): | ||||||
|         v = ~Const(0, 4) |         v = ~Const(0, 4) | ||||||
|         self.assertEqual(repr(v), "(~ (const 4'd0))") |         self.assertEqual(repr(v), "(~ (const 4'd0))") | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue
	
	 whitequark
						whitequark