hdl.ast: deprecate Value.part, add Value.{bit,word}_select.

Fixes #148.
This commit is contained in:
whitequark 2019-08-03 13:05:41 +00:00
parent bcdc280a87
commit 94e13effad
6 changed files with 105 additions and 29 deletions

View file

@ -177,11 +177,12 @@ class _RHSValueCompiler(_ValueCompiler):
return lambda state: normalize((arg(state) >> shift) & mask, shape) return lambda state: normalize((arg(state) >> shift) & mask, shape)
def on_Part(self, value): def on_Part(self, value):
shape = value.shape() shape = value.shape()
arg = self(value.value) arg = self(value.value)
shift = self(value.offset) shift = self(value.offset)
mask = (1 << value.width) - 1 mask = (1 << value.width) - 1
return lambda state: normalize((arg(state) >> shift(state)) & mask, shape) stride = value.stride
return lambda state: normalize((arg(state) >> shift(state) * stride) & mask, shape)
def on_Cat(self, value): def on_Cat(self, value):
shape = value.shape() shape = value.shape()
@ -260,13 +261,14 @@ class _LHSValueCompiler(_ValueCompiler):
return eval return eval
def on_Part(self, value): def on_Part(self, value):
lhs_r = self.rhs_compiler(value.value) lhs_r = self.rhs_compiler(value.value)
lhs_l = self(value.value) lhs_l = self(value.value)
shift = self.rhs_compiler(value.offset) shift = self.rhs_compiler(value.offset)
mask = (1 << value.width) - 1 mask = (1 << value.width) - 1
stride = value.stride
def eval(state, rhs): def eval(state, rhs):
lhs_value = lhs_r(state) lhs_value = lhs_r(state)
shift_value = shift(state) shift_value = shift(state) * stride
lhs_value &= ~(mask << shift_value) lhs_value &= ~(mask << shift_value)
lhs_value |= (rhs & mask) << shift_value lhs_value |= (rhs & mask) << shift_value
lhs_l(state, lhs_value) lhs_l(state, lhs_value)

View file

@ -532,6 +532,8 @@ class _RHSValueCompiler(_ValueCompiler):
def on_Part(self, value): def on_Part(self, value):
lhs, rhs = value.value, value.offset lhs, rhs = value.value, value.offset
if value.stride != 1:
rhs *= value.stride
lhs_bits, lhs_sign = lhs.shape() lhs_bits, lhs_sign = lhs.shape()
rhs_bits, rhs_sign = rhs.shape() rhs_bits, rhs_sign = rhs.shape()
res_bits, res_sign = value.shape() res_bits, res_sign = value.shape()

View file

