hdl._ir,hdl._nir,back.rtlil: new intermediate representation.

The new intermediate representation will enable global analyses
on Amaranth code without lowering it to another representation
such as RTLIL.

This commit also changes the RTLIL builder to use the new IR.

Co-authored-by: Wanda <wanda@phinode.net>
This commit is contained in:
Catherine 2023-08-21 05:23:15 +00:00
parent 78981232d9
commit 6f44438e58
7 changed files with 2536 additions and 1048 deletions

File diff suppressed because it is too large Load diff

View file

@ -1770,25 +1770,17 @@ class Property(Statement, MustUse):
Assume = "assume"
Cover = "cover"
def __init__(self, kind, test, *, _check=None, _en=None, name=None, src_loc_at=0):
def __init__(self, kind, test, *, name=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
self.test = Value.cast(test)
self._check = _check
self._en = _en
self.name = name
if not isinstance(self.name, str) and self.name is not None:
raise TypeError("Property name must be a string or None, not {!r}"
.format(self.name))
if self._check is None:
self._check = Signal(reset_less=True, name=f"${self.kind.value}$check")
self._check.src_loc = self.src_loc
if _en is None:
self._en = Signal(reset_less=True, name=f"${self.kind.value}$en")
self._en.src_loc = self.src_loc
def _lhs_signals(self):
return SignalSet((self._en, self._check))
return set()
def _rhs_signals(self):
return self.test._rhs_signals()

View file

@ -1,20 +1,17 @@
from abc import ABCMeta
from typing import Tuple
from collections import defaultdict, OrderedDict
from functools import reduce
import warnings
from .. import tracer
from .._utils import *
from .._unused import *
from ._ast import *
from ._ast import _StatementList
from ._cd import *
from .._utils import flatten, memoize
from .. import tracer, _unused
from . import _ast, _cd, _ir, _nir
__all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"]
class UnusedElaboratable(UnusedMustUse):
class UnusedElaboratable(_unused.UnusedMustUse):
# The warning is initially silenced. If everything that has been constructed remains unused,
# it means the application likely crashed (with an exception, or in another way that does not
# call `sys.excepthook`), and it's not necessary to show any warnings.
@ -22,7 +19,7 @@ class UnusedElaboratable(UnusedMustUse):
_MustUse__silence = True
class Elaboratable(MustUse):
class Elaboratable(_unused.MustUse):
_MustUse__warning = UnusedElaboratable
@ -64,7 +61,7 @@ class Fragment:
obj = new_obj
def __init__(self):
self.ports = SignalDict()
self.ports = _ast.SignalDict()
self.drivers = OrderedDict()
self.statements = {}
self.domains = OrderedDict()
@ -89,7 +86,7 @@ class Fragment:
def add_driver(self, signal, domain="comb"):
assert isinstance(domain, str)
if domain not in self.drivers:
self.drivers[domain] = SignalSet()
self.drivers[domain] = _ast.SignalSet()
self.drivers[domain].add(signal)
def iter_drivers(self):
@ -109,7 +106,7 @@ class Fragment:
yield domain, signal
def iter_signals(self):
signals = SignalSet()
signals = _ast.SignalSet()
signals |= self.ports.keys()
for domain, domain_signals in self.drivers.items():
if domain != "comb":
@ -122,7 +119,7 @@ class Fragment:
def add_domains(self, *domains):
for domain in flatten(domains):
assert isinstance(domain, ClockDomain)
assert isinstance(domain, _cd.ClockDomain)
assert domain.name not in self.domains
self.domains[domain.name] = domain
@ -131,9 +128,9 @@ class Fragment:
def add_statements(self, domain, *stmts):
assert isinstance(domain, str)
for stmt in Statement.cast(stmts):
for stmt in _ast.Statement.cast(stmts):
stmt._MustUse__used = True
self.statements.setdefault(domain, _StatementList()).append(stmt)
self.statements.setdefault(domain, _ast._StatementList()).append(stmt)
def add_subfragment(self, subfragment, name=None):
assert isinstance(subfragment, Fragment)
@ -186,15 +183,15 @@ class Fragment:
assert mode in ("silent", "warn", "error")
from ._mem import MemoryInstance
driver_subfrags = SignalDict()
driver_subfrags = _ast.SignalDict()
def add_subfrag(registry, entity, entry):
# Because of missing domain insertion, at the point when this code runs, we have
# a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to
# the actual signals (because the signal itself can be driven as well); but leave
# the unbound ones as it is, because there's no concrete signal for it yet anyway.
if isinstance(entity, ClockSignal) and entity.domain in self.domains:
if isinstance(entity, _ast.ClockSignal) and entity.domain in self.domains:
entity = self.domains[entity.domain].clk
elif isinstance(entity, ResetSignal) and entity.domain in self.domains:
elif isinstance(entity, _ast.ResetSignal) and entity.domain in self.domains:
entity = self.domains[entity.domain].rst
if entity not in registry:
@ -265,7 +262,7 @@ class Fragment:
return self._resolve_hierarchy_conflicts(hierarchy, mode)
# Nothing was flattened, we're done!
return SignalSet(driver_subfrags.keys())
return _ast.SignalSet(driver_subfrags.keys())
def _propagate_domains_up(self, hierarchy=("top",)):
from ._xfrm import DomainRenamer
@ -296,18 +293,18 @@ class Fragment:
if not all(names):
names = sorted(f"<unnamed #{i}>" if n is None else f"'{n}'"
for f, n, i in subfrags)
raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; "
"it is necessary to either rename subfragment domains "
"explicitly, or give names to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
raise _cd.DomainError(
"Domain '{}' is defined by subfragments {} of fragment '{}'; it is necessary "
"to either rename subfragment domains explicitly, or give names to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
if len(names) != len(set(names)):
names = sorted(f"#{i}" for f, n, i in subfrags)
raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', "
"some of which have identical names; it is necessary to either "
"rename subfragment domains explicitly, or give distinct names "
"to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
raise _cd.DomainError(
"Domain '{}' is defined by subfragments {} of fragment '{}', some of which "
"have identical names; it is necessary to either rename subfragment domains "
"explicitly, or give distinct names to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
for subfrag, name, i in subfrags:
domain_name_map = {domain_name: f"{name}_{domain_name}"}
@ -343,8 +340,8 @@ class Fragment:
continue
value = missing_domain(domain_name)
if value is None:
raise DomainError(f"Domain '{domain_name}' is used but not defined")
if type(value) is ClockDomain:
raise _cd.DomainError(f"Domain '{domain_name}' is used but not defined")
if type(value) is _cd.ClockDomain:
self.add_domains(value)
# And expose ports on the newly added clock domain, since it is added directly
# and there was no chance to add any logic driving it.
@ -353,7 +350,7 @@ class Fragment:
new_fragment = Fragment.get(value, platform=platform)
if domain_name not in new_fragment.domains:
defined = new_fragment.domains.keys()
raise DomainError(
raise _cd.DomainError(
"Fragment returned by missing domain callback does not define "
"requested domain '{}' (defines {})."
.format(domain_name, ", ".join(f"'{n}'" for n in defined)))
@ -463,12 +460,12 @@ class Fragment:
parent = {self: None}
level = {self: 0}
uses = SignalDict()
defs = SignalDict()
ios = SignalDict()
uses = _ast.SignalDict()
defs = _ast.SignalDict()
ios = _ast.SignalDict()
self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
ports = SignalSet(ports)
ports = _ast.SignalSet(ports)
if all_undef_as_ports:
for sig in uses:
if sig in defs:
@ -530,7 +527,7 @@ class Fragment:
else:
self.add_ports(sig, dir="i")
def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)):
def prepare(self, ports=None, missing_domain=lambda name: _cd.ClockDomain(name)):
from ._xfrm import DomainLowerer
new_domains = self._propagate_domains(missing_domain)
@ -541,14 +538,14 @@ class Fragment:
if not isinstance(ports, tuple) and not isinstance(ports, list):
msg = "`ports` must be either a list or a tuple, not {!r}"\
.format(ports)
if isinstance(ports, Value):
if isinstance(ports, _ast.Value):
msg += " (did you mean `ports=(<signal>,)`, rather than `ports=<signal>`?)"
raise TypeError(msg)
mapped_ports = []
# Lower late bound signals like ClockSignal() to ports.
port_lowerer = DomainLowerer(fragment.domains)
for port in ports:
if not isinstance(port, (Signal, ClockSignal, ResetSignal)):
if not isinstance(port, (_ast.Signal, _ast.ClockSignal, _ast.ResetSignal)):
raise TypeError("Only signals may be added as ports, not {!r}"
.format(port))
mapped_ports.append(port_lowerer.on_value(port))
@ -573,7 +570,7 @@ class Fragment:
may get a different name.
"""
signal_names = SignalDict()
signal_names = _ast.SignalDict()
assigned_names = set()
def add_signal_name(signal):
@ -599,7 +596,7 @@ class Fragment:
for statements in self.statements.values():
for statement in statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (ClockSignal, ResetSignal)):
if not isinstance(signal, (_ast.ClockSignal, _ast.ResetSignal)):
add_signal_name(signal)
return signal_names
@ -657,7 +654,7 @@ class Instance(Fragment):
elif kind == "p":
self.parameters[name] = value
elif kind in ("i", "o", "io"):
self.named_ports[name] = (Value.cast(value), kind)
self.named_ports[name] = (_ast.Value.cast(value), kind)
else:
raise NameError("Instance argument {!r} should be a tuple (kind, name, value) "
"where kind is one of \"a\", \"p\", \"i\", \"o\", or \"io\""
@ -669,12 +666,763 @@ class Instance(Fragment):
elif kw.startswith("p_"):
self.parameters[kw[2:]] = arg
elif kw.startswith("i_"):
self.named_ports[kw[2:]] = (Value.cast(arg), "i")
self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "i")
elif kw.startswith("o_"):
self.named_ports[kw[2:]] = (Value.cast(arg), "o")
self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "o")
elif kw.startswith("io_"):
self.named_ports[kw[3:]] = (Value.cast(arg), "io")
self.named_ports[kw[3:]] = (_ast.Value.cast(arg), "io")
else:
raise NameError("Instance keyword argument {}={!r} does not start with one of "
"\"a_\", \"p_\", \"i_\", \"o_\", or \"io_\""
.format(kw, arg))
############################################################################################### >:3
class NetlistDriver:
def __init__(self, module_idx: int, signal: _ast.Signal,
domain: '_cd.ClockDomain | None', *, src_loc):
self.module_idx = module_idx
self.signal = signal
self.domain = domain
self.src_loc = src_loc
self.assignments = []
def emit_value(self, builder):
if self.domain is None:
reset = _ast.Const(self.signal.reset, self.signal.width)
default, _signed = builder.emit_rhs(self.module_idx, reset)
else:
default = builder.emit_signal(self.signal)
if len(self.assignments) == 1:
assign, = self.assignments
if assign.cond == 1 and assign.start == 0 and len(assign.value) == len(default):
return assign.value
cell = _nir.AssignmentList(self.module_idx, default=default, assignments=self.assignments,
src_loc=self.signal.src_loc)
return builder.netlist.add_value_cell(len(default), cell)
class NetlistEmitter:
def __init__(self, netlist: _nir.Netlist, fragment_names: 'dict[_ir.Fragment, str]'):
self.netlist = netlist
self.fragment_names = fragment_names
self.drivers = _ast.SignalDict()
self.rhs_cache: dict[int, Tuple[_nir.Value, bool, _ast.Value]] = {}
# Collected for driver conflict diagnostics only.
self.late_net_to_signal = {}
self.connect_src_loc = {}
def emit_signal(self, signal) -> _nir.Value:
if signal in self.netlist.signals:
return self.netlist.signals[signal]
value = self.netlist.alloc_late_value(len(signal))
self.netlist.signals[signal] = value
for bit, net in enumerate(value):
self.late_net_to_signal[net] = (signal, bit)
return value
# Used for instance outputs and read port data, not used for actual assignments.
def emit_lhs(self, value: _ast.Value):
if isinstance(value, _ast.Signal):
return self.emit_signal(value)
elif isinstance(value, _ast.Cat):
result = []
for part in value.parts:
result += self.emit_lhs(part)
return _nir.Value(result)
elif isinstance(value, _ast.Slice):
return self.emit_lhs(value.value)[value.start:value.stop]
elif isinstance(value, _ast.Operator):
assert value.operator in ('u', 's')
return self.emit_lhs(value.operands[0])
else:
raise TypeError # :nocov:
def extend(self, value: _nir.Value, signed: bool, width: int):
nets = list(value)
while len(nets) < width:
if signed:
nets.append(nets[-1])
else:
nets.append(_nir.Net.from_const(0))
return _nir.Value(nets)
def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc):
op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc)
return self.netlist.add_value_cell(op.width, op)
def unify_shapes_bitwise(self,
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
if signed_a == signed_b:
width = max(len(operand_a), len(operand_b))
elif signed_a:
width = max(len(operand_a), len(operand_b) + 1)
else: # signed_b
width = max(len(operand_a) + 1, len(operand_b))
operand_a = self.extend(operand_a, signed_a, width)
operand_b = self.extend(operand_b, signed_b, width)
signed = signed_a or signed_b
return (operand_a, operand_b, signed)
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
"""Emits a RHS value, returns a tuple of (value, is_signed)"""
try:
result, signed, value = self.rhs_cache[id(value)]
return result, signed
except KeyError:
pass
if isinstance(value, _ast.Const):
result = _nir.Value(
_nir.Net.from_const((value.value >> bit) & 1)
for bit in range(value.width)
)
signed = value.signed
elif isinstance(value, _ast.Signal):
result = self.emit_signal(value)
signed = value.signed
elif isinstance(value, _ast.Operator):
if len(value.operands) == 1:
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0])
if value.operator == 's':
result = operand_a
signed = True
elif value.operator == 'u':
result = operand_a
signed = False
elif value.operator == '+':
result = operand_a
signed = signed_a
elif value.operator == '-':
operand_a = self.extend(operand_a, signed_a, len(operand_a) + 1)
result = self.emit_operator(module_idx, '-', operand_a,
src_loc=value.src_loc)
signed = True
elif value.operator == '~':
result = self.emit_operator(module_idx, '~', operand_a,
src_loc=value.src_loc)
signed = signed_a
elif value.operator in ('b', 'r|', 'r&', 'r^'):
result = self.emit_operator(module_idx, value.operator, operand_a,
src_loc=value.src_loc)
signed = False
else:
assert False # :nocov:
elif len(value.operands) == 2:
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0])
operand_b, signed_b = self.emit_rhs(module_idx, value.operands[1])
if value.operator in ('|', '&', '^'):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, value.operator, operand_a, operand_b,
src_loc=value.src_loc)
elif value.operator in ('+', '-'):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
width = len(operand_a) + 1
operand_a = self.extend(operand_a, signed, width)
operand_b = self.extend(operand_b, signed, width)
result = self.emit_operator(module_idx, value.operator, operand_a, operand_b,
src_loc=value.src_loc)
if value.operator == '-':
signed = True
elif value.operator == '*':
width = len(operand_a) + len(operand_b)
operand_a = self.extend(operand_a, signed_a, width)
operand_b = self.extend(operand_b, signed_b, width)
result = self.emit_operator(module_idx, '*', operand_a, operand_b,
src_loc=value.src_loc)
signed = signed_a or signed_b
elif value.operator == '//':
width = len(operand_a) + signed_b
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
if len(operand_a) < width:
operand_a = self.extend(operand_a, signed, width)
operand_b = self.extend(operand_b, signed, width)
operator = 's//' if signed else 'u//'
result = _nir.Value(
self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)[:width]
)
elif value.operator == '%':
width = len(operand_b)
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
operator = 's%' if signed else 'u%'
result = _nir.Value(
self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)[:width]
)
signed = signed_b
elif value.operator == '<<':
operand_a = self.extend(operand_a, signed_a,
len(operand_a) + 2 ** len(operand_b) - 1)
result = self.emit_operator(module_idx, '<<', operand_a, operand_b,
src_loc=value.src_loc)
signed = signed_a
elif value.operator == '>>':
operator = 's>>' if signed_a else 'u>>'
result = self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)
signed = signed_a
elif value.operator in ('==', '!='):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, value.operator, operand_a, operand_b,
src_loc=value.src_loc)
signed = False
elif value.operator in ('<', '>', '<=', '>='):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
operator = ('s' if signed else 'u') + value.operator
result = self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)
signed = False
else:
assert False # :nocov:
elif len(value.operands) == 3:
assert value.operator == 'm'
operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0])
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1])
operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2])
if len(operand_s) != 1:
operand_s = self.emit_operator(module_idx, 'b', operand_s,
src_loc=value.src_loc)
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b,
src_loc=value.src_loc)
else:
assert False # :nocov:
elif isinstance(value, _ast.Slice):
inner, _signed = self.emit_rhs(module_idx, value.value)
result = _nir.Value(inner[value.start:value.stop])
signed = False
elif isinstance(value, _ast.Part):
inner, signed = self.emit_rhs(module_idx, value.value)
offset, _signed = self.emit_rhs(module_idx, value.offset)
cell = _nir.Part(module_idx, value=inner, value_signed=signed, width=value.width,
stride=value.stride, offset=offset, src_loc=value.src_loc)
result = self.netlist.add_value_cell(value.width, cell)
signed = False
elif isinstance(value, _ast.ArrayProxy):
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
width = 0
signed = False
for elem, elem_signed in elems:
if elem_signed:
if not signed:
width += 1
signed = True
width = max(width, len(elem))
elif signed:
width = max(width, len(elem) + 1)
else:
width = max(width, len(elem))
elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems)
index, _signed = self.emit_rhs(module_idx, value.index)
cell = _nir.ArrayMux(module_idx, width=width, elems=elems, index=index,
src_loc=value.src_loc)
result = self.netlist.add_value_cell(width, cell)
elif isinstance(value, _ast.Cat):
nets = []
for val in value.parts:
inner, _signed = self.emit_rhs(module_idx, val)
for net in inner:
nets.append(net)
result = _nir.Value(nets)
signed = False
elif isinstance(value, _ast.AnyValue):
result = self.netlist.add_value_cell(value.width,
_nir.AnyValue(module_idx, kind=value.kind.value, width=value.width,
src_loc=value.src_loc))
signed = value.signed
elif isinstance(value, _ast.Initial):
result = self.netlist.add_value_cell(1, _nir.Initial(module_idx, src_loc=value.src_loc))
signed = False
else:
assert False # :nocov:
assert value.shape().width == len(result), \
f"Value {value!r} with shape {value.shape()!r} does not match " \
f"result with width {len(result)}"
# Add the value itself to the cache to make sure `id(value)` remains allocated and pointing
# at `value`. This would be a weakref.WeakKeyDictionary if `value` was hashable.
self.rhs_cache[id(value)] = result, signed, value
return (result, signed)
def connect(self, lhs: _nir.Value, rhs: _nir.Value, *, src_loc):
assert len(lhs) == len(rhs)
for left, right in zip(lhs, rhs):
if left in self.netlist.connections:
signal, bit = self.late_net_to_signal[left]
other_src_loc = self.connect_src_loc[left]
raise _ir.DriverConflict(f"Bit {bit} of signal {signal!r} has multiple drivers: "
f"{other_src_loc} and {src_loc}")
self.netlist.connections[left] = right
self.connect_src_loc[left] = src_loc
def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
stmt: _ast.Statement, cond: _nir.Net):
if domain == "comb":
cd: _cd.ClockDomain | None = None
else:
cd = fragment.domains[domain]
if isinstance(stmt, _ast.Assign):
if isinstance(stmt.lhs, _ast.Signal):
signal = stmt.lhs
start = 0
width = signal.width
elif isinstance(stmt.lhs, _ast.Slice):
signal = stmt.lhs.value
start = stmt.lhs.start
width = stmt.lhs.stop - stmt.lhs.start
else:
assert False # :nocov:
assert isinstance(signal, _ast.Signal)
if signal in self.drivers:
driver = self.drivers[signal]
if driver.domain is not cd:
raise _ir.DriverConflict(
f"Signal {signal} driven from domain {cd} at {stmt.src_loc} and domain "
f"{driver.domain} at {driver.src_loc}")
if driver.module_idx != module_idx:
mod_name = ".".join(self.netlist.modules[module_idx].name or ("<toplevel>",))
other_mod_name = \
".".join(self.netlist.modules[driver.module_idx].name or ("<toplevel>",))
raise _ir.DriverConflict(
f"Signal {signal} driven from module {mod_name} at {stmt.src_loc} and "
f"module {other_mod_name} at {driver.src_loc}")
else:
driver = NetlistDriver(module_idx, signal, domain=cd, src_loc=stmt.src_loc)
self.drivers[signal] = driver
rhs, signed = self.emit_rhs(module_idx, stmt.rhs)
if len(rhs) > width:
rhs = _nir.Value(rhs[:width])
if len(rhs) < width:
rhs = self.extend(rhs, signed, width)
driver.assignments.append(_nir.Assignment(cond=cond, start=start, value=rhs,
src_loc=stmt.src_loc))
elif isinstance(stmt, _ast.Property):
test, _signed = self.emit_rhs(module_idx, stmt.test)
if len(test) != 1:
test = self.emit_operator(module_idx, 'b', test, src_loc=stmt.src_loc)
test, = test
en_cell = _nir.AssignmentList(module_idx,
default=_nir.Value.zeros(),
assignments=[
_nir.Assignment(cond=cond, start=0, value=_nir.Value.ones(),
src_loc=stmt.src_loc)
],
src_loc=stmt.src_loc)
cond, = self.netlist.add_value_cell(1, en_cell)
if cd is None:
cell = _nir.AsyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond,
name=stmt.name, src_loc=stmt.src_loc)
else:
clk, = self.emit_signal(cd.clk)
cell = _nir.SyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond,
clk=clk, clk_edge=cd.clk_edge, name=stmt.name,
src_loc=stmt.src_loc)
self.netlist.add_cell(cell)
elif isinstance(stmt, _ast.Switch):
test, _signed = self.emit_rhs(module_idx, stmt.test)
conds = []
for patterns in stmt.cases:
if patterns:
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=stmt.case_src_locs.get(patterns))
net, = self.netlist.add_value_cell(1, cell)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds),
src_loc=stmt.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
for subcond, substmts in zip(conds, stmt.cases.values()):
for substmt in substmts:
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
else:
assert False # :nocov:
def emit_tribuf(self, module_idx: int, instance: _ir.Instance):
pad = self.emit_lhs(instance.named_ports["Y"][0])
o, _signed = self.emit_rhs(module_idx, instance.named_ports["A"][0])
(oe,), _signed = self.emit_rhs(module_idx, instance.named_ports["EN"][0])
assert len(pad) == len(o)
cell = _nir.IOBuffer(module_idx, pad=pad, o=o, oe=oe, src_loc=instance.src_loc)
self.netlist.add_cell(cell)
def emit_memory(self, module_idx: int, fragment: '_mem.MemoryInstance', name: str):
cell = _nir.Memory(module_idx,
width=fragment._width,
depth=fragment._depth,
init=fragment._init,
name=name,
attributes=fragment._attrs,
src_loc=fragment._src_loc,
)
return self.netlist.add_cell(cell)
def emit_write_port(self, module_idx: int, fragment: '_mem.MemoryInstance',
port: '_mem.MemoryInstance._WritePort', memory: int):
data, _signed = self.emit_rhs(module_idx, port._data)
addr, _signed = self.emit_rhs(module_idx, port._addr)
en, _signed = self.emit_rhs(module_idx, port._en)
en = _nir.Value([en[bit // port._granularity] for bit in range(len(port._data))])
cd = fragment.domains[port._domain]
clk, = self.emit_signal(cd.clk)
cell = _nir.SyncWritePort(module_idx,
memory=memory,
data=data,
addr=addr,
en=en,
clk=clk,
clk_edge=cd.clk_edge,
src_loc=port._data.src_loc,
)
return self.netlist.add_cell(cell)
def emit_read_port(self, module_idx: int, fragment: '_mem.MemoryInstance',
port: '_mem.MemoryInstance._ReadPort', memory: int,
write_ports: 'list[int]'):
addr, _signed = self.emit_rhs(module_idx, port._addr)
if port._domain == "comb":
cell = _nir.AsyncReadPort(module_idx,
memory=memory,
width=len(port._data),
addr=addr,
src_loc=port._data.src_loc,
)
else:
(en,), _signed = self.emit_rhs(module_idx, port._en)
cd = fragment.domains[port._domain]
clk, = self.emit_signal(cd.clk)
cell = _nir.SyncReadPort(module_idx,
memory=memory,
width=len(port._data),
addr=addr,
en=en,
clk=clk,
clk_edge=cd.clk_edge,
transparent_for=tuple(write_ports[idx] for idx in port._transparency),
src_loc=port._data.src_loc,
)
data = self.netlist.add_value_cell(len(port._data), cell)
self.connect(self.emit_lhs(port._data), data, src_loc=port._data.src_loc)
def emit_instance(self, module_idx: int, instance: _ir.Instance, name: str):
ports_i = {}
ports_o = {}
ports_io = {}
outputs = []
next_output_bit = 0
for port_name, (port_conn, dir) in instance.named_ports.items():
if dir == 'i':
ports_i[port_name], _signed = self.emit_rhs(module_idx, port_conn)
elif dir == 'o':
port_conn = self.emit_lhs(port_conn)
ports_o[port_name] = (next_output_bit, len(port_conn))
outputs.append((next_output_bit, port_conn))
next_output_bit += len(port_conn)
elif dir == 'io':
ports_io[port_name] = self.emit_lhs(port_conn)
else:
assert False # :nocov:
cell = _nir.Instance(module_idx,
type=instance.type,
name=name,
parameters=instance.parameters,
attributes=instance.attrs,
ports_i=ports_i,
ports_o=ports_o,
ports_io=ports_io,
src_loc=instance.src_loc,
)
output_nets = self.netlist.add_value_cell(width=next_output_bit, cell=cell)
for start_bit, port_conn in outputs:
self.connect(port_conn, _nir.Value(output_nets[start_bit:start_bit + len(port_conn)]),
src_loc=instance.src_loc)
def emit_top_ports(self, fragment: _ir.Fragment, signal_names: _ast.SignalDict):
next_input_bit = 2 # 0 and 1 are reserved for constants
top = self.netlist.top
for signal, dir in fragment.ports.items():
assert signal not in self.netlist.signals
name = signal_names[signal]
if dir == 'i':
top.ports_i[name] = (next_input_bit, signal.width)
nets = _nir.Value(
_nir.Net.from_cell(0, bit)
for bit in range(next_input_bit, next_input_bit + signal.width)
)
next_input_bit += signal.width
self.netlist.signals[signal] = nets
elif dir == 'o':
top.ports_o[name] = self.emit_signal(signal)
elif dir == 'io':
top.ports_io[name] = (next_input_bit, signal.width)
nets = _nir.Value(
_nir.Net.from_cell(0, bit)
for bit in range(next_input_bit, next_input_bit + signal.width)
)
next_input_bit += signal.width
self.netlist.signals[signal] = nets
def emit_drivers(self):
for driver in self.drivers.values():
value = driver.emit_value(self)
if driver.domain is not None:
clk, = self.emit_signal(driver.domain.clk)
if driver.domain.rst is not None and driver.domain.async_reset:
arst, = self.emit_signal(driver.domain.rst)
else:
arst = _nir.Net.from_const(0)
cell = _nir.FlipFlop(driver.module_idx,
data=value,
init=driver.signal.reset,
clk=clk,
clk_edge=driver.domain.clk_edge,
arst=arst,
attributes=driver.signal.attrs,
src_loc=driver.signal.src_loc,
)
value = self.netlist.add_value_cell(len(value), cell)
if driver.assignments:
src_loc = driver.assignments[0].src_loc
else:
src_loc = driver.signal.src_loc
self.connect(self.emit_signal(driver.signal), value, src_loc=src_loc)
# Connect all undriven signal bits to their reset values. This can only happen for entirely
# undriven signals, or signals that are partially driven by instances.
for signal, value in self.netlist.signals.items():
for bit, net in enumerate(value):
if net.is_late and net not in self.netlist.connections:
self.netlist.connections[net] = _nir.Net.from_const((signal.reset >> bit) & 1)
def emit_fragment(self, fragment: _ir.Fragment, parent_module_idx: 'int | None'):
from . import _mem
fragment_name = self.fragment_names[fragment]
if isinstance(fragment, _ir.Instance):
assert parent_module_idx is not None
if fragment.type == "$tribuf":
self.emit_tribuf(parent_module_idx, fragment)
else:
self.emit_instance(parent_module_idx, fragment, name=fragment_name[-1])
elif isinstance(fragment, _mem.MemoryInstance):
assert parent_module_idx is not None
memory = self.emit_memory(parent_module_idx, fragment, name=fragment_name[-1])
write_ports = []
for port in fragment._write_ports:
write_ports.append(self.emit_write_port(parent_module_idx, fragment, port, memory))
for port in fragment._read_ports:
self.emit_read_port(parent_module_idx, fragment, port, memory, write_ports)
elif type(fragment) is _ir.Fragment:
module_idx = self.netlist.add_module(parent_module_idx, fragment_name)
signal_names = fragment._assign_names_to_signals()
self.netlist.modules[module_idx].signal_names = signal_names
if parent_module_idx is None:
self.emit_top_ports(fragment, signal_names)
for signal in signal_names:
self.emit_signal(signal)
for domain, stmts in fragment.statements.items():
for stmt in stmts:
self.emit_stmt(module_idx, fragment, domain, stmt, _nir.Net.from_const(1))
for subfragment, _name in fragment.subfragments:
self.emit_fragment(subfragment, module_idx)
if parent_module_idx is None:
self.emit_drivers()
else:
assert False # :nocov:
def _emit_netlist(netlist: _nir.Netlist, fragment, hierarchy):
fragment_names = fragment._assign_names_to_fragments(hierarchy)
NetlistEmitter(netlist, fragment_names).emit_fragment(fragment, None)
def _compute_net_flows(netlist: _nir.Netlist):
# Computes the net flows for all modules of the netlist.
#
# The rules for net flows are as follows:
#
# - the modules that have a given net in their net_flow form a subtree of the hierarchy
# - INTERNAL is used in the root of the subtree and nowhere else
# - OUTPUT is used for modules that contain the definition of the net, or are on the
# path from the definition to the root
# - remaining modules have a flow of INPUT (unless the net is a top-level inout port,
# in which case it is INOUT)
#
# In other words, the tree looks something like this:
#
# - [no flow] <<< top
# - [no flow]
# - INTERNAL
# - INPUT << use
# - [no flow]
# - INPUT
# - INPUT << use
# - OUTPUT
# - INPUT << use
# - [no flow]
# - OUTPUT << def
# - INPUT
# - INPUT
# - [no flow]
# - [no flow]
# - [no flow]
#
# This function doesn't assign the INOUT flow — that is corrected later, in compute_ports.
lca = {}
# Initialize by marking the definition point of every net.
for cell_idx, cell in enumerate(netlist.cells):
for net in cell.output_nets(cell_idx):
lca[net] = cell.module_idx
netlist.modules[cell.module_idx].net_flow[net] = _nir.ModuleNetFlow.INTERNAL
# Marks a use of a net within a given module, and adjusts its netflows in all modules
# as required.
def use_net(net, use_module):
if net.is_const:
return
# If the net is already present in the current module, we're done.
if net in netlist.modules[use_module].net_flow:
return
modules = netlist.modules
# Otherwise, we need to route the net through the hierarchy from def_module
# to use_module. We do that by treating use_module and def_module as pointers
# and moving them up the hierarchy until they meet at the new LCA.
def_module = lca[net]
# While def_module deeper than use_module, go up with def_module.
while len(modules[def_module].name) > len(modules[use_module].name):
modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT
def_module = modules[def_module].parent
# While use_module deeper than def_module, go up with use_module.
# If use_module is below def_module in the hierarchy, we may hit
# another module which already uses this net before hitting def_module,
# so check for this case.
while len(modules[def_module].name) < len(modules[use_module].name):
if net in modules[use_module].net_flow:
return
modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT
use_module = modules[use_module].parent
# Now both pointers should be at the same depth within the hierarchy.
assert len(modules[def_module].name) == len(modules[use_module].name)
# Move both pointers up until they meet.
while def_module != use_module:
modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT
def_module = modules[def_module].parent
modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT
use_module = modules[use_module].parent
assert len(modules[def_module].name) == len(modules[use_module].name)
# And mark the new LCA.
modules[def_module].net_flow[net] = _nir.ModuleNetFlow.INTERNAL
lca[net] = def_module
# Now mark all uses and flesh out the structure.
for cell in netlist.cells:
for net in cell.input_nets():
use_net(net, cell.module_idx)
# TODO: ?
for module_idx, module in enumerate(netlist.modules):
for signal in module.signal_names:
for net in netlist.signals[signal]:
use_net(net, module_idx)
def _compute_ports(netlist: _nir.Netlist):
# Compute the indexes at which the outputs of a cell should be split to create a distinct port.
# These indexes are stored here as nets.
port_starts = set()
for start, _ in netlist.top.ports_i.values():
port_starts.add(_nir.Net.from_cell(0, start))
for start, width in netlist.top.ports_io.values():
port_starts.add(_nir.Net.from_cell(0, start))
for cell_idx, cell in enumerate(netlist.cells):
if isinstance(cell, _nir.Instance):
for start, _ in cell.ports_o.values():
port_starts.add(_nir.Net.from_cell(cell_idx, start))
# Compute the set of all inout nets. Currently, a net has inout flow iff it is connected to
# a toplevel inout port.
inouts = set()
for start, width in netlist.top.ports_io.values():
for idx in range(start, start + width):
inouts.add(_nir.Net.from_cell(0, idx))
for module in netlist.modules:
# Collect preferred names for ports. If a port exactly matches a signal, we reuse
# the signal name for the port. Otherwise, we synthesize a private name.
name_table = {}
for signal, name in module.signal_names.items():
value = netlist.signals[signal]
if value not in name_table and not name.startswith('$'):
name_table[value] = name
# Gather together "adjacent" nets with the same flow into ports.
visited = set()
for net in sorted(module.net_flow):
flow = module.net_flow[net]
if flow == _nir.ModuleNetFlow.INTERNAL:
continue
if flow == _nir.ModuleNetFlow.INPUT and net in inouts:
flow = module.net_flow[net] = _nir.ModuleNetFlow.INOUT
if net in visited:
continue
# We found a net that needs a port. Keep joining the next nets output by the same
# cell into the same port, if applicable, but stop at instance/top port boundaries.
nets = [net]
while True:
succ = _nir.Net.from_cell(net.cell, net.bit + 1)
if succ in port_starts:
break
if succ not in module.net_flow:
break
if module.net_flow[succ] != module.net_flow[net]:
break
net = succ
nets.append(net)
value = _nir.Value(nets)
# Joined as many nets as we could, now name and add the port.
if value in name_table:
name = name_table[value]
else:
name = f"port${value[0].cell}${value[0].bit}"
module.ports[name] = (value, flow)
visited.update(value)
# The 0th cell and the 0th module correspond to the toplevel. Transfer the net flows from
# the toplevel cell (used for data flow) to the toplevel module (used to split netlist into
# modules in the backends).
top_module = netlist.modules[0]
for name, (start, width) in netlist.top.ports_i.items():
top_module.ports[name] = (
_nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)),
_nir.ModuleNetFlow.INPUT
)
for name, (start, width) in netlist.top.ports_io.items():
top_module.ports[name] = (
_nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)),
_nir.ModuleNetFlow.INOUT
)
for name, value in netlist.top.ports_o.items():
top_module.ports[name] = (value, _nir.ModuleNetFlow.OUTPUT)
def build_netlist(fragment, *, name="top"):
from ._xfrm import AssignmentLegalizer
fragment = AssignmentLegalizer()(fragment)
netlist = _nir.Netlist()
_emit_netlist(netlist, fragment, hierarchy=(name,))
netlist.resolve_all_nets()
_compute_net_flows(netlist)
_compute_ports(netlist)
return netlist

1003
amaranth/hdl/_nir.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -16,7 +16,6 @@ __all__ = ["ValueVisitor", "ValueTransformer",
"FragmentTransformer",
"TransformedElaboratable",
"DomainCollector", "DomainRenamer", "DomainLowerer",
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
"ResetInserter", "EnableInserter", "AssignmentLegalizer"]
@ -195,7 +194,7 @@ class StatementTransformer(StatementVisitor):
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
def on_Property(self, stmt):
return Property(stmt.kind, self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name)
return Property(stmt.kind, self.on_value(stmt.test), name=stmt.name)
def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
@ -533,97 +532,6 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
return new_fragment
class SwitchCleaner(StatementVisitor):
def on_ignore(self, stmt):
return stmt
on_Assign = on_ignore
on_Property = on_ignore
def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
if any(len(s) for s in cases.values()):
return Switch(stmt.test, cases)
def on_statements(self, stmts):
stmts = flatten(self.on_statement(stmt) for stmt in stmts)
return _StatementList(stmt for stmt in stmts if stmt is not None)
class LHSGroupAnalyzer(StatementVisitor):
def __init__(self):
self.signals = SignalDict()
self.unions = OrderedDict()
def find(self, signal):
if signal not in self.signals:
self.signals[signal] = len(self.signals)
group = self.signals[signal]
while group in self.unions:
group = self.unions[group]
self.signals[signal] = group
return group
def unify(self, root, *leaves):
root_group = self.find(root)
for leaf in leaves:
leaf_group = self.find(leaf)
if root_group == leaf_group:
continue
self.unions[leaf_group] = root_group
def groups(self):
groups = OrderedDict()
for signal in self.signals:
group = self.find(signal)
if group not in groups:
groups[group] = SignalSet()
groups[group].add(signal)
return groups
def on_Assign(self, stmt):
lhs_signals = stmt._lhs_signals()
if lhs_signals:
self.unify(*stmt._lhs_signals())
def on_Property(self, stmt):
lhs_signals = stmt._lhs_signals()
if lhs_signals:
self.unify(*stmt._lhs_signals())
def on_Switch(self, stmt):
for case_stmts in stmt.cases.values():
self.on_statements(case_stmts)
def on_statements(self, stmts):
assert not isinstance(stmts, str)
for stmt in stmts:
self.on_statement(stmt)
def __call__(self, stmts):
self.on_statements(stmts)
return self.groups()
class LHSGroupFilter(SwitchCleaner):
def __init__(self, signals):
self.signals = signals
def on_Assign(self, stmt):
# The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
# on LHS are a part of the same group, so it is sufficient to check any of them.
lhs_signals = stmt.lhs._lhs_signals()
if lhs_signals:
any_lhs_signal = next(iter(lhs_signals))
if any_lhs_signal in self.signals:
return stmt
def on_Property(self, stmt):
any_lhs_signal = next(iter(stmt._lhs_signals()))
if any_lhs_signal in self.signals:
return stmt
class _ControlInserter(FragmentTransformer):
def __init__(self, controls):
self.src_loc = None
@ -655,10 +563,23 @@ class ResetInserter(_ControlInserter):
fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
class _PropertyEnableInserter(StatementTransformer):
def __init__(self, en):
self.en = en
def on_Property(self, stmt):
return Switch(
self.en,
{1: [stmt]},
src_loc=stmt.src_loc,
)
class EnableInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(s) for s in signals]
fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc))
fragment.statements[domain] = _PropertyEnableInserter(self.controls[domain])(fragment.statements[domain])
def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment)

View file

@ -9,7 +9,6 @@ __all__ = ["ValueVisitor", "ValueTransformer",
"FragmentTransformer",
"TransformedElaboratable",
"DomainCollector", "DomainRenamer", "DomainLowerer",
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
"ResetInserter", "EnableInserter"]

View file

@ -244,136 +244,6 @@ class DomainLowererTestCase(FHDLTestCase):
DomainLowerer()(f)
class SwitchCleanerTestCase(FHDLTestCase):
def test_clean(self):
a = Signal()
b = Signal()
c = Signal()
stmts = [
Switch(a, {
1: a.eq(0),
0: [
b.eq(1),
Switch(b, {1: [
Switch(a|b, {})
]})
]
})
]
self.assertRepr(SwitchCleaner()(stmts), """
(
(switch (sig a)
(case 1
(eq (sig a) (const 1'd0)))
(case 0
(eq (sig b) (const 1'd1)))
)
)
""")
class LHSGroupAnalyzerTestCase(FHDLTestCase):
def test_no_group_unrelated(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
b.eq(0),
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a,)),
SignalSet((b,)),
])
def test_group_related(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
Cat(a, b).eq(0),
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a, b)),
])
def test_no_loops(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
Cat(a, b).eq(0),
Cat(a, b).eq(0),
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a, b)),
])
def test_switch(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
Switch(a, {
1: b.eq(0),
})
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a,)),
SignalSet((b,)),
])
def test_lhs_empty(self):
stmts = [
Cat().eq(0)
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
])
class LHSGroupFilterTestCase(FHDLTestCase):
def test_filter(self):
a = Signal()
b = Signal()
c = Signal()
stmts = [
Switch(a, {
1: a.eq(0),
0: [
b.eq(1),
Switch(b, {1: []})
]
})
]
self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """
(
(switch (sig a)
(case 1
(eq (sig a) (const 1'd0)))
(case 0 )
)
)
""")
def test_lhs_empty(self):
stmts = [
Cat().eq(0)
]
self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()")
class ResetInserterTestCase(FHDLTestCase):
def setUp(self):
self.s1 = Signal()