hdl.ast: add an explicit Shape class, included in prelude.

Shapes have long been a part of nMigen, but represented using tuples.
This commit adds a Shape class (using namedtuple for backwards
compatibility), and accepts anything castable to Shape (including
enums, ranges, etc) anywhere a tuple was accepted previously.

In addition, `signed(n)` and `unsigned(n)` are added as aliases for
`Shape(n, signed=True)` and `Shape(n, signed=False)`, transforming
code such as `Signal((8, True))` to `Signal(signed(8))`.
These aliases are also included in prelude.

Preparation for #225.
This commit is contained in:
whitequark 2019-10-11 12:52:41 +00:00
parent db960e7c30
commit 6aabdc0a73
5 changed files with 236 additions and 91 deletions

View file

@ -1,3 +1,4 @@
from .ast import Shape, unsigned, signed
from .ast import Value, Const, C, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal from .ast import Value, Const, C, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal
from .dsl import Module from .dsl import Module
from .cd import ClockDomain from .cd import ClockDomain

View file

@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod
import builtins import builtins
import traceback import traceback
import warnings import warnings
import typing
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
from enum import Enum from enum import Enum
@ -11,6 +12,7 @@ from ..tools import *
__all__ = [ __all__ = [
"Shape", "signed", "unsigned",
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
"Array", "ArrayProxy", "Array", "ArrayProxy",
"Signal", "ClockSignal", "ResetSignal", "Signal", "ClockSignal", "ResetSignal",
@ -23,21 +25,65 @@ __all__ = [
class DUID: class DUID:
"""Deterministic Unique IDentifier""" """Deterministic Unique IDentifier."""
__next_uid = 0 __next_uid = 0
def __init__(self): def __init__(self):
self.duid = DUID.__next_uid self.duid = DUID.__next_uid
DUID.__next_uid += 1 DUID.__next_uid += 1
def _enum_shape(enum_type): class Shape(typing.NamedTuple):
min_value = min(member.value for member in enum_type) """Bit width and signedness of a value.
max_value = max(member.value for member in enum_type)
if not isinstance(min_value, int) or not isinstance(max_value, int): Attributes
raise TypeError("Only enumerations with integer values can be converted to nMigen values") ----------
signed = min_value < 0 or max_value < 0 width : int
width = max(bits_for(min_value, signed), bits_for(max_value, signed)) The number of bits in the representation, including the sign bit (if any).
return (width, signed) signed : bool
If ``False``, the value is unsigned. If ``True``, the value is signed two's complement.
"""
width: int = 1
signed: bool = False
@staticmethod
def cast(obj):
if isinstance(obj, int):
return Shape(obj)
if isinstance(obj, tuple):
return Shape(*obj)
if isinstance(obj, range):
if len(obj) == 0:
return Shape(0, obj.start < 0)
signed = obj.start < 0 or (obj.stop - obj.step) < 0
width = max(bits_for(obj.start, signed),
bits_for(obj.stop - obj.step, signed))
return Shape(width, signed)
if isinstance(obj, type) and issubclass(obj, Enum):
min_value = min(member.value for member in obj)
max_value = max(member.value for member in obj)
if not isinstance(min_value, int) or not isinstance(max_value, int):
raise TypeError("Only enumerations with integer values can be used "
"as value shapes")
signed = min_value < 0 or max_value < 0
width = max(bits_for(min_value, signed), bits_for(max_value, signed))
return Shape(width, signed)
raise TypeError("Object {!r} cannot be used as value shape".format(obj))
# TODO: use dataclasses instead of this hack
def _Shape___init__(self, width=1, signed=False):
if not isinstance(width, int) or width < 0:
raise TypeError("Width must be a non-negative integer, not {!r}"
.format(width))
Shape.__init__ = _Shape___init__
def unsigned(width):
return Shape(width, signed=False)
def signed(width):
return Shape(width, signed=True)
class Value(metaclass=ABCMeta): class Value(metaclass=ABCMeta):
@ -50,12 +96,11 @@ class Value(metaclass=ABCMeta):
""" """
if isinstance(obj, Value): if isinstance(obj, Value):
return obj return obj
elif isinstance(obj, (bool, int)): if isinstance(obj, int):
return Const(obj) return Const(obj)
elif isinstance(obj, Enum): if isinstance(obj, Enum):
return Const(obj.value, _enum_shape(type(obj))) return Const(obj.value, Shape.cast(type(obj)))
else: raise TypeError("Object {!r} cannot be converted to an nMigen value".format(obj))
raise TypeError("Object {!r} is not an nMigen value".format(obj))
# TODO(nmigen-0.2): remove this # TODO(nmigen-0.2): remove this
@classmethod @classmethod
@ -146,7 +191,7 @@ class Value(metaclass=ABCMeta):
return Operator(">=", [self, other]) return Operator(">=", [self, other])
def __len__(self): def __len__(self):
return self.shape()[0] return self.shape().width
def __getitem__(self, key): def __getitem__(self, key):
n = len(self) n = len(self)
@ -329,20 +374,19 @@ class Value(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def shape(self): def shape(self):
"""Bit length and signedness of a value. """Bit width and signedness of a value.
Returns Returns
------- -------
int, bool Shape
Number of bits required to store `v` or available in `v`, followed by See :class:`Shape`.
whether `v` has a sign bit (included in the bit count).
Examples Examples
-------- --------
>>> Value.shape(Signal(8)) >>> Signal(8).shape()
8, False Shape(width=8, signed=False)
>>> Value.shape(C(0xaa)) >>> Const(0xaa).shape()
8, False Shape(width=8, signed=False)
""" """
pass # :nocov: pass # :nocov:
@ -391,13 +435,12 @@ class Const(Value):
# We deliberately do not call Value.__init__ here. # We deliberately do not call Value.__init__ here.
self.value = int(value) self.value = int(value)
if shape is None: if shape is None:
shape = bits_for(self.value), self.value < 0 shape = Shape(bits_for(self.value), signed=self.value < 0)
if isinstance(shape, int): elif isinstance(shape, int):
shape = shape, self.value < 0 shape = Shape(shape, signed=self.value < 0)
else:
shape = Shape.cast(shape)
self.width, self.signed = shape self.width, self.signed = shape
if not isinstance(self.width, int) or self.width < 0:
raise TypeError("Width must be a non-negative integer, not {!r}"
.format(self.width))
self.value = self.normalize(self.value, shape) self.value = self.normalize(self.value, shape)
# TODO(nmigen-0.2): move this to nmigen.compat and make it a deprecated extension # TODO(nmigen-0.2): move this to nmigen.compat and make it a deprecated extension
@ -407,7 +450,7 @@ class Const(Value):
return self.width return self.width
def shape(self): def shape(self):
return self.width, self.signed return Shape(self.width, self.signed)
def _rhs_signals(self): def _rhs_signals(self):
return ValueSet() return ValueSet()
@ -425,15 +468,13 @@ C = Const # shorthand
class AnyValue(Value, DUID): class AnyValue(Value, DUID):
def __init__(self, shape, *, src_loc_at=0): def __init__(self, shape, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at) super().__init__(src_loc_at=src_loc_at)
if isinstance(shape, int): self.width, self.signed = Shape.cast(shape)
shape = shape, False
self.width, self.signed = shape
if not isinstance(self.width, int) or self.width < 0: if not isinstance(self.width, int) or self.width < 0:
raise TypeError("Width must be a non-negative integer, not {!r}" raise TypeError("Width must be a non-negative integer, not {!r}"
.format(self.width)) .format(self.width))
def shape(self): def shape(self):
return self.width, self.signed return Shape(self.width, self.signed)
def _rhs_signals(self): def _rhs_signals(self):
return ValueSet() return ValueSet()
@ -470,41 +511,41 @@ class Operator(Value):
b_bits, b_sign = b_shape b_bits, b_sign = b_shape
if not a_sign and not b_sign: if not a_sign and not b_sign:
# both operands unsigned # both operands unsigned
return max(a_bits, b_bits), False return Shape(max(a_bits, b_bits), False)
elif a_sign and b_sign: elif a_sign and b_sign:
# both operands signed # both operands signed
return max(a_bits, b_bits), True return Shape(max(a_bits, b_bits), True)
elif not a_sign and b_sign: elif not a_sign and b_sign:
# first operand unsigned (add sign bit), second operand signed # first operand unsigned (add sign bit), second operand signed
return max(a_bits + 1, b_bits), True return Shape(max(a_bits + 1, b_bits), True)
else: else:
# first signed, second operand unsigned (add sign bit) # first signed, second operand unsigned (add sign bit)
return max(a_bits, b_bits + 1), True return Shape(max(a_bits, b_bits + 1), True)
op_shapes = list(map(lambda x: x.shape(), self.operands)) op_shapes = list(map(lambda x: x.shape(), self.operands))
if len(op_shapes) == 1: if len(op_shapes) == 1:
(a_width, a_signed), = op_shapes (a_width, a_signed), = op_shapes
if self.operator in ("+", "~"): if self.operator in ("+", "~"):
return a_width, a_signed return Shape(a_width, a_signed)
if self.operator == "-": if self.operator == "-":
if not a_signed: if not a_signed:
return a_width + 1, True return Shape(a_width + 1, True)
else: else:
return a_width, a_signed return Shape(a_width, a_signed)
if self.operator in ("b", "r|", "r&", "r^"): if self.operator in ("b", "r|", "r&", "r^"):
return 1, False return Shape(1, False)
elif len(op_shapes) == 2: elif len(op_shapes) == 2:
(a_width, a_signed), (b_width, b_signed) = op_shapes (a_width, a_signed), (b_width, b_signed) = op_shapes
if self.operator == "+" or self.operator == "-": if self.operator in ("+", "-"):
width, signed = _bitwise_binary_shape(*op_shapes) width, signed = _bitwise_binary_shape(*op_shapes)
return width + 1, signed return Shape(width + 1, signed)
if self.operator == "*": if self.operator == "*":
return a_width + b_width, a_signed or b_signed return Shape(a_width + b_width, a_signed or b_signed)
if self.operator in ("//", "%"): if self.operator in ("//", "%"):
assert not b_signed assert not b_signed
return a_width, a_signed return Shape(a_width, a_signed)
if self.operator in ("<", "<=", "==", "!=", ">", ">="): if self.operator in ("<", "<=", "==", "!=", ">", ">="):
return 1, False return Shape(1, False)
if self.operator in ("&", "^", "|"): if self.operator in ("&", "^", "|"):
return _bitwise_binary_shape(*op_shapes) return _bitwise_binary_shape(*op_shapes)
if self.operator == "<<": if self.operator == "<<":
@ -512,13 +553,13 @@ class Operator(Value):
extra = 2 ** (b_width - 1) - 1 extra = 2 ** (b_width - 1) - 1
else: else:
extra = 2 ** (b_width) - 1 extra = 2 ** (b_width) - 1
return a_width + extra, a_signed return Shape(a_width + extra, a_signed)
if self.operator == ">>": if self.operator == ">>":
if b_signed: if b_signed:
extra = 2 ** (b_width - 1) extra = 2 ** (b_width - 1)
else: else:
extra = 0 extra = 0
return a_width + extra, a_signed return Shape(a_width + extra, a_signed)
elif len(op_shapes) == 3: elif len(op_shapes) == 3:
if self.operator == "m": if self.operator == "m":
s_shape, a_shape, b_shape = op_shapes s_shape, a_shape, b_shape = op_shapes
@ -581,7 +622,7 @@ class Slice(Value):
self.end = end self.end = end
def shape(self): def shape(self):
return self.end - self.start, False return Shape(self.end - self.start)
def _lhs_signals(self): def _lhs_signals(self):
return self.value._lhs_signals() return self.value._lhs_signals()
@ -608,7 +649,7 @@ class Part(Value):
self.stride = stride self.stride = stride
def shape(self): def shape(self):
return self.width, False return Shape(self.width)
def _lhs_signals(self): def _lhs_signals(self):
return self.value._lhs_signals() return self.value._lhs_signals()
@ -651,7 +692,7 @@ class Cat(Value):
self.parts = [Value.cast(v) for v in flatten(args)] self.parts = [Value.cast(v) for v in flatten(args)]
def shape(self): def shape(self):
return sum(len(part) for part in self.parts), False return Shape(sum(len(part) for part in self.parts))
def _lhs_signals(self): def _lhs_signals(self):
return union((part._lhs_signals() for part in self.parts), start=ValueSet()) return union((part._lhs_signals() for part in self.parts), start=ValueSet())
@ -701,7 +742,7 @@ class Repl(Value):
self.count = count self.count = count
def shape(self): def shape(self):
return len(self.value) * self.count, False return Shape(len(self.value) * self.count)
def _rhs_signals(self): def _rhs_signals(self):
return self.value._rhs_signals() return self.value._rhs_signals()
@ -792,13 +833,7 @@ class Signal(Value, DUID):
else: else:
if not (min is None and max is None): if not (min is None and max is None):
raise ValueError("Only one of bits/signedness or bounds may be specified") raise ValueError("Only one of bits/signedness or bounds may be specified")
if isinstance(shape, int): self.width, self.signed = Shape.cast(shape)
self.width, self.signed = shape, False
else:
self.width, self.signed = shape
if not isinstance(self.width, int) or self.width < 0:
raise TypeError("Width must be a non-negative integer, not {!r}".format(self.width))
reset_width = bits_for(reset, self.signed) reset_width = bits_for(reset, self.signed)
if reset != 0 and reset_width > self.width: if reset != 0 and reset_width > self.width:
@ -829,14 +864,7 @@ class Signal(Value, DUID):
That is, for any given ``range(*args)``, ``Signal.range(*args)`` can represent any That is, for any given ``range(*args)``, ``Signal.range(*args)`` can represent any
``x for x in range(*args)``. ``x for x in range(*args)``.
""" """
value_range = range(*args) return cls(Shape.cast(range(*args)), src_loc_at=1 + src_loc_at, **kwargs)
if len(value_range) > 0:
signed = value_range.start < 0 or (value_range.stop - value_range.step) < 0
else:
signed = value_range.start < 0
width = max(bits_for(value_range.start, signed),
bits_for(value_range.stop - value_range.step, signed))
return cls((width, signed), src_loc_at=1 + src_loc_at, **kwargs)
@classmethod @classmethod
def enum(cls, enum_type, *, src_loc_at=0, **kwargs): def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
@ -849,7 +877,7 @@ class Signal(Value, DUID):
""" """
if not issubclass(enum_type, Enum): if not issubclass(enum_type, Enum):
raise TypeError("Type {!r} is not an enumeration") raise TypeError("Type {!r} is not an enumeration")
return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs) return cls(Shape.cast(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
@classmethod @classmethod
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs): def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
@ -885,7 +913,7 @@ class Signal(Value, DUID):
self.width = value self.width = value
def shape(self): def shape(self):
return self.width, self.signed return Shape(self.width, self.signed)
def _lhs_signals(self): def _lhs_signals(self):
return ValueSet((self,)) return ValueSet((self,))
@ -919,7 +947,7 @@ class ClockSignal(Value):
self.domain = domain self.domain = domain
def shape(self): def shape(self):
return 1, False return Shape(1)
def _lhs_signals(self): def _lhs_signals(self):
return ValueSet((self,)) return ValueSet((self,))
@ -956,7 +984,7 @@ class ResetSignal(Value):
self.allow_reset_less = allow_reset_less self.allow_reset_less = allow_reset_less
def shape(self): def shape(self):
return 1, False return Shape(1)
def _lhs_signals(self): def _lhs_signals(self):
return ValueSet((self,)) return ValueSet((self,))
@ -1077,7 +1105,7 @@ class ArrayProxy(Value):
for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()): for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()):
width = max(width, elem_width + elem_signed) width = max(width, elem_width + elem_signed)
signed = max(signed, elem_signed) signed = max(signed, elem_signed)
return width, signed return Shape(width, signed)
def _lhs_signals(self): def _lhs_signals(self):
signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet()) signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet())
@ -1195,7 +1223,7 @@ class Initial(Value):
super().__init__(src_loc_at=1 + src_loc_at) super().__init__(src_loc_at=1 + src_loc_at)
def shape(self): def shape(self):
return (1, False) return Shape(1)
def _rhs_signals(self): def _rhs_signals(self):
return ValueSet((self,)) return ValueSet((self,))