@ -146,15 +146,21 @@ class Value(metaclass=ABCMeta):
""" """
return ~premise | conclusion return ~premise | conclusion
# TODO(nmigen-0.2): move this to nmigen.compat and make it a deprecated extension
@deprecated("instead of `.part`, use `.bit_slip`")
def part(self, offset, width): def part(self, offset, width):
"""Indexed part-select. return Part(self, offset, width, src_loc_at=1)
Selects a constant width but variable offset part of a ``Value``. def bit_select(self, offset, width):
"""Part-select with bit granularity.
Selects a constant width but variable offset part of a ``Value``, such that successive
parts overlap by all but 1 bit.
Parameters Parameters
---------- ----------
offset : Value, in offset : Value, in
start point of the selected bits index of first selected bit
width : int width : int
number of selected bits number of selected bits
@ -163,7 +169,27 @@ class Value(metaclass=ABCMeta):
Part, out Part, out
Selected part of the ``Value`` Selected part of the ``Value``
""" """
return Part(self, offset, width, src_loc_at=1) return Part(self, offset, width, stride=1, src_loc_at=1)
def word_select(self, offset, width):
"""Part-select with word granularity.
Selects a constant width but variable offset part of a ``Value``, such that successive
parts do not overlap.
Parameters
----------
offset : Value, in
index of first selected word
width : int
number of selected bits
Returns
-------
Part, out
Selected part of the ``Value``
"""
return Part(self, offset, width, stride=width, src_loc_at=1)
def eq(self, value): def eq(self, value):
"""Assignment. """Assignment.
@ -434,14 +460,17 @@ class Slice(Value):
@final @final
class Part(Value): class Part(Value):
def __init__(self, value, offset, width, *, src_loc_at=0): def __init__(self, value, offset, width, stride=1, *, src_loc_at=0):
if not isinstance(width, int) or width < 0: if not isinstance(width, int) or width < 0:
raise TypeError("Part width must be a non-negative integer, not '{!r}'".format(width)) raise TypeError("Part width must be a non-negative integer, not '{!r}'".format(width))
if not isinstance(stride, int) or stride <= 0:
raise TypeError("Part stride must be a positive integer, not '{!r}'".format(stride))
super().__init__(src_loc_at=src_loc_at) super().__init__(src_loc_at=src_loc_at)
self.value = value self.value = value
self.offset = Value.wrap(offset) self.offset = Value.wrap(offset)
self.width = width self.width = width
self.stride = stride
def shape(self): def shape(self):
return self.width, False return self.width, False
@ -453,7 +482,8 @@ class Part(Value):
return self.value._rhs_signals() | self.offset._rhs_signals() return self.value._rhs_signals() | self.offset._rhs_signals()
def __repr__(self): def __repr__(self):
return "(part {} {} {})".format(repr(self.value), repr(self.offset), self.width) return "(part {} {} {} {})".format(repr(self.value), repr(self.offset),
self.width, self.stride)
@final @final
@ -1240,7 +1270,7 @@ class ValueKey:
return hash((ValueKey(self.value.value), self.value.start, self.value.end)) return hash((ValueKey(self.value.value), self.value.start, self.value.end))
elif isinstance(self.value, Part): elif isinstance(self.value, Part):
return hash((ValueKey(self.value.value), ValueKey(self.value.offset), return hash((ValueKey(self.value.value), ValueKey(self.value.offset),
self.value.width)) self.value.width, self.value.stride))
elif isinstance(self.value, Cat): elif isinstance(self.value, Cat):
return hash(tuple(ValueKey(o) for o in self.value.parts)) return hash(tuple(ValueKey(o) for o in self.value.parts))
elif isinstance(self.value, ArrayProxy): elif isinstance(self.value, ArrayProxy):
@ -1276,7 +1306,8 @@ class ValueKey:
elif isinstance(self.value, Part): elif isinstance(self.value, Part):
return (ValueKey(self.value.value) == ValueKey(other.value.value) and return (ValueKey(self.value.value) == ValueKey(other.value.value) and
ValueKey(self.value.offset) == ValueKey(other.value.offset) and ValueKey(self.value.offset) == ValueKey(other.value.offset) and
self.value.width == other.value.width) self.value.width == other.value.width and
self.value.stride == other.value.stride)
elif isinstance(self.value, Cat): elif isinstance(self.value, Cat):
return all(ValueKey(a) == ValueKey(b) return all(ValueKey(a) == ValueKey(b)
for a, b in zip(self.value.parts, other.value.parts)) for a, b in zip(self.value.parts, other.value.parts))

View file

@ -157,7 +157,8 @@ class ValueTransformer(ValueVisitor):
return Slice(self.on_value(value.value), value.start, value.end) return Slice(self.on_value(value.value), value.start, value.end)
def on_Part(self, value): def on_Part(self, value):
return Part(self.on_value(value.value), self.on_value(value.offset), value.width) return Part(self.on_value(value.value), self.on_value(value.offset),
value.width, value.stride)
def on_Cat(self, value): def on_Cat(self, value):
return Cat(self.on_value(o) for o in value.parts) return Cat(self.on_value(o) for o in value.parts)

View file

