diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index eb29620..bf3f526 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -41,12 +41,10 @@ class ShapeCastable: a richer description of the shape than what is supported by the core Amaranth language, yet still be transparently used with it. """ - def __new__(cls, *args, **kwargs): - self = super().__new__(cls) - if not hasattr(self, "as_shape"): + def __init_subclass__(cls, **kwargs): + if not hasattr(cls, "as_shape"): raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override " f"the `as_shape` method") - return self class Shape: @@ -1319,15 +1317,13 @@ class ValueCastable: from :class:`ValueCastable` is mutable, it is up to the user to ensure that it is not mutated in a way that changes its representation after the first call to :meth:`as_value`. """ - def __new__(cls, *args, **kwargs): - self = super().__new__(cls) - if not hasattr(self, "as_value"): + def __init_subclass__(cls, **kwargs): + if not hasattr(cls, "as_value"): raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override " "the `as_value` method") - if not hasattr(self.as_value, "_ValueCastable__memoized"): + if not hasattr(cls.as_value, "_ValueCastable__memoized"): raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate " "the `as_value` method with the `ValueCastable.lowermethod` decorator") - return self @staticmethod def lowermethod(func): diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 3d18091..7d95e20 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -145,17 +145,14 @@ class MockShapeCastable(ShapeCastable): return self.dest -class MockShapeCastableNoOverride(ShapeCastable): - def __init__(self): - pass - - class ShapeCastableTestCase(FHDLTestCase): def test_no_override(self): with self.assertRaisesRegex(TypeError, r"^Class 'MockShapeCastableNoOverride' deriving from `ShapeCastable` must " r"override the `as_shape` method$"): - sc = MockShapeCastableNoOverride() + class MockShapeCastableNoOverride(ShapeCastable): + def __init__(self): + pass def test_cast(self): sc = MockShapeCastable(unsigned(2)) @@ -1135,19 +1132,6 @@ class MockValueCastableChanges(ValueCastable): return Signal(self.width) -class MockValueCastableNotDecorated(ValueCastable): - def __init__(self): - pass - - def as_value(self): - return Signal() - - -class MockValueCastableNoOverride(ValueCastable): - def __init__(self): - pass - - class MockValueCastableCustomGetattr(ValueCastable): def __init__(self): pass @@ -1165,13 +1149,20 @@ class ValueCastableTestCase(FHDLTestCase): with self.assertRaisesRegex(TypeError, r"^Class 'MockValueCastableNotDecorated' deriving from `ValueCastable` must " r"decorate the `as_value` method with the `ValueCastable.lowermethod` decorator$"): - vc = MockValueCastableNotDecorated() + class MockValueCastableNotDecorated(ValueCastable): + def __init__(self): + pass + + def as_value(self): + return Signal() def test_no_override(self): with self.assertRaisesRegex(TypeError, r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must " r"override the `as_value` method$"): - vc = MockValueCastableNoOverride() + class MockValueCastableNoOverride(ValueCastable): + def __init__(self): + pass def test_memoized(self): vc = MockValueCastableChanges(1)