hdl.ast: use keyword-only arguments as appropriate.

As a motivation/related refactor, make sure each AST node exposes
src_loc_at in the constructor.
This commit is contained in:
whitequark 2019-07-08 09:23:33 +00:00
parent 70f3563b5f
commit dac6275493
2 changed files with 43 additions and 39 deletions

View file

@ -41,7 +41,7 @@ class Value(metaclass=ABCMeta):
else: else:
raise TypeError("Object '{!r}' is not an nMigen value".format(obj)) raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
def __init__(self, src_loc_at=0): def __init__(self, *, src_loc_at=0):
super().__init__() super().__init__()
self.src_loc = tracer.get_src_loc(1 + src_loc_at) self.src_loc = tracer.get_src_loc(1 + src_loc_at)
@ -242,6 +242,7 @@ class Const(Value):
return value return value
def __init__(self, value, shape=None): def __init__(self, value, shape=None):
# We deliberately do not call Value.__init__ here.
self.value = int(value) self.value = int(value)
if shape is None: if shape is None:
shape = bits_for(self.value), self.value < 0 shape = bits_for(self.value), self.value < 0
@ -270,8 +271,8 @@ C = Const # shorthand
class AnyValue(Value, DUID): class AnyValue(Value, DUID):
def __init__(self, shape): def __init__(self, shape, *, src_loc_at=0):
super().__init__(src_loc_at=0) super().__init__(src_loc_at=src_loc_at)
if isinstance(shape, int): if isinstance(shape, int):
shape = shape, False shape = shape, False
self.nbits, self.signed = shape self.nbits, self.signed = shape
@ -300,7 +301,7 @@ class AnySeq(AnyValue):
@final @final
class Operator(Value): class Operator(Value):
def __init__(self, op, operands, src_loc_at=0): def __init__(self, op, operands, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at) super().__init__(src_loc_at=1 + src_loc_at)
self.op = op self.op = op
self.operands = [Value.wrap(o) for o in operands] self.operands = [Value.wrap(o) for o in operands]
@ -395,7 +396,7 @@ def Mux(sel, val1, val0):
@final @final
class Slice(Value): class Slice(Value):
def __init__(self, value, start, end): def __init__(self, value, start, end, *, src_loc_at=0):
if not isinstance(start, int): if not isinstance(start, int):
raise TypeError("Slice start must be an integer, not '{!r}'".format(start)) raise TypeError("Slice start must be an integer, not '{!r}'".format(start))
if not isinstance(end, int): if not isinstance(end, int):
@ -413,7 +414,7 @@ class Slice(Value):
if start > end: if start > end:
raise IndexError("Slice start {} must be less than slice end {}".format(start, end)) raise IndexError("Slice start {} must be less than slice end {}".format(start, end))
super().__init__() super().__init__(src_loc_at=src_loc_at)
self.value = Value.wrap(value) self.value = Value.wrap(value)
self.start = start self.start = start
self.end = end self.end = end
@ -433,11 +434,11 @@ class Slice(Value):
@final @final
class Part(Value): class Part(Value):
def __init__(self, value, offset, width): def __init__(self, value, offset, width, *, src_loc_at=0):
if not isinstance(width, int) or width < 0: if not isinstance(width, int) or width < 0:
raise TypeError("Part width must be a non-negative integer, not '{!r}'".format(width)) raise TypeError("Part width must be a non-negative integer, not '{!r}'".format(width))
super().__init__() super().__init__(src_loc_at=src_loc_at)
self.value = value self.value = value
self.offset = Value.wrap(offset) self.offset = Value.wrap(offset)
self.width = width self.width = width
@ -480,8 +481,8 @@ class Cat(Value):
Value, inout Value, inout
Resulting ``Value`` obtained by concatentation. Resulting ``Value`` obtained by concatentation.
""" """
def __init__(self, *args): def __init__(self, *args, src_loc_at=0):
super().__init__() super().__init__(src_loc_at=src_loc_at)
self.parts = [Value.wrap(v) for v in flatten(args)] self.parts = [Value.wrap(v) for v in flatten(args)]
def shape(self): def shape(self):
@ -525,12 +526,12 @@ class Repl(Value):
Repl, out Repl, out
Replicated value. Replicated value.
""" """
def __init__(self, value, count): def __init__(self, value, count, *, src_loc_at=0):
if not isinstance(count, int) or count < 0: if not isinstance(count, int) or count < 0:
raise TypeError("Replication count must be a non-negative integer, not '{!r}'" raise TypeError("Replication count must be a non-negative integer, not '{!r}'"
.format(count)) .format(count))
super().__init__() super().__init__(src_loc_at=src_loc_at)
self.value = Value.wrap(value) self.value = Value.wrap(value)
self.count = count self.count = count
@ -592,7 +593,7 @@ class Signal(Value, DUID):
attrs : dict attrs : dict
""" """
def __init__(self, shape=None, name=None, reset=0, reset_less=False, min=None, max=None, def __init__(self, shape=None, name=None, *, reset=0, reset_less=False, min=None, max=None,
attrs=None, decoder=None, src_loc_at=0): attrs=None, decoder=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at) super().__init__(src_loc_at=src_loc_at)
@ -641,7 +642,7 @@ class Signal(Value, DUID):
self.decoder = decoder self.decoder = decoder
@classmethod @classmethod
def like(cls, other, name=None, name_suffix=None, src_loc_at=0, **kwargs): def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
"""Create Signal based on another. """Create Signal based on another.
Parameters Parameters
@ -688,8 +689,8 @@ class ClockSignal(Value):
domain : str domain : str
Clock domain to obtain a clock signal for. Defaults to ``"sync"``. Clock domain to obtain a clock signal for. Defaults to ``"sync"``.
""" """
def __init__(self, domain="sync"): def __init__(self, domain="sync", *, src_loc_at=0):
super().__init__() super().__init__(src_loc_at=src_loc_at)
if not isinstance(domain, str): if not isinstance(domain, str):
raise TypeError("Clock domain name must be a string, not '{!r}'".format(domain)) raise TypeError("Clock domain name must be a string, not '{!r}'".format(domain))
self.domain = domain self.domain = domain
@ -722,8 +723,8 @@ class ResetSignal(Value):
allow_reset_less : bool allow_reset_less : bool
If the clock domain is reset-less, act as a constant ``0`` instead of reporting an error. If the clock domain is reset-less, act as a constant ``0`` instead of reporting an error.
""" """
def __init__(self, domain="sync", allow_reset_less=False): def __init__(self, domain="sync", allow_reset_less=False, *, src_loc_at=0):
super().__init__() super().__init__(src_loc_at=src_loc_at)
if not isinstance(domain, str): if not isinstance(domain, str):
raise TypeError("Clock domain name must be a string, not '{!r}'".format(domain)) raise TypeError("Clock domain name must be a string, not '{!r}'".format(domain))
self.domain = domain self.domain = domain
@ -832,8 +833,8 @@ class Array(MutableSequence):
@final @final
class ArrayProxy(Value): class ArrayProxy(Value):
def __init__(self, elems, index): def __init__(self, elems, index, *, src_loc_at=0):
super().__init__(src_loc_at=1) super().__init__(src_loc_at=1 + src_loc_at)
self.elems = elems self.elems = elems
self.index = Value.wrap(index) self.index = Value.wrap(index)
@ -885,7 +886,7 @@ class UserValue(Value):
* Indexing or iterating through individual bits; * Indexing or iterating through individual bits;
* Adding an assignment to the value to a ``Module`` using ``m.d.<domain> +=``. * Adding an assignment to the value to a ``Module`` using ``m.d.<domain> +=``.
""" """
def __init__(self, src_loc_at=1): def __init__(self, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at) super().__init__(src_loc_at=1 + src_loc_at)
self.__lowered = None self.__lowered = None
@ -917,8 +918,8 @@ class Sample(Value):
of the ``domain`` clock back. If that moment is before the beginning of time, it is equal of the ``domain`` clock back. If that moment is before the beginning of time, it is equal
to the value of the expression calculated as if each signal had its reset value. to the value of the expression calculated as if each signal had its reset value.
""" """
def __init__(self, expr, clocks, domain): def __init__(self, expr, clocks, domain, *, src_loc_at=0):
super().__init__(src_loc_at=1) super().__init__(src_loc_at=1 + src_loc_at)
self.value = Value.wrap(expr) self.value = Value.wrap(expr)
self.clocks = int(clocks) self.clocks = int(clocks)
self.domain = domain self.domain = domain
@ -962,6 +963,9 @@ class _StatementList(list):
class Statement: class Statement:
def __init__(self, *, src_loc_at=0):
self.src_loc = tracer.get_src_loc(1 + src_loc_at)
@staticmethod @staticmethod
def wrap(obj): def wrap(obj):
if isinstance(obj, Iterable): if isinstance(obj, Iterable):
@ -975,9 +979,8 @@ class Statement:
@final @final
class Assign(Statement): class Assign(Statement):
def __init__(self, lhs, rhs, src_loc_at=0): def __init__(self, lhs, rhs, *, src_loc_at=0):
self.src_loc = tracer.get_src_loc(src_loc_at) super().__init__(src_loc_at=src_loc_at)
self.lhs = Value.wrap(lhs) self.lhs = Value.wrap(lhs)
self.rhs = Value.wrap(rhs) self.rhs = Value.wrap(rhs)
@ -992,17 +995,14 @@ class Assign(Statement):
class Property(Statement): class Property(Statement):
def __init__(self, test, _check=None, _en=None): def __init__(self, test, *, _check=None, _en=None, src_loc_at=0):
self.src_loc = tracer.get_src_loc() super().__init__(src_loc_at=src_loc_at)
self.test = Value.wrap(test)
self.test = Value.wrap(test)
self._check = _check self._check = _check
self._en = _en
if self._check is None: if self._check is None:
self._check = Signal(reset_less=True, name="${}$check".format(self._kind)) self._check = Signal(reset_less=True, name="${}$check".format(self._kind))
self._check.src_loc = self.src_loc self._check.src_loc = self.src_loc
self._en = _en
if _en is None: if _en is None:
self._en = Signal(reset_less=True, name="${}$en".format(self._kind)) self._en = Signal(reset_less=True, name="${}$en".format(self._kind))
self._en.src_loc = self.src_loc self._en.src_loc = self.src_loc
@ -1029,9 +1029,8 @@ class Assume(Property):
# @final # @final
class Switch(Statement): class Switch(Statement):
def __init__(self, test, cases, src_loc_at=0): def __init__(self, test, cases, *, src_loc_at=0):
self.src_loc = tracer.get_src_loc(src_loc_at) super().__init__(src_loc_at=src_loc_at)
self.test = Value.wrap(test) self.test = Value.wrap(test)
self.cases = OrderedDict() self.cases = OrderedDict()
for keys, stmts in cases.items(): for keys, stmts in cases.items():
@ -1081,7 +1080,8 @@ class Switch(Statement):
@final @final
class Delay(Statement): class Delay(Statement):
def __init__(self, interval=None): def __init__(self, interval=None, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.interval = None if interval is None else float(interval) self.interval = None if interval is None else float(interval)
def _rhs_signals(self): def _rhs_signals(self):
@ -1096,7 +1096,8 @@ class Delay(Statement):
@final @final
class Tick(Statement): class Tick(Statement):
def __init__(self, domain="sync"): def __init__(self, domain="sync", *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.domain = str(domain) self.domain = str(domain)
def _rhs_signals(self): def _rhs_signals(self):
@ -1108,6 +1109,9 @@ class Tick(Statement):
@final @final
class Passive(Statement): class Passive(Statement):
def __init__(self, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
def _rhs_signals(self): def _rhs_signals(self):
return ValueSet() return ValueSet()

View file

@ -207,7 +207,7 @@ class StatementVisitor(metaclass=ABCMeta):
new_stmt = self.on_statements(stmt) new_stmt = self.on_statements(stmt)
else: else:
new_stmt = self.on_unknown_statement(stmt) new_stmt = self.on_unknown_statement(stmt)
if hasattr(stmt, "src_loc") and hasattr(new_stmt, "src_loc"): if isinstance(new_stmt, Statement):
new_stmt.src_loc = stmt.src_loc new_stmt.src_loc = stmt.src_loc
return new_stmt return new_stmt