@ -294,24 +294,52 @@ class SliceTestCase(FHDLTestCase):
self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)") self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)")
class PartTestCase(FHDLTestCase): class BitSelectTestCase(FHDLTestCase):
def setUp(self): def setUp(self):
self.c = Const(0, 8) self.c = Const(0, 8)
self.s = Signal(max=self.c.nbits) self.s = Signal(max=self.c.nbits)
def test_shape(self): def test_shape(self):
s1 = self.c.part(self.s, 2) s1 = self.c.bit_select(self.s, 2)
self.assertEqual(s1.shape(), (2, False)) self.assertEqual(s1.shape(), (2, False))
s2 = self.c.part(self.s, 0) s2 = self.c.bit_select(self.s, 0)
self.assertEqual(s2.shape(), (0, False)) self.assertEqual(s2.shape(), (0, False))
def test_stride(self):
s1 = self.c.bit_select(self.s, 2)
self.assertEqual(s1.stride, 1)
def test_width_bad(self): def test_width_bad(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.c.part(self.s, -1) self.c.bit_select(self.s, -1)
def test_repr(self): def test_repr(self):
s = self.c.part(self.s, 2) s = self.c.bit_select(self.s, 2)
self.assertEqual(repr(s), "(part (const 8'd0) (sig s) 2)") self.assertEqual(repr(s), "(part (const 8'd0) (sig s) 2 1)")
class WordSelectTestCase(FHDLTestCase):
def setUp(self):
self.c = Const(0, 8)
self.s = Signal(max=self.c.nbits)
def test_shape(self):
s1 = self.c.word_select(self.s, 2)
self.assertEqual(s1.shape(), (2, False))
def test_stride(self):
s1 = self.c.word_select(self.s, 2)
self.assertEqual(s1.stride, 2)
def test_width_bad(self):
with self.assertRaises(TypeError):
self.c.word_select(self.s, 0)
with self.assertRaises(TypeError):
self.c.word_select(self.s, -1)
def test_repr(self):
s = self.c.word_select(self.s, 2)
self.assertEqual(repr(s), "(part (const 8'd0) (sig s) 2 2)")
class CatTestCase(FHDLTestCase): class CatTestCase(FHDLTestCase):

View file

@ -151,18 +151,30 @@ class SimulatorUnitTestCase(FHDLTestCase):
stmt2 = lambda y, a: y[2:4].eq(a) stmt2 = lambda y, a: y[2:4].eq(a)
self.assertStatement(stmt2, [C(0b01, 2)], C(0b11110111, 8), reset=0b11111011) self.assertStatement(stmt2, [C(0b01, 2)], C(0b11110111, 8), reset=0b11111011)
def test_part(self): def test_bit_select(self):
stmt = lambda y, a, b: y.eq(a.part(b, 3)) stmt = lambda y, a, b: y.eq(a.bit_select(b, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3))
def test_part_lhs(self): def test_bit_select_lhs(self):
stmt = lambda y, a, b: y.part(a, 3).eq(b) stmt = lambda y, a, b: y.bit_select(a, 3).eq(b)
self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111) self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111)
self.assertStatement(stmt, [C(2), C(0b101, 3)], C(0b11110111, 8), reset=0b11111111) self.assertStatement(stmt, [C(2), C(0b101, 3)], C(0b11110111, 8), reset=0b11111111)
self.assertStatement(stmt, [C(3), C(0b110, 3)], C(0b11110111, 8), reset=0b11111111) self.assertStatement(stmt, [C(3), C(0b110, 3)], C(0b11110111, 8), reset=0b11111111)
def test_word_select(self):
stmt = lambda y, a, b: y.eq(a.word_select(b, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(1)], C(0b110, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b010, 3))
def test_word_select_lhs(self):
stmt = lambda y, a, b: y.word_select(a, 3).eq(b)
self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111)
self.assertStatement(stmt, [C(1), C(0b101, 3)], C(0b11101111, 8), reset=0b11111111)
self.assertStatement(stmt, [C(2), C(0b110, 3)], C(0b10111111, 8), reset=0b11111111)
def test_cat(self): def test_cat(self):
stmt = lambda y, *xs: y.eq(Cat(*xs)) stmt = lambda y, *xs: y.eq(Cat(*xs))
self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4)) self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4))