Implement RFC 51: Add ShapeCastable.from_bits and amaranth.lib.data.Const.

Co-authored-by: Catherine <whitequark@whitequark.org>
This commit is contained in:
Wanda 2024-03-14 20:07:53 +01:00 committed by Catherine
parent 598cf8db28
commit d6bf47d549
8 changed files with 705 additions and 109 deletions

View file

@ -178,6 +178,9 @@ class MockShapeCastable(ShapeCastable):
def const(self, init):
return Const(init, self.dest)
def from_bits(self, bits):
return bits
class ShapeCastableTestCase(FHDLTestCase):
def test_no_override(self):
@ -208,6 +211,25 @@ class ShapeCastableTestCase(FHDLTestCase):
r"^Can't instantiate abstract class ShapeCastable$"):
ShapeCastable()
def test_no_from_bits(self):
with self.assertWarnsRegex(DeprecationWarning,
r"^Class 'MockShapeCastableNoFromBits' deriving from 'ShapeCastable' does "
r"not override the 'from_bits' method, which will be required in Amaranth 0.6$"):
class MockShapeCastableNoFromBits(ShapeCastable):
def __init__(self, dest):
self.dest = dest
def as_shape(self):
return self.dest
def __call__(self, value):
return value
def const(self, init):
return Const(init, self.dest)
self.assertEqual(MockShapeCastableNoFromBits(unsigned(2)).from_bits(123), 123)
class ShapeLikeTestCase(FHDLTestCase):
def test_construct(self):
@ -514,6 +536,9 @@ class ConstTestCase(FHDLTestCase):
def const(self, init):
return MockConstValue(init)
def from_bits(self, bits):
return bits
s = Const(10, MockConstShape())
self.assertIsInstance(s, MockConstValue)
self.assertEqual(s.value, 10)
@ -1186,6 +1211,9 @@ class SignalTestCase(FHDLTestCase):
def const(self, init):
return int(init, 16)
def from_bits(self, bits):
return bits
s1 = Signal(CastableFromHex(), init="aa")
self.assertEqual(s1.init, 0xaa)

View file

