hdl.ast: add Signal.range(...), to replace Signal(min=..., max=...).

Fixes #196.
This commit is contained in:
whitequark 2019-09-08 12:10:31 +00:00
parent 5e9587bbbd
commit ccfbccc044
2 changed files with 75 additions and 27 deletions

View file

@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import builtins import builtins
import traceback import traceback
import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
from enum import Enum from enum import Enum
@ -627,6 +628,13 @@ class Signal(Value, DUID):
attrs=None, decoder=None, src_loc_at=0): attrs=None, decoder=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at) super().__init__(src_loc_at=src_loc_at)
# TODO(nmigen-0.2): move this to nmigen.compat and make it a deprecated extension
if min is not None or max is not None:
warnings.warn("instead of `Signal(min={min}, max={max})`, "
"use `Signal.range({min}, {max})`"
.format(min=min or 0, max=max or 2),
DeprecationWarning, stacklevel=2 + src_loc_at)
if name is not None and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TypeError("Name must be a string, not '{!r}'".format(name)) raise TypeError("Name must be a string, not '{!r}'".format(name))
self.name = name or tracer.get_var_name(depth=2 + src_loc_at, default="$signal") self.name = name or tracer.get_var_name(depth=2 + src_loc_at, default="$signal")
@ -671,6 +679,23 @@ class Signal(Value, DUID):
else: else:
self.decoder = decoder self.decoder = decoder
@classmethod
def range(cls, *args, src_loc_at=0, **kwargs):
"""Create Signal that can represent a given range.
The parameters to ``Signal.range`` are the same as for the built-in ``range`` function.
That is, for any given ``range(*args)``, ``Signal.range(*args)`` can represent any
``x for x in range(*args)``.
"""
value_range = range(*args)
if len(value_range) > 0:
signed = value_range.start < 0 or (value_range.stop - value_range.step) < 0
else:
signed = value_range.start < 0
nbits = max(bits_for(value_range.start, signed),
bits_for(value_range.stop - value_range.step, signed))
return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
@classmethod @classmethod
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs): def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
"""Create Signal based on another. """Create Signal based on another.

View file

