hdl._ast: Make AST nodes immutable.

Fixes #1067.
This commit is contained in:
Wanda 2024-02-28 09:26:19 +01:00 committed by Catherine
parent 3271f85650
commit 2bf1b4dafc
3 changed files with 186 additions and 82 deletions

View file

@ -1462,26 +1462,38 @@ class Const(Value, metaclass=_ConstMeta):
def __init__(self, value, shape=None, *, src_loc_at=0):
# We deliberately do not call Value.__init__ here.
self.value = int(operator.index(value))
value = int(operator.index(value))
if shape is None:
shape = Shape(bits_for(self.value), signed=self.value < 0)
shape = Shape(bits_for(value), signed=value < 0)
elif isinstance(shape, int):
shape = Shape(shape, signed=self.value < 0)
shape = Shape(shape, signed=value < 0)
else:
if isinstance(shape, range) and self.value == shape.stop:
if isinstance(shape, range) and value == shape.stop:
warnings.warn(
message="Value {!r} equals the non-inclusive end of the constant "
"shape {!r}; this is likely an off-by-one error"
.format(self.value, shape),
message=f"Value {value!r} equals the non-inclusive end of the constant "
f"shape {shape!r}; this is likely an off-by-one error",
category=SyntaxWarning,
stacklevel=3)
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
if self.signed and self.value >> (self.width - 1) & 1:
self.value |= -(1 << self.width)
self._width = shape.width
self._signed = shape.signed
if shape.signed and value >> (shape.width - 1) & 1:
value |= -(1 << shape.width)
else:
self.value &= (1 << self.width) - 1
value &= (1 << shape.width) - 1
self._value = value
@property
def value(self):
return self._value
@property
def width(self):
return self._width
@property
def signed(self):
return self._signed
def shape(self):
return Shape(self.width, self.signed)
@ -1500,8 +1512,16 @@ C = Const # shorthand
class Operator(Value):
def __init__(self, operator, operands, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at)
self.operator = operator
self.operands = [Value.cast(op) for op in operands]
self._operator = operator
self._operands = tuple(Value.cast(op) for op in operands)
@property
def operator(self):
return self._operator
@property
def operands(self):
return self._operands
def shape(self):
def _bitwise_binary_shape(a_shape, b_shape):
@ -1614,9 +1634,21 @@ class Slice(Value):
raise IndexError(f"Slice start {start} must be less than slice stop {stop}")
super().__init__(src_loc_at=src_loc_at)
self.value = value
self.start = int(operator.index(start))
self.stop = int(operator.index(stop))
self._value = value
self._start = int(operator.index(start))
self._stop = int(operator.index(stop))
@property
def value(self):
return self._value
@property
def start(self):
return self._start
@property
def stop(self):
return self._stop
def shape(self):
return Shape(self.stop - self.start)
@ -1645,10 +1677,26 @@ class Part(Value):
raise TypeError("Part offset must be unsigned")
super().__init__(src_loc_at=src_loc_at)
self.value = value
self.offset = offset
self.width = width
self.stride = stride
self._value = value
self._offset = offset
self._width = width
self._stride = stride
@property
def value(self):
return self._value
@property
def offset(self):
return self._offset
@property
def width(self):
return self._width
@property
def stride(self):
return self._stride
def shape(self):
return Shape(self.width)
@ -1691,7 +1739,7 @@ class Cat(Value):
"""
def __init__(self, *args, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.parts = []
parts = []
for index, arg in enumerate(flatten(args)):
if isinstance(arg, Enum) and (not isinstance(type(arg), ShapeCastable) or
not hasattr(arg, "_amaranth_shape_")):
@ -1706,7 +1754,12 @@ class Cat(Value):
"context; specify the width explicitly using C({}, {})"
.format(index + 1, arg, arg, bits_for(arg)),
SyntaxWarning, stacklevel=2 + src_loc_at)
self.parts.append(Value.cast(arg))
parts.append(Value.cast(arg))
self._parts = tuple(parts)
@property
def parts(self):
return self._parts
def shape(self):
return Shape(sum(len(part) for part in self.parts))
@ -1784,8 +1837,8 @@ class Signal(Value, DUID, metaclass=_SignalMeta):
shape = unsigned(1)
else:
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
self._width = shape.width
self._signed = shape.signed
# TODO(amaranth-0.7): remove
if reset is not None:
@ -1831,8 +1884,8 @@ class Signal(Value, DUID, metaclass=_SignalMeta):
.format(orig_init, shape),
category=SyntaxWarning,
stacklevel=2)
self.init = init.value
self.reset_less = bool(reset_less)
self._init = init.value
self._reset_less = bool(reset_less)
if isinstance(orig_shape, range) and orig_init is not None and orig_init not in orig_shape:
if orig_init == orig_shape.stop:
@ -1843,21 +1896,21 @@ class Signal(Value, DUID, metaclass=_SignalMeta):
raise SyntaxError(
f"Initial value {orig_init!r} is not within the signal shape {orig_shape!r}")
self.attrs = OrderedDict(() if attrs is None else attrs)
self._attrs = OrderedDict(() if attrs is None else attrs)
if decoder is not None:
# The value representation is specified explicitly. Since we do not expose `hdl._repr`,
# this is the only way to add a custom filter to the signal right now. The setter sets
# `self._value_repr` as well as the compatibility `self.decoder`.
self.decoder = decoder
pass
else:
# If it's an enum, expose it via `self.decoder` for compatibility, whether it's a Python
# enum or an Amaranth enum. This also sets the value representation, even for custom
# shape-castables that implement their own `_value_repr`.
if isinstance(orig_shape, type) and issubclass(orig_shape, Enum):
self.decoder = orig_shape
decoder = orig_shape
else:
self.decoder = None
decoder = None
# The value representation is specified implicitly in the shape of the signal.
if isinstance(orig_shape, ShapeCastable):
# A custom shape-castable always has a `_value_repr`, at least the default one.
@ -1869,24 +1922,6 @@ class Signal(Value, DUID, metaclass=_SignalMeta):
# Any other case is formatted as a plain integer.
self._value_repr = (Repr(FormatInt(), self),)
@property
def reset(self):
warnings.warn("`Signal.reset` is deprecated, use `Signal.init` instead",
DeprecationWarning, stacklevel=2)
return self.init
@reset.setter
def reset(self, value):
warnings.warn("`Signal.reset` is deprecated, use `Signal.init` instead",
DeprecationWarning, stacklevel=2)
self.init = value
@property
def decoder(self):
return self._decoder
@decoder.setter
def decoder(self, decoder):
# Compute the value representation that will be used by Amaranth.
if decoder is None:
self._value_repr = (Repr(FormatInt(), self),)
@ -1903,6 +1938,37 @@ class Signal(Value, DUID, metaclass=_SignalMeta):
return str(value)
self._decoder = enum_decoder
@property
def width(self):
return self._width
@property
def signed(self):
return self._signed
@property
def init(self):
return self._init
@property
def reset(self):
warnings.warn("`Signal.reset` is deprecated, use `Signal.init` instead",
DeprecationWarning, stacklevel=2)
return self._init
@property
def reset_less(self):
return self._reset_less
@property
def attrs(self):
# Would ideally be frozendict...
return self._attrs
@property
def decoder(self):
return self._decoder
@classmethod
def like(cls, other, *, name=None, name_suffix=None, init=None, reset=None, src_loc_at=0, **kwargs):
"""Create Signal based on another.
@ -1970,7 +2036,11 @@ class ClockSignal(Value):
raise TypeError(f"Clock domain name must be a string, not {domain!r}")
if domain == "comb":
raise ValueError(f"Domain '{domain}' does not have a clock")
self.domain = domain
self._domain = domain
@property
def domain(self):
return self._domain
def shape(self):
return Shape(1)
@ -2006,8 +2076,16 @@ class ResetSignal(Value):
raise TypeError(f"Clock domain name must be a string, not {domain!r}")
if domain == "comb":
raise ValueError(f"Domain '{domain}' does not have a reset")
self.domain = domain
self.allow_reset_less = allow_reset_less
self._domain = domain
self._allow_reset_less = allow_reset_less
@property
def domain(self):
return self._domain
@property
def allow_reset_less(self):
return self._allow_reset_less
def shape(self):
return Shape(1)
@ -2032,8 +2110,16 @@ class AnyValue(Value, DUID):
super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
self._width = shape.width
self._signed = shape.signed
@property
def width(self):
return self._width
@property
def signed(self):
return self._signed
def shape(self):
return Shape(self.width, self.signed)
@ -2147,8 +2233,16 @@ class Array(MutableSequence):
class ArrayProxy(Value):
def __init__(self, elems, index, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at)
self.elems = elems
self.index = Value.cast(index)
self._elems = elems
self._index = Value.cast(index)
@property
def elems(self):
return self._elems
@property
def index(self):
return self._index
def __getattr__(self, attr):
return ArrayProxy([getattr(elem, attr) for elem in self.elems], self.index)
@ -2245,8 +2339,16 @@ class Statement:
class Assign(Statement):
def __init__(self, lhs, rhs, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.lhs = Value.cast(lhs)
self.rhs = Value.cast(rhs)
self._lhs = Value.cast(lhs)
self._rhs = Value.cast(rhs)
@property
def lhs(self):
return self._lhs
@property
def rhs(self):
return self._rhs
def _lhs_signals(self):
return self.lhs._lhs_signals()
@ -2273,13 +2375,25 @@ class Property(Statement, MustUse):
def __init__(self, kind, test, *, name=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
self.test = Value.cast(test)
self.name = name
self._kind = self.Kind(kind)
self._test = Value.cast(test)
self._name = name
if not isinstance(self.name, str) and self.name is not None:
raise TypeError("Property name must be a string or None, not {!r}"
.format(self.name))
@property
def kind(self):
return self._kind
@property
def test(self):
return self._test
@property
def name(self):
return self._name
def _lhs_signals(self):
return set()
@ -2322,8 +2436,8 @@ class Switch(Statement):
# be automatically traced, so whatever constructs a Switch may optionally provide it.
self.case_src_locs = {}
self.test = Value.cast(test)
self.cases = OrderedDict()
self._test = Value.cast(test)
self._cases = OrderedDict()
for orig_keys, stmts in cases.items():
# Map: None -> (); key -> (key,); (key...) -> (key...)
keys = orig_keys
@ -2354,10 +2468,18 @@ class Switch(Statement):
new_keys = (*new_keys, key)
if not isinstance(stmts, Iterable):
stmts = [stmts]
self.cases[new_keys] = Statement.cast(stmts)
self._cases[new_keys] = Statement.cast(stmts)
if orig_keys in case_src_locs:
self.case_src_locs[new_keys] = case_src_locs[orig_keys]
@property
def test(self):
return self._test
@property
def cases(self):
return self._cases
def _lhs_signals(self):
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())

View file

@ -1237,10 +1237,6 @@ class SignalTestCase(FHDLTestCase):
with self.assertWarnsRegex(DeprecationWarning,
r"^`Signal.reset` is deprecated, use `Signal.init` instead$"):
self.assertEqual(s1.reset, 0b111)
with self.assertWarnsRegex(DeprecationWarning,
r"^`Signal.reset` is deprecated, use `Signal.init` instead$"):
s1.reset = 0b010
self.assertEqual(s1.init, 0b010)
with self.assertWarnsRegex(DeprecationWarning,
r"^`reset=` is deprecated, use `init=` instead$"):
s2 = Signal.like(s1, reset=3)

View file

@ -194,20 +194,6 @@ class RecordTestCase(FHDLTestCase):
r4 = Record.like(r1, name_suffix="foo")
self.assertEqual(r4.name, "r1foo")
def test_like_modifications(self):
r1 = Record([("a", 1), ("b", [("s", 1)])])
self.assertEqual(r1.a.name, "r1__a")
self.assertEqual(r1.b.name, "r1__b")
self.assertEqual(r1.b.s.name, "r1__b__s")
r1.a.init = 1
r1.b.s.init = 1
r2 = Record.like(r1)
self.assertEqual(r2.a.init, 1)
self.assertEqual(r2.b.s.init, 1)
self.assertEqual(r2.a.name, "r2__a")
self.assertEqual(r2.b.name, "r2__b")
self.assertEqual(r2.b.s.name, "r2__b__s")
def test_slice_tuple(self):
r1 = Record([("a", 1), ("b", 2), ("c", 3)])
r2 = r1["a", "c"]