@ -22,6 +22,9 @@ class MockShapeCastable(ShapeCastable):
def const(self, init):
return Const(init, self.shape)
def from_bits(self, bits):
return bits
class FieldTestCase(TestCase):
def test_construct(self):
@ -417,6 +420,9 @@ class LayoutTestCase(FHDLTestCase):
def const(self, init):
return int(init, 16)
def from_bits(self, bits):
return bits
sl = data.StructLayout({"f": CastableFromHex()})
self.assertRepr(sl.const({"f": "aa"}).as_value(), "(const 8'd170)")
@ -467,13 +473,13 @@ class ViewTestCase(FHDLTestCase):
def test_layout_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^View layout must be a layout, not <.+?>$"):
r"^Layout of a view must be a Layout, not <.+?>$"):
data.View(object(), Signal(1))
def test_layout_conflict_with_attr(self):
with self.assertWarnsRegex(SyntaxWarning,
r"^View layout includes a field 'as_value' that will be shadowed by the view "
r"attribute 'amaranth\.lib\.data\.View\.as_value'$"):
r"^Layout of a view includes a field 'as_value' that will be shadowed by "
r"the attribute 'amaranth\.lib\.data\.View\.as_value'$"):
data.View(data.StructLayout({"as_value": unsigned(1)}), Signal(1))
def test_layout_conflict_with_attr_derived(self):
@ -481,20 +487,20 @@ class ViewTestCase(FHDLTestCase):
def foo(self):
pass
with self.assertWarnsRegex(SyntaxWarning,
r"^View layout includes a field 'foo' that will be shadowed by the view "
r"attribute 'tests\.test_lib_data\.ViewTestCase\."
r"^Layout of a view includes a field 'foo' that will be shadowed by "
r"the attribute 'tests\.test_lib_data\.ViewTestCase\."
r"test_layout_conflict_with_attr_derived\.<locals>.DerivedView\.foo'$"):
DerivedView(data.StructLayout({"foo": unsigned(1)}), Signal(1))
def test_target_wrong_type(self):
with self.assertRaisesRegex(TypeError,
r"^View target must be a value-castable object, not <.+?>$"):
r"^Target of a view must be a value-castable object, not <.+?>$"):
data.View(data.StructLayout({}), object())
def test_target_wrong_size(self):
with self.assertRaisesRegex(ValueError,
r"^View target is 2 bit\(s\) wide, which is not compatible with the 1 bit\(s\) "
r"wide view layout$"):
r"^Target of a view is 2 bit\(s\) wide, which is not compatible with its 1 bit\(s\) "
r"wide layout$"):
data.View(data.StructLayout({"a": unsigned(1)}), Signal(2))
def test_getitem(self):
@ -540,6 +546,9 @@ class ViewTestCase(FHDLTestCase):
def const(self, init):
return Const(init, 2)
def from_bits(self, bits):
return bits
v = Signal(data.StructLayout({
"f": Reverser()
}))
@ -557,6 +566,9 @@ class ViewTestCase(FHDLTestCase):
def const(self, init):
return Const(init, 2)
def from_bits(self, bits):
return bits
v = Signal(data.StructLayout({
"f": WrongCastable()
}))
@ -606,14 +618,13 @@ class ViewTestCase(FHDLTestCase):
def test_attr_wrong_missing(self):
with self.assertRaisesRegex(AttributeError,
r"^View of \(sig \$signal\) does not have a field 'a'; "
r"did you mean one of: 'b', 'c'\?$"):
r"^View with layout .* does not have a field 'a'; did you mean one of: 'b', 'c'\?$"):
Signal(data.StructLayout({"b": unsigned(1), "c": signed(1)})).a
def test_attr_wrong_reserved(self):
with self.assertRaisesRegex(AttributeError,
r"^View of \(sig \$signal\) field '_c' has a reserved name "
r"and may only be accessed by indexing$"):
r"^Field '_c' of view with layout .* has a reserved name and may only be accessed "
r"by indexing$"):
Signal(data.StructLayout({"_c": signed(1)}))._c
def test_signal_like(self):
@ -623,13 +634,13 @@ class ViewTestCase(FHDLTestCase):
def test_bug_837_array_layout_getitem_str(self):
with self.assertRaisesRegex(TypeError,
r"^Views with array layout may only be indexed with an integer or a value, "
r"^View with array layout may only be indexed with an integer or a value, "
r"not 'init'$"):
Signal(data.ArrayLayout(unsigned(1), 1), init=[0])["init"]
def test_bug_837_array_layout_getattr(self):
with self.assertRaisesRegex(AttributeError,
r"^View of \(sig \$signal\) with an array layout does not have fields$"):
r"^View with an array layout does not have fields$"):
Signal(data.ArrayLayout(unsigned(1), 1), init=[0]).init
def test_eq(self):
@ -639,16 +650,20 @@ class ViewTestCase(FHDLTestCase):
self.assertRepr(s1 == s2, "(== (sig s1) (sig s2))")
self.assertRepr(s1 != s2, "(!= (sig s1) (sig s2))")
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
r"^View with layout .* can only be compared to another view or constant "
r"with the same layout, not .*$"):
s1 == s3
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
r"^View with layout .* can only be compared to another view or constant "
r"with the same layout, not .*$"):
s1 != s3
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
r"^View with layout .* can only be compared to another view or constant "
r"with the same layout, not .*$"):
s1 == Const(0, 2)
with self.assertRaisesRegex(TypeError,
r"^View of .* can only be compared to another view of the same layout, not .*$"):
r"^View with layout .* can only be compared to another view or constant "
r"with the same layout, not .*$"):
s1 != Const(0, 2)
def test_operator(self):
@ -690,6 +705,251 @@ class ViewTestCase(FHDLTestCase):
self.assertRepr(s1, "View(StructLayout({'a': unsigned(2)}), (sig s1))")
class ConstTestCase(FHDLTestCase):
def test_construct(self):
c = data.Const(data.StructLayout({"a": unsigned(1), "b": unsigned(2)}), 5)
self.assertRepr(Value.cast(c), "(const 3'd5)")
self.assertEqual(c.shape(), data.StructLayout({"a": unsigned(1), "b": unsigned(2)}))
self.assertEqual(c.as_bits(), 5)
self.assertEqual(c["a"], 1)
self.assertEqual(c["b"], 2)
def test_construct_const(self):
c = Const({"a": 1, "b": 2}, data.StructLayout({"a": unsigned(1), "b": unsigned(2)}))
self.assertRepr(Const.cast(c), "(const 3'd5)")
self.assertEqual(c.a, 1)
self.assertEqual(c.b, 2)
def test_layout_wrong(self):
with self.assertRaisesRegex(TypeError,
r"^Layout of a constant must be a Layout, not <.+?>$"):
data.Const(object(), 1)
def test_layout_conflict_with_attr(self):
with self.assertWarnsRegex(SyntaxWarning,
r"^Layout of a constant includes a field 'as_value' that will be shadowed by "
r"the attribute 'amaranth\.lib\.data\.Const\.as_value'$"):
data.Const(data.StructLayout({"as_value": unsigned(1)}), 1)
def test_layout_conflict_with_attr_derived(self):
class DerivedConst(data.Const):
def foo(self):
pass
with self.assertWarnsRegex(SyntaxWarning,
r"^Layout of a constant includes a field 'foo' that will be shadowed by "
r"the attribute 'tests\.test_lib_data\.ConstTestCase\."
r"test_layout_conflict_with_attr_derived\.<locals>.DerivedConst\.foo'$"):
DerivedConst(data.StructLayout({"foo": unsigned(1)}), 1)
def test_target_wrong_type(self):
with self.assertRaisesRegex(TypeError,
r"^Target of a constant must be an int, not <.+?>$"):
data.Const(data.StructLayout({}), object())
def test_target_wrong_value(self):
with self.assertRaisesRegex(ValueError,
r"^Target of a constant does not fit in 1 bit\(s\)$"):
data.Const(data.StructLayout({"a": unsigned(1)}), 2)
def test_getitem(self):
l = data.StructLayout({
"u": unsigned(1),
"v": unsigned(1)
})
v = data.Const(data.StructLayout({
"a": unsigned(2),
"s": data.StructLayout({
"b": unsigned(1),
"c": unsigned(3)
}),
"p": 1,
"q": signed(1),
"r": data.ArrayLayout(unsigned(2), 2),
"t": data.ArrayLayout(data.StructLayout({
"u": unsigned(1),
"v": unsigned(1)
}), 2),
}), 0xabcd)
cv = Value.cast(v)
i = Signal(1)
self.assertEqual(cv.shape(), unsigned(16))
self.assertEqual(v["a"], 1)
self.assertEqual(v["s"]["b"], 1)
self.assertEqual(v["s"]["c"], 1)
self.assertEqual(v["p"], 1)
self.assertEqual(v["q"], -1)
self.assertEqual(v["r"][0], 3)
self.assertEqual(v["r"][1], 2)
self.assertRepr(v["r"][i], "(part (const 4'd11) (sig i) 2 2)")
self.assertEqual(v["t"][0], data.Const(l, 2))
self.assertEqual(v["t"][1], data.Const(l, 2))
self.assertEqual(v["t"][0]["u"], 0)
self.assertEqual(v["t"][1]["v"], 1)
def test_getitem_custom_call(self):
class Reverser(ShapeCastable):
def as_shape(self):
return unsigned(2)
def __call__(self, value):
raise NotImplementedError
def const(self, init):
raise NotImplementedError
def from_bits(self, bits):
return float(bits) / 2
v = data.Const(data.StructLayout({
"f": Reverser()
}), 3)
self.assertEqual(v.f, 1.5)
def test_index_wrong_missing(self):
with self.assertRaisesRegex(KeyError,
r"^'a'$"):
data.Const(data.StructLayout({}), 0)["a"]
def test_index_wrong_struct_dynamic(self):
with self.assertRaisesRegex(TypeError,
r"^Only constants with array layout, not StructLayout\(\{\}\), may be indexed "
r"with a value$"):
data.Const(data.StructLayout({}), 0)[Signal(1)]
def test_getattr(self):
v = data.Const(data.UnionLayout({
"a": unsigned(2),
"s": data.StructLayout({
"b": unsigned(1),
"c": unsigned(3)
}),
"p": 1,
"q": signed(1),
}), 13)
cv = Const.cast(v)
i = Signal(1)
self.assertEqual(cv.shape(), unsigned(4))
self.assertEqual(v.a, 1)
self.assertEqual(v.s.b, 1)
self.assertEqual(v.s.c, 6)
self.assertEqual(v.p, 1)
self.assertEqual(v.q, -1)
def test_getattr_reserved(self):
v = data.Const(data.UnionLayout({
"_a": unsigned(2)
}), 2)
self.assertEqual(v["_a"], 2)
def test_attr_wrong_missing(self):
with self.assertRaisesRegex(AttributeError,
r"^Constant with layout .* does not have a field 'a'; did you mean one of: "
r"'b', 'c'\?$"):
data.Const(data.StructLayout({"b": unsigned(1), "c": signed(1)}), 0).a
def test_attr_wrong_reserved(self):
with self.assertRaisesRegex(AttributeError,
r"^Field '_c' of constant with layout .* has a reserved name and may only be "
r"accessed by indexing$"):
data.Const(data.StructLayout({"_c": signed(1)}), 0)._c
def test_bug_837_array_layout_getitem_str(self):
with self.assertRaisesRegex(TypeError,
r"^Constant with array layout may only be indexed with an integer or a value, "
r"not 'init'$"):
data.Const(data.ArrayLayout(unsigned(1), 1), 0)["init"]
def test_bug_837_array_layout_getattr(self):
with self.assertRaisesRegex(AttributeError,
r"^Constant with an array layout does not have fields$"):
data.Const(data.ArrayLayout(unsigned(1), 1), 0).init
def test_eq(self):
c1 = data.Const(data.StructLayout({"a": unsigned(2)}), 1)
c2 = data.Const(data.StructLayout({"a": unsigned(2)}), 1)
c3 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)
c4 = data.Const(data.StructLayout({"a": unsigned(1), "b": unsigned(1)}), 2)
s1 = Signal(data.StructLayout({"a": unsigned(2)}))
self.assertTrue(c1 == c2)
self.assertFalse(c1 != c2)
self.assertFalse(c1 == c3)
self.assertTrue(c1 != c3)
self.assertRepr(c1 == s1, "(== (const 2'd1) (sig s1))")
self.assertRepr(c1 != s1, "(!= (const 2'd1) (sig s1))")
self.assertRepr(s1 == c1, "(== (sig s1) (const 2'd1))")
self.assertRepr(s1 != c1, "(!= (sig s1) (const 2'd1))")
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
c1 == c4
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
c1 != c4
with self.assertRaisesRegex(TypeError,
r"^View with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
s1 == c4
with self.assertRaisesRegex(TypeError,
r"^View with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
s1 != c4
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
c4 == s1
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
c4 != s1
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
c1 == Const(0, 2)
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
c1 != Const(0, 2)
def test_operator(self):
s1 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)
s2 = Signal(unsigned(2))
for op in [
operator.__add__,
operator.__sub__,
operator.__mul__,
operator.__floordiv__,
operator.__mod__,
operator.__lshift__,
operator.__rshift__,
operator.__lt__,
operator.__le__,
operator.__gt__,
operator.__ge__,
]:
with self.assertRaisesRegex(TypeError,
r"^Cannot perform arithmetic operations on a lib.data.Const$"):
op(s1, s2)
with self.assertRaisesRegex(TypeError,
r"^Cannot perform arithmetic operations on a lib.data.Const$"):
op(s2, s1)
for op in [
operator.__and__,
operator.__or__,
operator.__xor__,
]:
with self.assertRaisesRegex(TypeError,
r"^Cannot perform bitwise operations on a lib.data.Const$"):
op(s1, s2)
with self.assertRaisesRegex(TypeError,
r"^Cannot perform bitwise operations on a lib.data.Const$"):
op(s2, s1)
def test_repr(self):
s1 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)
self.assertRepr(s1, "Const(StructLayout({'a': unsigned(2)}), 2)")
class StructTestCase(FHDLTestCase):
def test_construct(self):
class S(data.Struct):
@ -815,6 +1075,13 @@ class StructTestCase(FHDLTestCase):
s2 = Signal.like(s1)
self.assertEqual(s2.shape(), S)
def test_from_bits(self):
class S(data.Struct):
a: 1
c = S.from_bits(1)
self.assertIsInstance(c, data.Const)
self.assertEqual(c.a, 1)
class UnionTestCase(FHDLTestCase):
def test_construct(self):

View file

@ -110,6 +110,15 @@ class EnumTestCase(FHDLTestCase):
self.assertRepr(EnumA.const(10), "EnumView(EnumA, (const 8'd10))")
self.assertRepr(EnumA.const(EnumA.A), "EnumView(EnumA, (const 8'd10))")
def test_from_bits(self):
class EnumA(Enum, shape=2):
A = 0
B = 1
C = 2
self.assertIs(EnumA.from_bits(2), EnumA.C)
with self.assertRaises(ValueError):
EnumA.from_bits(3)
def test_shape_implicit_wrong_in_concat(self):
class EnumA(Enum):
A = 0