hdl.ast: Test *Castable subclasses on definition.

The __init_subclass__ method fires on class definition rather than use.
It also has the bonus impact that no __new__ method is defined, so the
classes can be correctly detected as mix-in classes by modules such as
enum.
This commit is contained in:
Arusekk 2023-03-22 00:17:21 +01:00 committed by Catherine
parent a0307c343d
commit 5f094a23eb
2 changed files with 17 additions and 30 deletions

View file

@ -41,12 +41,10 @@ class ShapeCastable:
a richer description of the shape than what is supported by the core Amaranth language, yet
still be transparently used with it.
"""
def __new__(cls, *args, **kwargs):
self = super().__new__(cls)
if not hasattr(self, "as_shape"):
def __init_subclass__(cls, **kwargs):
if not hasattr(cls, "as_shape"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
f"the `as_shape` method")
return self
class Shape:
@ -1319,15 +1317,13 @@ class ValueCastable:
from :class:`ValueCastable` is mutable, it is up to the user to ensure that it is not mutated
in a way that changes its representation after the first call to :meth:`as_value`.
"""
def __new__(cls, *args, **kwargs):
self = super().__new__(cls)
if not hasattr(self, "as_value"):
def __init_subclass__(cls, **kwargs):
if not hasattr(cls, "as_value"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
"the `as_value` method")
if not hasattr(self.as_value, "_ValueCastable__memoized"):
if not hasattr(cls.as_value, "_ValueCastable__memoized"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate "
"the `as_value` method with the `ValueCastable.lowermethod` decorator")
return self
@staticmethod
def lowermethod(func):

View file

@ -145,17 +145,14 @@ class MockShapeCastable(ShapeCastable):
return self.dest
class MockShapeCastableNoOverride(ShapeCastable):
def __init__(self):
pass
class ShapeCastableTestCase(FHDLTestCase):
def test_no_override(self):
with self.assertRaisesRegex(TypeError,
r"^Class 'MockShapeCastableNoOverride' deriving from `ShapeCastable` must "
r"override the `as_shape` method$"):
sc = MockShapeCastableNoOverride()
class MockShapeCastableNoOverride(ShapeCastable):
def __init__(self):
pass
def test_cast(self):
sc = MockShapeCastable(unsigned(2))
@ -1135,19 +1132,6 @@ class MockValueCastableChanges(ValueCastable):
return Signal(self.width)
class MockValueCastableNotDecorated(ValueCastable):
def __init__(self):
pass
def as_value(self):
return Signal()
class MockValueCastableNoOverride(ValueCastable):
def __init__(self):
pass
class MockValueCastableCustomGetattr(ValueCastable):
def __init__(self):
pass
@ -1165,13 +1149,20 @@ class ValueCastableTestCase(FHDLTestCase):
with self.assertRaisesRegex(TypeError,
r"^Class 'MockValueCastableNotDecorated' deriving from `ValueCastable` must "
r"decorate the `as_value` method with the `ValueCastable.lowermethod` decorator$"):
vc = MockValueCastableNotDecorated()
class MockValueCastableNotDecorated(ValueCastable):
def __init__(self):
pass
def as_value(self):
return Signal()
def test_no_override(self):
with self.assertRaisesRegex(TypeError,
r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must "
r"override the `as_value` method$"):
vc = MockValueCastableNoOverride()
class MockValueCastableNoOverride(ValueCastable):
def __init__(self):
pass
def test_memoized(self):
vc = MockValueCastableChanges(1)