hdl.ast: Test *Castable subclasses on definition.

The __init_subclass__ method fires on class definition rather than use.
It also has the bonus impact that no __new__ method is defined, so the
classes can be correctly detected as mix-in classes by modules such as
enum.
This commit is contained in:
Arusekk 2023-03-22 00:17:21 +01:00 committed by Catherine
parent a0307c343d
commit 5f094a23eb
2 changed files with 17 additions and 30 deletions

View file

@ -41,12 +41,10 @@ class ShapeCastable:
a richer description of the shape than what is supported by the core Amaranth language, yet a richer description of the shape than what is supported by the core Amaranth language, yet
still be transparently used with it. still be transparently used with it.
""" """
def __new__(cls, *args, **kwargs): def __init_subclass__(cls, **kwargs):
self = super().__new__(cls) if not hasattr(cls, "as_shape"):
if not hasattr(self, "as_shape"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
f"the `as_shape` method") f"the `as_shape` method")
return self
class Shape: class Shape:
@ -1319,15 +1317,13 @@ class ValueCastable:
from :class:`ValueCastable` is mutable, it is up to the user to ensure that it is not mutated from :class:`ValueCastable` is mutable, it is up to the user to ensure that it is not mutated
in a way that changes its representation after the first call to :meth:`as_value`. in a way that changes its representation after the first call to :meth:`as_value`.
""" """
def __new__(cls, *args, **kwargs): def __init_subclass__(cls, **kwargs):
self = super().__new__(cls) if not hasattr(cls, "as_value"):
if not hasattr(self, "as_value"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
"the `as_value` method") "the `as_value` method")
if not hasattr(self.as_value, "_ValueCastable__memoized"): if not hasattr(cls.as_value, "_ValueCastable__memoized"):
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate " raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate "
"the `as_value` method with the `ValueCastable.lowermethod` decorator") "the `as_value` method with the `ValueCastable.lowermethod` decorator")
return self
@staticmethod @staticmethod
def lowermethod(func): def lowermethod(func):

View file

@ -145,17 +145,14 @@ class MockShapeCastable(ShapeCastable):
return self.dest return self.dest
class MockShapeCastableNoOverride(ShapeCastable):
def __init__(self):
pass
class ShapeCastableTestCase(FHDLTestCase): class ShapeCastableTestCase(FHDLTestCase):
def test_no_override(self): def test_no_override(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Class 'MockShapeCastableNoOverride' deriving from `ShapeCastable` must " r"^Class 'MockShapeCastableNoOverride' deriving from `ShapeCastable` must "
r"override the `as_shape` method$"): r"override the `as_shape` method$"):
sc = MockShapeCastableNoOverride() class MockShapeCastableNoOverride(ShapeCastable):
def __init__(self):
pass
def test_cast(self): def test_cast(self):
sc = MockShapeCastable(unsigned(2)) sc = MockShapeCastable(unsigned(2))
@ -1135,19 +1132,6 @@ class MockValueCastableChanges(ValueCastable):
return Signal(self.width) return Signal(self.width)
class MockValueCastableNotDecorated(ValueCastable):
def __init__(self):
pass
def as_value(self):
return Signal()
class MockValueCastableNoOverride(ValueCastable):
def __init__(self):
pass
class MockValueCastableCustomGetattr(ValueCastable): class MockValueCastableCustomGetattr(ValueCastable):
def __init__(self): def __init__(self):
pass pass
@ -1165,13 +1149,20 @@ class ValueCastableTestCase(FHDLTestCase):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Class 'MockValueCastableNotDecorated' deriving from `ValueCastable` must " r"^Class 'MockValueCastableNotDecorated' deriving from `ValueCastable` must "
r"decorate the `as_value` method with the `ValueCastable.lowermethod` decorator$"): r"decorate the `as_value` method with the `ValueCastable.lowermethod` decorator$"):
vc = MockValueCastableNotDecorated() class MockValueCastableNotDecorated(ValueCastable):
def __init__(self):
pass
def as_value(self):
return Signal()
def test_no_override(self): def test_no_override(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must " r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must "
r"override the `as_value` method$"): r"override the `as_value` method$"):
vc = MockValueCastableNoOverride() class MockValueCastableNoOverride(ValueCastable):
def __init__(self):
pass
def test_memoized(self): def test_memoized(self):
vc = MockValueCastableChanges(1) vc = MockValueCastableChanges(1)