@ -1,3 +1,4 @@
import warnings
from enum import Enum from enum import Enum
from ..hdl.ast import * from ..hdl.ast import *
@ -297,7 +298,7 @@ class SliceTestCase(FHDLTestCase):
class BitSelectTestCase(FHDLTestCase): class BitSelectTestCase(FHDLTestCase):
def setUp(self): def setUp(self):
self.c = Const(0, 8) self.c = Const(0, 8)
self.s = Signal(max=self.c.nbits) self.s = Signal.range(self.c.nbits)
def test_shape(self): def test_shape(self):
s1 = self.c.bit_select(self.s, 2) s1 = self.c.bit_select(self.s, 2)
@ -321,7 +322,7 @@ class BitSelectTestCase(FHDLTestCase):
class WordSelectTestCase(FHDLTestCase): class WordSelectTestCase(FHDLTestCase):
def setUp(self): def setUp(self):
self.c = Const(0, 8) self.c = Const(0, 8)
self.s = Signal(max=self.c.nbits) self.s = Signal.range(self.c.nbits)
def test_shape(self): def test_shape(self):
s1 = self.c.word_select(self.s, 2) s1 = self.c.word_select(self.s, 2)
@ -390,8 +391,8 @@ class ArrayTestCase(FHDLTestCase):
def test_becomes_immutable(self): def test_becomes_immutable(self):
a = Array([1,2,3]) a = Array([1,2,3])
s1 = Signal(max=len(a)) s1 = Signal.range(len(a))
s2 = Signal(max=len(a)) s2 = Signal.range(len(a))
v1 = a[s1] v1 = a[s1]
v2 = a[s2] v2 = a[s2]
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
@ -407,7 +408,7 @@ class ArrayTestCase(FHDLTestCase):
def test_repr(self): def test_repr(self):
a = Array([1,2,3]) a = Array([1,2,3])
self.assertEqual(repr(a), "(array mutable [1, 2, 3])") self.assertEqual(repr(a), "(array mutable [1, 2, 3])")
s = Signal(max=len(a)) s = Signal.range(len(a))
v = a[s] v = a[s]
self.assertEqual(repr(a), "(array [1, 2, 3])") self.assertEqual(repr(a), "(array [1, 2, 3])")
@ -415,8 +416,8 @@ class ArrayTestCase(FHDLTestCase):
class ArrayProxyTestCase(FHDLTestCase): class ArrayProxyTestCase(FHDLTestCase):
def test_index_shape(self): def test_index_shape(self):
m = Array(Array(x * y for y in range(1, 4)) for x in range(1, 4)) m = Array(Array(x * y for y in range(1, 4)) for x in range(1, 4))
a = Signal(max=3) a = Signal.range(3)
b = Signal(max=3) b = Signal.range(3)
v = m[a][b] v = m[a][b]
self.assertEqual(v.shape(), (4, False)) self.assertEqual(v.shape(), (4, False))
@ -424,14 +425,14 @@ class ArrayProxyTestCase(FHDLTestCase):
from collections import namedtuple from collections import namedtuple
pair = namedtuple("pair", ("p", "n")) pair = namedtuple("pair", ("p", "n"))
a = Array(pair(i, -i) for i in range(10)) a = Array(pair(i, -i) for i in range(10))
s = Signal(max=len(a)) s = Signal.range(len(a))
v = a[s] v = a[s]
self.assertEqual(v.p.shape(), (4, False)) self.assertEqual(v.p.shape(), (4, False))
self.assertEqual(v.n.shape(), (6, True)) self.assertEqual(v.n.shape(), (6, True))
def test_repr(self): def test_repr(self):
a = Array([1, 2, 3]) a = Array([1, 2, 3])
s = Signal(max=3) s = Signal.range(3)
v = a[s] v = a[s]
self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))") self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))")
@ -446,30 +447,52 @@ class SignalTestCase(FHDLTestCase):
self.assertEqual(s3.shape(), (2, False)) self.assertEqual(s3.shape(), (2, False))
s4 = Signal((2, True)) s4 = Signal((2, True))
self.assertEqual(s4.shape(), (2, True)) self.assertEqual(s4.shape(), (2, True))
s5 = Signal(max=16) s5 = Signal(0)
self.assertEqual(s5.shape(), (4, False)) self.assertEqual(s5.shape(), (0, False))
s6 = Signal(min=4, max=16) s6 = Signal.range(16)
self.assertEqual(s6.shape(), (4, False)) self.assertEqual(s6.shape(), (4, False))
s7 = Signal(min=-4, max=16) s7 = Signal.range(4, 16)
self.assertEqual(s7.shape(), (5, True)) self.assertEqual(s7.shape(), (4, False))
s8 = Signal(min=-20, max=16) s8 = Signal.range(-4, 16)
self.assertEqual(s8.shape(), (6, True)) self.assertEqual(s8.shape(), (5, True))
s9 = Signal(0) s9 = Signal.range(-20, 16)
self.assertEqual(s9.shape(), (0, False)) self.assertEqual(s9.shape(), (6, True))
s10 = Signal(max=1) s10 = Signal.range(0)
self.assertEqual(s10.shape(), (0, False)) self.assertEqual(s10.shape(), (1, False))
s11 = Signal.range(1)
self.assertEqual(s11.shape(), (1, False))
# deprecated
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
d6 = Signal(max=16)
self.assertEqual(d6.shape(), (4, False))
d7 = Signal(min=4, max=16)
self.assertEqual(d7.shape(), (4, False))
d8 = Signal(min=-4, max=16)
self.assertEqual(d8.shape(), (5, True))
d9 = Signal(min=-20, max=16)
self.assertEqual(d9.shape(), (6, True))
d10 = Signal(max=1)
self.assertEqual(d10.shape(), (0, False))
def test_shape_bad(self): def test_shape_bad(self):
with self.assertRaises(ValueError,
msg="Lower bound 10 should be less or equal to higher bound 4"):
Signal(min=10, max=4)
with self.assertRaises(ValueError,
msg="Only one of bits/signedness or bounds may be specified"):
Signal(2, min=10)
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
msg="Width must be a non-negative integer, not '-10'"): msg="Width must be a non-negative integer, not '-10'"):
Signal(-10) Signal(-10)
def test_min_max_deprecated(self):
with self.assertWarns(DeprecationWarning,
msg="instead of `Signal(min=0, max=10)`, use `Signal.range(0, 10)`"):
Signal(max=10)
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
with self.assertRaises(ValueError,
msg="Lower bound 10 should be less or equal to higher bound 4"):
Signal(min=10, max=4)
with self.assertRaises(ValueError,
msg="Only one of bits/signedness or bounds may be specified"):
Signal(2, min=10)
def test_name(self): def test_name(self):
s1 = Signal() s1 = Signal()
self.assertEqual(s1.name, "s1") self.assertEqual(s1.name, "s1")
@ -500,7 +523,7 @@ class SignalTestCase(FHDLTestCase):
def test_like(self): def test_like(self):
s1 = Signal.like(Signal(4)) s1 = Signal.like(Signal(4))
self.assertEqual(s1.shape(), (4, False)) self.assertEqual(s1.shape(), (4, False))
s2 = Signal.like(Signal(min=-15)) s2 = Signal.like(Signal.range(-15, 1))
self.assertEqual(s2.shape(), (5, True)) self.assertEqual(s2.shape(), (5, True))
s3 = Signal.like(Signal(4, reset=0b111, reset_less=True)) s3 = Signal.like(Signal(4, reset=0b111, reset_less=True))
self.assertEqual(s3.reset, 0b111) self.assertEqual(s3.reset, 0b111)