View file

@ -5,7 +5,6 @@ from functools import reduce
from .. import tracer from .. import tracer
from ..tools import union from ..tools import union
from .ast import * from .ast import *
from .ast import _enum_shape
__all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"] __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
@ -46,17 +45,16 @@ class Layout:
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError("Field {!r} has invalid name: should be a string" raise TypeError("Field {!r} has invalid name: should be a string"
.format(field)) .format(field))
if isinstance(shape, type) and issubclass(shape, Enum): if not isinstance(shape, Layout):
shape = _enum_shape(shape) try:
if not isinstance(shape, (int, tuple, Layout)): shape = Shape.cast(shape)
raise TypeError("Field {!r} has invalid shape: should be an int, tuple, Enum, or " except Exception as error:
"list of fields of a nested record" raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
.format(field)) "or a list of fields of a nested record"
.format(field))
if name in self.fields: if name in self.fields:
raise NameError("Field {!r} has a name that is already present in the layout" raise NameError("Field {!r} has a name that is already present in the layout"
.format(field)) .format(field))
if isinstance(shape, int):
shape = (shape, False)
self.fields[name] = (shape, direction) self.fields[name] = (shape, direction)
def __getitem__(self, item): def __getitem__(self, item):
@ -159,7 +157,7 @@ class Record(Value):
return super().__getitem__(item) return super().__getitem__(item)
def shape(self): def shape(self):
return sum(len(f) for f in self.fields.values()), False return Shape(sum(len(f) for f in self.fields.values()))
def _lhs_signals(self): def _lhs_signals(self):
return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())

View file

@ -22,6 +22,105 @@ class StringEnum(Enum):
BAR = "b" BAR = "b"
class ShapeTestCase(FHDLTestCase):
def test_make(self):
s1 = Shape()
self.assertEqual(s1.width, 1)
self.assertEqual(s1.signed, False)
s2 = Shape(signed=True)
self.assertEqual(s2.width, 1)
self.assertEqual(s2.signed, True)
s3 = Shape(3, True)
self.assertEqual(s3.width, 3)
self.assertEqual(s3.signed, True)
def test_make_wrong(self):
with self.assertRaises(TypeError,
msg="Width must be a non-negative integer, not -1"):
Shape(-1)
def test_tuple(self):
width, signed = Shape()
self.assertEqual(width, 1)
self.assertEqual(signed, False)
def test_unsigned(self):
s1 = unsigned(2)
self.assertIsInstance(s1, Shape)
self.assertEqual(s1.width, 2)
self.assertEqual(s1.signed, False)
def test_signed(self):
s1 = signed(2)
self.assertIsInstance(s1, Shape)
self.assertEqual(s1.width, 2)
self.assertEqual(s1.signed, True)
def test_cast_int(self):
s1 = Shape.cast(2)
self.assertEqual(s1.width, 2)
self.assertEqual(s1.signed, False)
def test_cast_int_wrong(self):
with self.assertRaises(TypeError,
msg="Width must be a non-negative integer, not -1"):
Shape.cast(-1)
def test_cast_tuple(self):
s1 = Shape.cast((1, False))
self.assertEqual(s1.width, 1)
self.assertEqual(s1.signed, False)
s2 = Shape.cast((3, True))
self.assertEqual(s2.width, 3)
self.assertEqual(s2.signed, True)
def test_cast_tuple_wrong(self):
with self.assertRaises(TypeError,
msg="Width must be a non-negative integer, not -1"):
Shape.cast((-1, True))
def test_cast_range(self):
s1 = Shape.cast(range(0, 8))
self.assertEqual(s1.width, 3)
self.assertEqual(s1.signed, False)
s2 = Shape.cast(range(0, 9))
self.assertEqual(s2.width, 4)
self.assertEqual(s2.signed, False)
s3 = Shape.cast(range(-7, 8))
self.assertEqual(s3.width, 4)
self.assertEqual(s3.signed, True)
s4 = Shape.cast(range(0, 1))
self.assertEqual(s4.width, 1)
self.assertEqual(s4.signed, False)
s5 = Shape.cast(range(-1, 0))
self.assertEqual(s5.width, 1)
self.assertEqual(s5.signed, True)
s6 = Shape.cast(range(0, 0))
self.assertEqual(s6.width, 0)
self.assertEqual(s6.signed, False)
s7 = Shape.cast(range(-1, -1))
self.assertEqual(s7.width, 0)
self.assertEqual(s7.signed, True)
def test_cast_enum(self):
s1 = Shape.cast(UnsignedEnum)
self.assertEqual(s1.width, 2)
self.assertEqual(s1.signed, False)
s2 = Shape.cast(SignedEnum)
self.assertEqual(s2.width, 2)
self.assertEqual(s2.signed, True)
def test_cast_enum_bad(self):
with self.assertRaises(TypeError,
msg="Only enumerations with integer values can be used as value shapes"):
Shape.cast(StringEnum)
def test_cast_bad(self):
with self.assertRaises(TypeError,
msg="Object 'foo' cannot be used as value shape"):
Shape.cast("foo")
class ValueTestCase(FHDLTestCase): class ValueTestCase(FHDLTestCase):
def test_cast(self): def test_cast(self):
self.assertIsInstance(Value.cast(0), Const) self.assertIsInstance(Value.cast(0), Const)
@ -29,7 +128,7 @@ class ValueTestCase(FHDLTestCase):
c = Const(0) c = Const(0)
self.assertIs(Value.cast(c), c) self.assertIs(Value.cast(c), c)
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
msg="Object 'str' is not an nMigen value"): msg="Object 'str' cannot be converted to an nMigen value"):
Value.cast("str") Value.cast("str")
def test_cast_enum(self): def test_cast_enum(self):
@ -42,7 +141,7 @@ class ValueTestCase(FHDLTestCase):
def test_cast_enum_wrong(self): def test_cast_enum_wrong(self):
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
msg="Only enumerations with integer values can be converted to nMigen values"): msg="Only enumerations with integer values can be used as value shapes"):
Value.cast(StringEnum.FOO) Value.cast(StringEnum.FOO)
def test_bool(self): def test_bool(self):
@ -97,11 +196,13 @@ class ValueTestCase(FHDLTestCase):
class ConstTestCase(FHDLTestCase): class ConstTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
self.assertEqual(Const(0).shape(), (1, False)) self.assertEqual(Const(0).shape(), (1, False))
self.assertIsInstance(Const(0).shape(), Shape)
self.assertEqual(Const(1).shape(), (1, False)) self.assertEqual(Const(1).shape(), (1, False))
self.assertEqual(Const(10).shape(), (4, False)) self.assertEqual(Const(10).shape(), (4, False))
self.assertEqual(Const(-10).shape(), (5, True)) self.assertEqual(Const(-10).shape(), (5, True))
self.assertEqual(Const(1, 4).shape(), (4, False)) self.assertEqual(Const(1, 4).shape(), (4, False))
self.assertEqual(Const(-1, 4).shape(), (4, True))
self.assertEqual(Const(1, (4, True)).shape(), (4, True)) self.assertEqual(Const(1, (4, True)).shape(), (4, True))
self.assertEqual(Const(0, (0, False)).shape(), (0, False)) self.assertEqual(Const(0, (0, False)).shape(), (0, False))
@ -380,6 +481,7 @@ class SliceTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
s1 = Const(10)[2] s1 = Const(10)[2]
self.assertEqual(s1.shape(), (1, False)) self.assertEqual(s1.shape(), (1, False))
self.assertIsInstance(s1.shape(), Shape)
s2 = Const(-10)[0:2] s2 = Const(-10)[0:2]
self.assertEqual(s2.shape(), (2, False)) self.assertEqual(s2.shape(), (2, False))
@ -423,6 +525,7 @@ class BitSelectTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
s1 = self.c.bit_select(self.s, 2) s1 = self.c.bit_select(self.s, 2)
self.assertEqual(s1.shape(), (2, False)) self.assertEqual(s1.shape(), (2, False))
self.assertIsInstance(s1.shape(), Shape)
s2 = self.c.bit_select(self.s, 0) s2 = self.c.bit_select(self.s, 0)
self.assertEqual(s2.shape(), (0, False)) self.assertEqual(s2.shape(), (0, False))
@ -447,6 +550,7 @@ class WordSelectTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
s1 = self.c.word_select(self.s, 2) s1 = self.c.word_select(self.s, 2)
self.assertEqual(s1.shape(), (2, False)) self.assertEqual(s1.shape(), (2, False))
self.assertIsInstance(s1.shape(), Shape)
def test_stride(self): def test_stride(self):
s1 = self.c.word_select(self.s, 2) s1 = self.c.word_select(self.s, 2)
@ -467,6 +571,7 @@ class CatTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
c0 = Cat() c0 = Cat()
self.assertEqual(c0.shape(), (0, False)) self.assertEqual(c0.shape(), (0, False))
self.assertIsInstance(c0.shape(), Shape)
c1 = Cat(Const(10)) c1 = Cat(Const(10))
self.assertEqual(c1.shape(), (4, False)) self.assertEqual(c1.shape(), (4, False))
c2 = Cat(Const(10), Const(1)) c2 = Cat(Const(10), Const(1))
@ -483,6 +588,7 @@ class ReplTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
s1 = Repl(Const(10), 3) s1 = Repl(Const(10), 3)
self.assertEqual(s1.shape(), (12, False)) self.assertEqual(s1.shape(), (12, False))
self.assertIsInstance(s1.shape(), Shape)
s2 = Repl(Const(10), 0) s2 = Repl(Const(10), 0)
self.assertEqual(s2.shape(), (0, False)) self.assertEqual(s2.shape(), (0, False))
@ -561,6 +667,7 @@ class SignalTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
s1 = Signal() s1 = Signal()
self.assertEqual(s1.shape(), (1, False)) self.assertEqual(s1.shape(), (1, False))
self.assertIsInstance(s1.shape(), Shape)
s2 = Signal(2) s2 = Signal(2)
self.assertEqual(s2.shape(), (2, False)) self.assertEqual(s2.shape(), (2, False))
s3 = Signal((2, False)) s3 = Signal((2, False))
@ -578,7 +685,7 @@ class SignalTestCase(FHDLTestCase):
s9 = Signal.range(-20, 16) s9 = Signal.range(-20, 16)
self.assertEqual(s9.shape(), (6, True)) self.assertEqual(s9.shape(), (6, True))
s10 = Signal.range(0) s10 = Signal.range(0)
self.assertEqual(s10.shape(), (1, False)) self.assertEqual(s10.shape(), (0, False))
s11 = Signal.range(1) s11 = Signal.range(1)
self.assertEqual(s11.shape(), (1, False)) self.assertEqual(s11.shape(), (1, False))
# deprecated # deprecated
@ -692,7 +799,9 @@ class ClockSignalTestCase(FHDLTestCase):
ClockSignal(1) ClockSignal(1)
def test_shape(self): def test_shape(self):
self.assertEqual(ClockSignal().shape(), (1, False)) s1 = ClockSignal()
self.assertEqual(s1.shape(), (1, False))
self.assertIsInstance(s1.shape(), Shape)
def test_repr(self): def test_repr(self):
s1 = ClockSignal() s1 = ClockSignal()
@ -716,7 +825,9 @@ class ResetSignalTestCase(FHDLTestCase):
ResetSignal(1) ResetSignal(1)
def test_shape(self): def test_shape(self):
self.assertEqual(ResetSignal().shape(), (1, False)) s1 = ResetSignal()
self.assertEqual(s1.shape(), (1, False))
self.assertIsInstance(s1.shape(), Shape)
def test_repr(self): def test_repr(self):
s1 = ResetSignal() s1 = ResetSignal()
@ -743,6 +854,7 @@ class UserValueTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
uv = MockUserValue(1) uv = MockUserValue(1)
self.assertEqual(uv.shape(), (1, False)) self.assertEqual(uv.shape(), (1, False))
self.assertIsInstance(uv.shape(), Shape)
uv.lowered = 2 uv.lowered = 2
self.assertEqual(uv.shape(), (1, False)) self.assertEqual(uv.shape(), (1, False))
self.assertEqual(uv.lower_count, 1) self.assertEqual(uv.lower_count, 1)

View file

@ -41,6 +41,12 @@ class LayoutTestCase(FHDLTestCase):
self.assertEqual(layout["enum"], ((2, False), DIR_NONE)) self.assertEqual(layout["enum"], ((2, False), DIR_NONE))
self.assertEqual(layout["enum_dir"], ((2, False), DIR_FANOUT)) self.assertEqual(layout["enum_dir"], ((2, False), DIR_FANOUT))
def test_range_field(self):
layout = Layout.wrap([
("range", range(0, 7)),
])
self.assertEqual(layout["range"], ((3, False), DIR_NONE))
def test_slice_tuple(self): def test_slice_tuple(self):
layout = Layout.wrap([ layout = Layout.wrap([
("a", 1), ("a", 1),
@ -77,8 +83,8 @@ class LayoutTestCase(FHDLTestCase):
def test_wrong_shape(self): def test_wrong_shape(self):
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
msg="Field ('a', 'x') has invalid shape: should be an int, tuple, Enum, or " msg="Field ('a', 'x') has invalid shape: should be castable to Shape or "
"list of fields of a nested record"): "a list of fields of a nested record"):
Layout.wrap([("a", "x")]) Layout.wrap([("a", "x")])