amaranth/amaranth/hdl/_ir.py

1429 lines
63 KiB
Python

from typing import Tuple
from collections import defaultdict, OrderedDict
from functools import reduce
import warnings
from .._utils import flatten, memoize
from .. import tracer, _unused
from . import _ast, _cd, _ir, _nir
__all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"]
class UnusedElaboratable(_unused.UnusedMustUse):
# The warning is initially silenced. If everything that has been constructed remains unused,
# it means the application likely crashed (with an exception, or in another way that does not
# call `sys.excepthook`), and it's not necessary to show any warnings.
# Once elaboration starts, the warning is enabled.
_MustUse__silence = True
class Elaboratable(_unused.MustUse):
_MustUse__warning = UnusedElaboratable
class DriverConflict(UserWarning):
pass
class Fragment:
@staticmethod
def get(obj, platform):
code = None
while True:
if isinstance(obj, Fragment):
return obj
elif isinstance(obj, Elaboratable):
code = obj.elaborate.__code__
UnusedElaboratable._MustUse__silence = False
obj._MustUse__used = True
new_obj = obj.elaborate(platform)
elif hasattr(obj, "elaborate"):
warnings.warn(
message="Class {!r} is an elaboratable that does not explicitly inherit from "
"Elaboratable; doing so would improve diagnostics"
.format(type(obj)),
category=RuntimeWarning,
stacklevel=2)
code = obj.elaborate.__code__
new_obj = obj.elaborate(platform)
else:
raise AttributeError(f"Object {obj!r} cannot be elaborated")
if new_obj is obj:
raise RecursionError(f"Object {obj!r} elaborates to itself")
if new_obj is None and code is not None:
warnings.warn_explicit(
message=".elaborate() returned None; missing return statement?",
category=UserWarning,
filename=code.co_filename,
lineno=code.co_firstlineno)
obj = new_obj
def __init__(self):
self.ports = _ast.SignalDict()
self.drivers = OrderedDict()
self.statements = {}
self.domains = OrderedDict()
self.subfragments = []
self.attrs = OrderedDict()
self.generated = OrderedDict()
self.flatten = False
def add_ports(self, *ports, dir):
assert dir in ("i", "o", "io")
for port in flatten(ports):
self.ports[port] = dir
def iter_ports(self, dir=None):
if dir is None:
yield from self.ports
else:
for port, port_dir in self.ports.items():
if port_dir == dir:
yield port
def add_driver(self, signal, domain="comb"):
assert isinstance(domain, str)
if domain not in self.drivers:
self.drivers[domain] = _ast.SignalSet()
self.drivers[domain].add(signal)
def iter_drivers(self):
for domain, signals in self.drivers.items():
for signal in signals:
yield domain, signal
def iter_comb(self):
if "comb" in self.drivers:
yield from self.drivers["comb"]
def iter_sync(self):
for domain, signals in self.drivers.items():
if domain == "comb":
continue
for signal in signals:
yield domain, signal
def iter_signals(self):
signals = _ast.SignalSet()
signals |= self.ports.keys()
for domain, domain_signals in self.drivers.items():
if domain != "comb":
cd = self.domains[domain]
signals.add(cd.clk)
if cd.rst is not None:
signals.add(cd.rst)
signals |= domain_signals
return signals
def add_domains(self, *domains):
for domain in flatten(domains):
assert isinstance(domain, _cd.ClockDomain)
assert domain.name not in self.domains
self.domains[domain.name] = domain
def iter_domains(self):
yield from self.domains
def add_statements(self, domain, *stmts):
assert isinstance(domain, str)
for stmt in _ast.Statement.cast(stmts):
stmt._MustUse__used = True
self.statements.setdefault(domain, _ast._StatementList()).append(stmt)
def add_subfragment(self, subfragment, name=None):
assert isinstance(subfragment, Fragment)
self.subfragments.append((subfragment, name))
def find_subfragment(self, name_or_index):
if isinstance(name_or_index, int):
if name_or_index < len(self.subfragments):
subfragment, name = self.subfragments[name_or_index]
return subfragment
raise NameError(f"No subfragment at index #{name_or_index}")
else:
for subfragment, name in self.subfragments:
if name == name_or_index:
return subfragment
raise NameError(f"No subfragment with name '{name_or_index}'")
def find_generated(self, *path):
if len(path) > 1:
path_component, *path = path
return self.find_subfragment(path_component).find_generated(*path)
else:
item, = path
return self.generated[item]
def elaborate(self, platform):
return self
def _merge_subfragment(self, subfragment):
# Merge subfragment's everything except clock domains into this fragment.
# Flattening is done after clock domain propagation, so we can assume the domains
# are already the same in every involved fragment in the first place.
self.ports.update(subfragment.ports)
for domain, signal in subfragment.iter_drivers():
self.add_driver(signal, domain)
for domain, statements in subfragment.statements.items():
self.statements.setdefault(domain, []).extend(statements)
self.subfragments += subfragment.subfragments
# Remove the merged subfragment.
found = False
for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr:
if subfragment == check_subfrag:
del self.subfragments[i]
found = True
break
assert found
def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
assert mode in ("silent", "warn", "error")
from ._mem import MemoryInstance
driver_subfrags = _ast.SignalDict()
def add_subfrag(registry, entity, entry):
# Because of missing domain insertion, at the point when this code runs, we have
# a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to
# the actual signals (because the signal itself can be driven as well); but leave
# the unbound ones as it is, because there's no concrete signal for it yet anyway.
if isinstance(entity, _ast.ClockSignal) and entity.domain in self.domains:
entity = self.domains[entity.domain].clk
elif isinstance(entity, _ast.ResetSignal) and entity.domain in self.domains:
entity = self.domains[entity.domain].rst
if entity not in registry:
registry[entity] = set()
registry[entity].add(entry)
# For each signal driven by this fragment and/or its subfragments, determine which
# subfragments also drive it.
for domain, signal in self.iter_drivers():
add_subfrag(driver_subfrags, signal, (None, hierarchy))
flatten_subfrags = set()
for i, (subfrag, name) in enumerate(self.subfragments):
if name is None:
name = f"<unnamed #{i}>"
subfrag_hierarchy = hierarchy + (name,)
if subfrag.flatten:
# Always flatten subfragments that explicitly request it.
flatten_subfrags.add((subfrag, subfrag_hierarchy))
if isinstance(subfrag, (Instance, MemoryInstance)):
# Never flatten instances.
continue
# First, recurse into subfragments and let them detect driver conflicts as well.
subfrag_drivers = \
subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
# Second, classify subfragments by signals they drive.
for signal in subfrag_drivers:
add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
# Find out the set of subfragments that needs to be flattened into this fragment
# to resolve driver-driver conflicts.
def flatten_subfrags_if_needed(subfrags):
if len(subfrags) == 1:
return []
flatten_subfrags.update((f, h) for f, h in subfrags if f is not None)
return list(sorted(".".join(h) for f, h in subfrags))
for signal, subfrags in driver_subfrags.items():
subfrag_names = flatten_subfrags_if_needed(subfrags)
if not subfrag_names:
continue
# While we're at it, show a message.
message = ("Signal '{}' is driven from multiple fragments: {}"
.format(signal, ", ".join(subfrag_names)))
if mode == "error":
raise DriverConflict(message)
elif mode == "warn":
message += "; hierarchy will be flattened"
warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
# Flatten hierarchy.
for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
self._merge_subfragment(subfrag)
# If we flattened anything, we might be in a situation where we have a driver conflict
# again, e.g. if we had a tree of fragments like A --- B --- C where only fragments
# A and C were driving a signal S. In that case, since B is not driving S itself,
# processing B will not result in any flattening, but since B is transitively driving S,
# processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which
# has another conflict.
if any(flatten_subfrags):
# Try flattening again.
return self._resolve_hierarchy_conflicts(hierarchy, mode)
# Nothing was flattened, we're done!
return _ast.SignalSet(driver_subfrags.keys())
def _propagate_domains_up(self, hierarchy=("top",)):
from ._xfrm import DomainRenamer
domain_subfrags = defaultdict(set)
# For each domain defined by a subfragment, determine which subfragments define it.
for i, (subfrag, name) in enumerate(self.subfragments):
# First, recurse into subfragments and let them propagate domains up as well.
hier_name = name
if hier_name is None:
hier_name = f"<unnamed #{i}>"
subfrag._propagate_domains_up(hierarchy + (hier_name,))
# Second, classify subfragments by domains they define.
for domain_name, domain in subfrag.domains.items():
if domain.local:
continue
domain_subfrags[domain_name].add((subfrag, name, i))
# For each domain defined by more than one subfragment, rename the domain in each
# of the subfragments such that they no longer conflict.
for domain_name, subfrags in domain_subfrags.items():
if len(subfrags) == 1:
continue
names = [n for f, n, i in subfrags]
if not all(names):
names = sorted(f"<unnamed #{i}>" if n is None else f"'{n}'"
for f, n, i in subfrags)
raise _cd.DomainError(
"Domain '{}' is defined by subfragments {} of fragment '{}'; it is necessary "
"to either rename subfragment domains explicitly, or give names to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
if len(names) != len(set(names)):
names = sorted(f"#{i}" for f, n, i in subfrags)
raise _cd.DomainError(
"Domain '{}' is defined by subfragments {} of fragment '{}', some of which "
"have identical names; it is necessary to either rename subfragment domains "
"explicitly, or give distinct names to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
for subfrag, name, i in subfrags:
domain_name_map = {domain_name: f"{name}_{domain_name}"}
self.subfragments[i] = (DomainRenamer(domain_name_map)(subfrag), name)
# Finally, collect the (now unique) subfragment domains, and merge them into our domains.
for subfrag, name in self.subfragments:
for domain_name, domain in subfrag.domains.items():
if domain.local:
continue
self.add_domains(domain)
def _propagate_domains_down(self):
# For each domain defined in this fragment, ensure it also exists in all subfragments.
for subfrag, name in self.subfragments:
for domain in self.iter_domains():
if domain in subfrag.domains:
assert self.domains[domain] is subfrag.domains[domain]
else:
subfrag.add_domains(self.domains[domain])
subfrag._propagate_domains_down()
def _create_missing_domains(self, missing_domain, *, platform=None):
from ._xfrm import DomainCollector
collector = DomainCollector()
collector(self)
new_domains = []
for domain_name in collector.used_domains - collector.defined_domains:
if domain_name == "comb":
continue
value = missing_domain(domain_name)
if value is None:
raise _cd.DomainError(f"Domain '{domain_name}' is used but not defined")
if type(value) is _cd.ClockDomain:
self.add_domains(value)
# And expose ports on the newly added clock domain, since it is added directly
# and there was no chance to add any logic driving it.
new_domains.append(value)
else:
new_fragment = Fragment.get(value, platform=platform)
if domain_name not in new_fragment.domains:
defined = new_fragment.domains.keys()
raise _cd.DomainError(
"Fragment returned by missing domain callback does not define "
"requested domain '{}' (defines {})."
.format(domain_name, ", ".join(f"'{n}'" for n in defined)))
self.add_subfragment(new_fragment, f"cd_{domain_name}")
self.add_domains(new_fragment.domains.values())
return new_domains
def _propagate_domains(self, missing_domain, *, platform=None):
self._propagate_domains_up()
self._propagate_domains_down()
self._resolve_hierarchy_conflicts()
new_domains = self._create_missing_domains(missing_domain, platform=platform)
self._propagate_domains_down()
return new_domains
def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
from ._mem import MemoryInstance
def add_uses(*sigs, self=self):
for sig in flatten(sigs):
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
def add_defs(*sigs):
for sig in flatten(sigs):
if sig not in defs:
defs[sig] = self
else:
assert defs[sig] is self
def add_io(*sigs):
for sig in flatten(sigs):
if sig not in ios:
ios[sig] = self
else:
assert ios[sig] is self
# Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains).
for stmts in self.statements.values():
for stmt in stmts:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for domain, _ in self.iter_sync():
cd = self.domains[domain]
add_uses(cd.clk)
if cd.rst is not None:
add_uses(cd.rst)
# Repeat for subfragments.
for subfrag, name in self.subfragments:
if isinstance(subfrag, Instance):
for port_name, (value, dir) in subfrag.named_ports.items():
if dir == "i":
# Prioritize defs over uses.
rhs_without_outputs = value._rhs_signals() - subfrag.iter_ports(dir="o")
subfrag.add_ports(rhs_without_outputs, dir=dir)
add_uses(value._rhs_signals())
if dir == "o":
subfrag.add_ports(value._lhs_signals(), dir=dir)
add_defs(value._lhs_signals())
if dir == "io":
subfrag.add_ports(value._lhs_signals(), dir=dir)
add_io(value._lhs_signals())
elif isinstance(subfrag, MemoryInstance):
for port in subfrag._read_ports:
subfrag.add_ports(port._data._lhs_signals(), dir="o")
add_defs(port._data._lhs_signals())
for value in [port._addr, port._en]:
subfrag.add_ports(value._rhs_signals(), dir="i")
add_uses(value._rhs_signals())
for port in subfrag._write_ports:
for value in [port._addr, port._en, port._data]:
subfrag.add_ports(value._rhs_signals(), dir="i")
add_uses(value._rhs_signals())
for domain, _ in subfrag.iter_sync():
cd = subfrag.domains[domain]
add_uses(cd.clk)
if cd.rst is not None:
add_uses(cd.rst)
else:
parent[subfrag] = self
level [subfrag] = level[self] + 1
subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
def _propagate_ports(self, ports, all_undef_as_ports):
# Take this fragment graph:
#
# __ B (def: q, use: p r)
# /
# A (def: p, use: q r)
# \
# \_ C (def: r, use: p q)
#
# We need to consider three cases.
# 1. Signal p requires an input port in B;
# 2. Signal r requires an output port in C;
# 3. Signal r requires an output port in C and an input port in B.
#
# Adding these ports can be in general done in three steps for each signal:
# 1. Find the least common ancestor of all uses and defs.
# 2. Going upwards from the single def, add output ports.
# 3. Going upwards from all uses, add input ports.
parent = {self: None}
level = {self: 0}
uses = _ast.SignalDict()
defs = _ast.SignalDict()
ios = _ast.SignalDict()
self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
ports = _ast.SignalSet(ports)
if all_undef_as_ports:
for sig in uses:
if sig in defs:
continue
ports.add(sig)
for sig in ports:
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
@memoize
def lca_of(fragu, fragv):
# Normalize fragu to be deeper than fragv.
if level[fragu] < level[fragv]:
fragu, fragv = fragv, fragu
# Find ancestor of fragu on the same level as fragv.
for _ in range(level[fragu] - level[fragv]):
fragu = parent[fragu]
# If fragv was the ancestor of fragv, we're done.
if fragu == fragv:
return fragu
# Otherwise, they are at the same level but in different branches. Step both fragu
# and fragv until we find the common ancestor.
while parent[fragu] != parent[fragv]:
fragu = parent[fragu]
fragv = parent[fragv]
return parent[fragu]
for sig in uses:
if sig in defs:
lca = reduce(lca_of, uses[sig], defs[sig])
else:
lca = reduce(lca_of, uses[sig])
for frag in uses[sig]:
if sig in defs and frag is defs[sig]:
continue
while frag != lca:
frag.add_ports(sig, dir="i")
frag = parent[frag]
if sig in defs:
frag = defs[sig]
while frag != lca:
frag.add_ports(sig, dir="o")
frag = parent[frag]
for sig in ios:
frag = ios[sig]
while frag is not None:
frag.add_ports(sig, dir="io")
frag = parent[frag]
for sig in ports:
if sig in ios:
continue
if sig in defs:
self.add_ports(sig, dir="o")
else:
self.add_ports(sig, dir="i")
def prepare(self, ports=None, missing_domain=lambda name: _cd.ClockDomain(name)):
from ._xfrm import DomainLowerer
new_domains = self._propagate_domains(missing_domain)
fragment = DomainLowerer()(self)
if ports is None:
fragment._propagate_ports(ports=(), all_undef_as_ports=True)
else:
if not isinstance(ports, tuple) and not isinstance(ports, list):
msg = "`ports` must be either a list or a tuple, not {!r}"\
.format(ports)
if isinstance(ports, _ast.Value):
msg += " (did you mean `ports=(<signal>,)`, rather than `ports=<signal>`?)"
raise TypeError(msg)
mapped_ports = []
# Lower late bound signals like ClockSignal() to ports.
port_lowerer = DomainLowerer(fragment.domains)
for port in ports:
if not isinstance(port, (_ast.Signal, _ast.ClockSignal, _ast.ResetSignal)):
raise TypeError("Only signals may be added as ports, not {!r}"
.format(port))
mapped_ports.append(port_lowerer.on_value(port))
# Add ports for all newly created missing clock domains, since not doing so defeats
# the purpose of domain auto-creation. (It's possible to refer to these ports before
# the domain actually exists through late binding, but it's inconvenient.)
for cd in new_domains:
mapped_ports.append(cd.clk)
if cd.rst is not None:
mapped_ports.append(cd.rst)
fragment._propagate_ports(ports=mapped_ports, all_undef_as_ports=False)
return fragment
def _assign_names_to_signals(self):
"""Assign names to signals used in this fragment.
Returns
-------
SignalDict of Signal to str
A mapping from signals used in this fragment to their local names. Because names are
deduplicated using local information only, the same signal used in a different fragment
may get a different name.
"""
signal_names = _ast.SignalDict()
assigned_names = set()
def add_signal_name(signal):
if signal not in signal_names:
if signal.name not in assigned_names:
name = signal.name
else:
name = f"{signal.name}${len(assigned_names)}"
assert name not in assigned_names
signal_names[signal] = name
assigned_names.add(name)
for port in self.ports.keys():
add_signal_name(port)
for domain_name, domain_signals in self.drivers.items():
if domain_name != "comb":
domain = self.domains[domain_name]
add_signal_name(domain.clk)
if domain.rst is not None:
add_signal_name(domain.rst)
for statements in self.statements.values():
for statement in statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (_ast.ClockSignal, _ast.ResetSignal)):
add_signal_name(signal)
return signal_names
def _assign_names_to_fragments(self, hierarchy=("top",), *, _names=None):
"""Assign names to this fragment and its subfragments.
Subfragments may not necessarily have a name. This method assigns every such subfragment
a name, ``U$<number>``, where ``<number>`` is based on its location in the hierarchy.
Subfragment names may collide with signal names safely in Amaranth, but this may confuse
backends. This method assigns every such subfragment a name, ``<name>$U$<number>``, where
``name`` is its original name, and ``<number>`` is based on its location in the hierarchy.
Arguments
---------
hierarchy : tuple of str
Name of this fragment.
Returns
-------
dict of Fragment to tuple of str
A mapping from this fragment and its subfragments to their full hierarchical names.
"""
if _names is None:
_names = dict()
_names[self] = hierarchy
signal_names = set(self._assign_names_to_signals().values())
for subfragment_index, (subfragment, subfragment_name) in enumerate(self.subfragments):
if subfragment_name is None:
subfragment_name = f"U${subfragment_index}"
elif subfragment_name in signal_names:
subfragment_name = f"{subfragment_name}$U${subfragment_index}"
assert subfragment_name not in signal_names
subfragment._assign_names_to_fragments(hierarchy=(*hierarchy, subfragment_name),
_names=_names)
return _names
class Instance(Fragment):
def __init__(self, type, *args, src_loc=None, src_loc_at=0, **kwargs):
super().__init__()
self.type = type
self.parameters = OrderedDict()
self.named_ports = OrderedDict()
self.src_loc = src_loc or tracer.get_src_loc(src_loc_at)
for (kind, name, value) in args:
if kind == "a":
self.attrs[name] = value
elif kind == "p":
self.parameters[name] = value
elif kind in ("i", "o", "io"):
self.named_ports[name] = (_ast.Value.cast(value), kind)
else:
raise NameError("Instance argument {!r} should be a tuple (kind, name, value) "
"where kind is one of \"a\", \"p\", \"i\", \"o\", or \"io\""
.format((kind, name, value)))
for kw, arg in kwargs.items():
if kw.startswith("a_"):
self.attrs[kw[2:]] = arg
elif kw.startswith("p_"):
self.parameters[kw[2:]] = arg
elif kw.startswith("i_"):
self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "i")
elif kw.startswith("o_"):
self.named_ports[kw[2:]] = (_ast.Value.cast(arg), "o")
elif kw.startswith("io_"):
self.named_ports[kw[3:]] = (_ast.Value.cast(arg), "io")
else:
raise NameError("Instance keyword argument {}={!r} does not start with one of "
"\"a_\", \"p_\", \"i_\", \"o_\", or \"io_\""
.format(kw, arg))
############################################################################################### >:3
class NetlistDriver:
def __init__(self, module_idx: int, signal: _ast.Signal,
domain: '_cd.ClockDomain | None', *, src_loc):
self.module_idx = module_idx
self.signal = signal
self.domain = domain
self.src_loc = src_loc
self.assignments = []
def emit_value(self, builder):
if self.domain is None:
init = _ast.Const(self.signal.init, self.signal.width)
default, _signed = builder.emit_rhs(self.module_idx, init)
else:
default = builder.emit_signal(self.signal)
if len(self.assignments) == 1:
assign, = self.assignments
if assign.cond == 1 and assign.start == 0 and len(assign.value) == len(default):
return assign.value
cell = _nir.AssignmentList(self.module_idx, default=default, assignments=self.assignments,
src_loc=self.signal.src_loc)
return builder.netlist.add_value_cell(len(default), cell)
class NetlistEmitter:
def __init__(self, netlist: _nir.Netlist, fragment_names: 'dict[_ir.Fragment, str]'):
self.netlist = netlist
self.fragment_names = fragment_names
self.drivers = _ast.SignalDict()
self.rhs_cache: dict[int, Tuple[_nir.Value, bool, _ast.Value]] = {}
# Collected for driver conflict diagnostics only.
self.late_net_to_signal = {}
self.connect_src_loc = {}
def emit_signal(self, signal) -> _nir.Value:
if signal in self.netlist.signals:
return self.netlist.signals[signal]
value = self.netlist.alloc_late_value(len(signal))
self.netlist.signals[signal] = value
for bit, net in enumerate(value):
self.late_net_to_signal[net] = (signal, bit)
return value
# Used for instance outputs and read port data, not used for actual assignments.
def emit_lhs(self, value: _ast.Value):
if isinstance(value, _ast.Signal):
return self.emit_signal(value)
elif isinstance(value, _ast.Cat):
result = []
for part in value.parts:
result += self.emit_lhs(part)
return _nir.Value(result)
elif isinstance(value, _ast.Slice):
return self.emit_lhs(value.value)[value.start:value.stop]
elif isinstance(value, _ast.Operator):
assert value.operator in ('u', 's')
return self.emit_lhs(value.operands[0])
else:
raise TypeError # :nocov:
def extend(self, value: _nir.Value, signed: bool, width: int):
nets = list(value)
while len(nets) < width:
if signed:
nets.append(nets[-1])
else:
nets.append(_nir.Net.from_const(0))
return _nir.Value(nets)
def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc):
op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc)
return self.netlist.add_value_cell(op.width, op)
def unify_shapes_bitwise(self,
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
if signed_a == signed_b:
width = max(len(operand_a), len(operand_b))
elif signed_a:
width = max(len(operand_a), len(operand_b) + 1)
else: # signed_b
width = max(len(operand_a) + 1, len(operand_b))
operand_a = self.extend(operand_a, signed_a, width)
operand_b = self.extend(operand_b, signed_b, width)
signed = signed_a or signed_b
return (operand_a, operand_b, signed)
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
"""Emits a RHS value, returns a tuple of (value, is_signed)"""
try:
result, signed, value = self.rhs_cache[id(value)]
return result, signed
except KeyError:
pass
if isinstance(value, _ast.Const):
result = _nir.Value(
_nir.Net.from_const((value.value >> bit) & 1)
for bit in range(value.width)
)
signed = value.signed
elif isinstance(value, _ast.Signal):
result = self.emit_signal(value)
signed = value.signed
elif isinstance(value, _ast.Operator):
if len(value.operands) == 1:
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0])
if value.operator == 's':
result = operand_a
signed = True
elif value.operator == 'u':
result = operand_a
signed = False
elif value.operator == '+':
result = operand_a
signed = signed_a
elif value.operator == '-':
operand_a = self.extend(operand_a, signed_a, len(operand_a) + 1)
result = self.emit_operator(module_idx, '-', operand_a,
src_loc=value.src_loc)
signed = True
elif value.operator == '~':
result = self.emit_operator(module_idx, '~', operand_a,
src_loc=value.src_loc)
signed = signed_a
elif value.operator in ('b', 'r|', 'r&', 'r^'):
result = self.emit_operator(module_idx, value.operator, operand_a,
src_loc=value.src_loc)
signed = False
else:
assert False # :nocov:
elif len(value.operands) == 2:
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[0])
operand_b, signed_b = self.emit_rhs(module_idx, value.operands[1])
if value.operator in ('|', '&', '^'):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, value.operator, operand_a, operand_b,
src_loc=value.src_loc)
elif value.operator in ('+', '-'):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
width = len(operand_a) + 1
operand_a = self.extend(operand_a, signed, width)
operand_b = self.extend(operand_b, signed, width)
result = self.emit_operator(module_idx, value.operator, operand_a, operand_b,
src_loc=value.src_loc)
if value.operator == '-':
signed = True
elif value.operator == '*':
width = len(operand_a) + len(operand_b)
operand_a = self.extend(operand_a, signed_a, width)
operand_b = self.extend(operand_b, signed_b, width)
result = self.emit_operator(module_idx, '*', operand_a, operand_b,
src_loc=value.src_loc)
signed = signed_a or signed_b
elif value.operator == '//':
width = len(operand_a) + signed_b
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
if len(operand_a) < width:
operand_a = self.extend(operand_a, signed, width)
operand_b = self.extend(operand_b, signed, width)
operator = 's//' if signed else 'u//'
result = _nir.Value(
self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)[:width]
)
elif value.operator == '%':
width = len(operand_b)
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
operator = 's%' if signed else 'u%'
result = _nir.Value(
self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)[:width]
)
signed = signed_b
elif value.operator == '<<':
operand_a = self.extend(operand_a, signed_a,
len(operand_a) + 2 ** len(operand_b) - 1)
result = self.emit_operator(module_idx, '<<', operand_a, operand_b,
src_loc=value.src_loc)
signed = signed_a
elif value.operator == '>>':
operator = 's>>' if signed_a else 'u>>'
result = self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)
signed = signed_a
elif value.operator in ('==', '!='):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, value.operator, operand_a, operand_b,
src_loc=value.src_loc)
signed = False
elif value.operator in ('<', '>', '<=', '>='):
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
operator = ('s' if signed else 'u') + value.operator
result = self.emit_operator(module_idx, operator, operand_a, operand_b,
src_loc=value.src_loc)
signed = False
else:
assert False # :nocov:
elif len(value.operands) == 3:
assert value.operator == 'm'
operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0])
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1])
operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2])
if len(operand_s) != 1:
operand_s = self.emit_operator(module_idx, 'b', operand_s,
src_loc=value.src_loc)
operand_a, operand_b, signed = \
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b,
src_loc=value.src_loc)
else:
assert False # :nocov:
elif isinstance(value, _ast.Slice):
inner, _signed = self.emit_rhs(module_idx, value.value)
result = _nir.Value(inner[value.start:value.stop])
signed = False
elif isinstance(value, _ast.Part):
inner, signed = self.emit_rhs(module_idx, value.value)
offset, _signed = self.emit_rhs(module_idx, value.offset)
cell = _nir.Part(module_idx, value=inner, value_signed=signed, width=value.width,
stride=value.stride, offset=offset, src_loc=value.src_loc)
result = self.netlist.add_value_cell(value.width, cell)
signed = False
elif isinstance(value, _ast.ArrayProxy):
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
width = 0
signed = False
for elem, elem_signed in elems:
if elem_signed:
if not signed:
width += 1
signed = True
width = max(width, len(elem))
elif signed:
width = max(width, len(elem) + 1)
else:
width = max(width, len(elem))
elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems)
index, _signed = self.emit_rhs(module_idx, value.index)
cell = _nir.ArrayMux(module_idx, width=width, elems=elems, index=index,
src_loc=value.src_loc)
result = self.netlist.add_value_cell(width, cell)
elif isinstance(value, _ast.Cat):
nets = []
for val in value.parts:
inner, _signed = self.emit_rhs(module_idx, val)
for net in inner:
nets.append(net)
result = _nir.Value(nets)
signed = False
elif isinstance(value, _ast.AnyValue):
result = self.netlist.add_value_cell(value.width,
_nir.AnyValue(module_idx, kind=value.kind.value, width=value.width,
src_loc=value.src_loc))
signed = value.signed
elif isinstance(value, _ast.Initial):
result = self.netlist.add_value_cell(1, _nir.Initial(module_idx, src_loc=value.src_loc))
signed = False
else:
assert False # :nocov:
assert value.shape().width == len(result), \
f"Value {value!r} with shape {value.shape()!r} does not match " \
f"result with width {len(result)}"
# Add the value itself to the cache to make sure `id(value)` remains allocated and pointing
# at `value`. This would be a weakref.WeakKeyDictionary if `value` was hashable.
self.rhs_cache[id(value)] = result, signed, value
return (result, signed)
def connect(self, lhs: _nir.Value, rhs: _nir.Value, *, src_loc):
assert len(lhs) == len(rhs)
for left, right in zip(lhs, rhs):
if left in self.netlist.connections:
signal, bit = self.late_net_to_signal[left]
other_src_loc = self.connect_src_loc[left]
raise _ir.DriverConflict(f"Bit {bit} of signal {signal!r} has multiple drivers: "
f"{other_src_loc} and {src_loc}")
self.netlist.connections[left] = right
self.connect_src_loc[left] = src_loc
def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
stmt: _ast.Statement, cond: _nir.Net):
if domain == "comb":
cd: _cd.ClockDomain | None = None
else:
cd = fragment.domains[domain]
if isinstance(stmt, _ast.Assign):
if isinstance(stmt.lhs, _ast.Signal):
signal = stmt.lhs
start = 0
width = signal.width
elif isinstance(stmt.lhs, _ast.Slice):
signal = stmt.lhs.value
start = stmt.lhs.start
width = stmt.lhs.stop - stmt.lhs.start
else:
assert False # :nocov:
assert isinstance(signal, _ast.Signal)
if signal in self.drivers:
driver = self.drivers[signal]
if driver.domain is not cd:
raise _ir.DriverConflict(
f"Signal {signal} driven from domain {cd} at {stmt.src_loc} and domain "
f"{driver.domain} at {driver.src_loc}")
if driver.module_idx != module_idx:
mod_name = ".".join(self.netlist.modules[module_idx].name or ("<toplevel>",))
other_mod_name = \
".".join(self.netlist.modules[driver.module_idx].name or ("<toplevel>",))
raise _ir.DriverConflict(
f"Signal {signal} driven from module {mod_name} at {stmt.src_loc} and "
f"module {other_mod_name} at {driver.src_loc}")
else:
driver = NetlistDriver(module_idx, signal, domain=cd, src_loc=stmt.src_loc)
self.drivers[signal] = driver
rhs, signed = self.emit_rhs(module_idx, stmt.rhs)
if len(rhs) > width:
rhs = _nir.Value(rhs[:width])
if len(rhs) < width:
rhs = self.extend(rhs, signed, width)
driver.assignments.append(_nir.Assignment(cond=cond, start=start, value=rhs,
src_loc=stmt.src_loc))
elif isinstance(stmt, _ast.Property):
test, _signed = self.emit_rhs(module_idx, stmt.test)
if len(test) != 1:
test = self.emit_operator(module_idx, 'b', test, src_loc=stmt.src_loc)
test, = test
en_cell = _nir.AssignmentList(module_idx,
default=_nir.Value.zeros(),
assignments=[
_nir.Assignment(cond=cond, start=0, value=_nir.Value.ones(),
src_loc=stmt.src_loc)
],
src_loc=stmt.src_loc)
cond, = self.netlist.add_value_cell(1, en_cell)
if cd is None:
cell = _nir.AsyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond,
name=stmt.name, src_loc=stmt.src_loc)
else:
clk, = self.emit_signal(cd.clk)
cell = _nir.SyncProperty(module_idx, kind=stmt.kind.value, test=test, en=cond,
clk=clk, clk_edge=cd.clk_edge, name=stmt.name,
src_loc=stmt.src_loc)
self.netlist.add_cell(cell)
elif isinstance(stmt, _ast.Switch):
test, _signed = self.emit_rhs(module_idx, stmt.test)
conds = []
for patterns in stmt.cases:
if patterns:
for pattern in patterns:
assert len(pattern) == len(test)
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
src_loc=stmt.case_src_locs.get(patterns))
net, = self.netlist.add_value_cell(1, cell)
conds.append(net)
else:
conds.append(_nir.Net.from_const(1))
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds),
src_loc=stmt.src_loc)
conds = self.netlist.add_value_cell(len(conds), cell)
for subcond, substmts in zip(conds, stmt.cases.values()):
for substmt in substmts:
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
else:
assert False # :nocov:
def emit_tribuf(self, module_idx: int, instance: _ir.Instance):
pad = self.emit_lhs(instance.named_ports["Y"][0])
o, _signed = self.emit_rhs(module_idx, instance.named_ports["A"][0])
(oe,), _signed = self.emit_rhs(module_idx, instance.named_ports["EN"][0])
assert len(pad) == len(o)
cell = _nir.IOBuffer(module_idx, pad=pad, o=o, oe=oe, src_loc=instance.src_loc)
self.netlist.add_cell(cell)
def emit_memory(self, module_idx: int, fragment: '_mem.MemoryInstance', name: str):
cell = _nir.Memory(module_idx,
width=fragment._width,
depth=fragment._depth,
init=fragment._init,
name=name,
attributes=fragment._attrs,
src_loc=fragment._src_loc,
)
return self.netlist.add_cell(cell)
def emit_write_port(self, module_idx: int, fragment: '_mem.MemoryInstance',
port: '_mem.MemoryInstance._WritePort', memory: int):
data, _signed = self.emit_rhs(module_idx, port._data)
addr, _signed = self.emit_rhs(module_idx, port._addr)
en, _signed = self.emit_rhs(module_idx, port._en)
en = _nir.Value([en[bit // port._granularity] for bit in range(len(port._data))])
cd = fragment.domains[port._domain]
clk, = self.emit_signal(cd.clk)
cell = _nir.SyncWritePort(module_idx,
memory=memory,
data=data,
addr=addr,
en=en,
clk=clk,
clk_edge=cd.clk_edge,
src_loc=port._data.src_loc,
)
return self.netlist.add_cell(cell)
def emit_read_port(self, module_idx: int, fragment: '_mem.MemoryInstance',
port: '_mem.MemoryInstance._ReadPort', memory: int,
write_ports: 'list[int]'):
addr, _signed = self.emit_rhs(module_idx, port._addr)
if port._domain == "comb":
cell = _nir.AsyncReadPort(module_idx,
memory=memory,
width=len(port._data),
addr=addr,
src_loc=port._data.src_loc,
)
else:
(en,), _signed = self.emit_rhs(module_idx, port._en)
cd = fragment.domains[port._domain]
clk, = self.emit_signal(cd.clk)
cell = _nir.SyncReadPort(module_idx,
memory=memory,
width=len(port._data),
addr=addr,
en=en,
clk=clk,
clk_edge=cd.clk_edge,
transparent_for=tuple(write_ports[idx] for idx in port._transparency),
src_loc=port._data.src_loc,
)
data = self.netlist.add_value_cell(len(port._data), cell)
self.connect(self.emit_lhs(port._data), data, src_loc=port._data.src_loc)
def emit_instance(self, module_idx: int, instance: _ir.Instance, name: str):
ports_i = {}
ports_o = {}
ports_io = {}
outputs = []
next_output_bit = 0
for port_name, (port_conn, dir) in instance.named_ports.items():
if dir == 'i':
ports_i[port_name], _signed = self.emit_rhs(module_idx, port_conn)
elif dir == 'o':
port_conn = self.emit_lhs(port_conn)
ports_o[port_name] = (next_output_bit, len(port_conn))
outputs.append((next_output_bit, port_conn))
next_output_bit += len(port_conn)
elif dir == 'io':
ports_io[port_name] = self.emit_lhs(port_conn)
else:
assert False # :nocov:
cell = _nir.Instance(module_idx,
type=instance.type,
name=name,
parameters=instance.parameters,
attributes=instance.attrs,
ports_i=ports_i,
ports_o=ports_o,
ports_io=ports_io,
src_loc=instance.src_loc,
)
output_nets = self.netlist.add_value_cell(width=next_output_bit, cell=cell)
for start_bit, port_conn in outputs:
self.connect(port_conn, _nir.Value(output_nets[start_bit:start_bit + len(port_conn)]),
src_loc=instance.src_loc)
def emit_top_ports(self, fragment: _ir.Fragment, signal_names: _ast.SignalDict):
next_input_bit = 2 # 0 and 1 are reserved for constants
top = self.netlist.top
for signal, dir in fragment.ports.items():
assert signal not in self.netlist.signals
name = signal_names[signal]
if dir == 'i':
top.ports_i[name] = (next_input_bit, signal.width)
nets = _nir.Value(
_nir.Net.from_cell(0, bit)
for bit in range(next_input_bit, next_input_bit + signal.width)
)
next_input_bit += signal.width
self.netlist.signals[signal] = nets
elif dir == 'o':
top.ports_o[name] = self.emit_signal(signal)
elif dir == 'io':
top.ports_io[name] = (next_input_bit, signal.width)
nets = _nir.Value(
_nir.Net.from_cell(0, bit)
for bit in range(next_input_bit, next_input_bit + signal.width)
)
next_input_bit += signal.width
self.netlist.signals[signal] = nets
def emit_drivers(self):
for driver in self.drivers.values():
value = driver.emit_value(self)
if driver.domain is not None:
clk, = self.emit_signal(driver.domain.clk)
if driver.domain.rst is not None and driver.domain.async_reset:
arst, = self.emit_signal(driver.domain.rst)
else:
arst = _nir.Net.from_const(0)
cell = _nir.FlipFlop(driver.module_idx,
data=value,
init=driver.signal.init,
clk=clk,
clk_edge=driver.domain.clk_edge,
arst=arst,
attributes=driver.signal.attrs,
src_loc=driver.signal.src_loc,
)
value = self.netlist.add_value_cell(len(value), cell)
if driver.assignments:
src_loc = driver.assignments[0].src_loc
else:
src_loc = driver.signal.src_loc
self.connect(self.emit_signal(driver.signal), value, src_loc=src_loc)
# Connect all undriven signal bits to their initial values. This can only happen for entirely
# undriven signals, or signals that are partially driven by instances.
for signal, value in self.netlist.signals.items():
for bit, net in enumerate(value):
if net.is_late and net not in self.netlist.connections:
self.netlist.connections[net] = _nir.Net.from_const((signal.init >> bit) & 1)
def emit_fragment(self, fragment: _ir.Fragment, parent_module_idx: 'int | None'):
from . import _mem
fragment_name = self.fragment_names[fragment]
if isinstance(fragment, _ir.Instance):
assert parent_module_idx is not None
if fragment.type == "$tribuf":
self.emit_tribuf(parent_module_idx, fragment)
else:
self.emit_instance(parent_module_idx, fragment, name=fragment_name[-1])
elif isinstance(fragment, _mem.MemoryInstance):
assert parent_module_idx is not None
memory = self.emit_memory(parent_module_idx, fragment, name=fragment_name[-1])
write_ports = []
for port in fragment._write_ports:
write_ports.append(self.emit_write_port(parent_module_idx, fragment, port, memory))
for port in fragment._read_ports:
self.emit_read_port(parent_module_idx, fragment, port, memory, write_ports)
elif type(fragment) is _ir.Fragment:
module_idx = self.netlist.add_module(parent_module_idx, fragment_name)
signal_names = fragment._assign_names_to_signals()
self.netlist.modules[module_idx].signal_names = signal_names
if parent_module_idx is None:
self.emit_top_ports(fragment, signal_names)
for signal in signal_names:
self.emit_signal(signal)
for domain, stmts in fragment.statements.items():
for stmt in stmts:
self.emit_stmt(module_idx, fragment, domain, stmt, _nir.Net.from_const(1))
for subfragment, _name in fragment.subfragments:
self.emit_fragment(subfragment, module_idx)
if parent_module_idx is None:
self.emit_drivers()
else:
assert False # :nocov:
def _emit_netlist(netlist: _nir.Netlist, fragment, hierarchy):
fragment_names = fragment._assign_names_to_fragments(hierarchy)
NetlistEmitter(netlist, fragment_names).emit_fragment(fragment, None)
def _compute_net_flows(netlist: _nir.Netlist):
# Computes the net flows for all modules of the netlist.
#
# The rules for net flows are as follows:
#
# - the modules that have a given net in their net_flow form a subtree of the hierarchy
# - INTERNAL is used in the root of the subtree and nowhere else
# - OUTPUT is used for modules that contain the definition of the net, or are on the
# path from the definition to the root
# - remaining modules have a flow of INPUT (unless the net is a top-level inout port,
# in which case it is INOUT)
#
# In other words, the tree looks something like this:
#
# - [no flow] <<< top
# - [no flow]
# - INTERNAL
# - INPUT << use
# - [no flow]
# - INPUT
# - INPUT << use
# - OUTPUT
# - INPUT << use
# - [no flow]
# - OUTPUT << def
# - INPUT
# - INPUT
# - [no flow]
# - [no flow]
# - [no flow]
#
# This function doesn't assign the INOUT flow — that is corrected later, in compute_ports.
lca = {}
# Initialize by marking the definition point of every net.
for cell_idx, cell in enumerate(netlist.cells):
for net in cell.output_nets(cell_idx):
lca[net] = cell.module_idx
netlist.modules[cell.module_idx].net_flow[net] = _nir.ModuleNetFlow.INTERNAL
# Marks a use of a net within a given module, and adjusts its netflows in all modules
# as required.
def use_net(net, use_module):
if net.is_const:
return
# If the net is already present in the current module, we're done.
if net in netlist.modules[use_module].net_flow:
return
modules = netlist.modules
# Otherwise, we need to route the net through the hierarchy from def_module
# to use_module. We do that by treating use_module and def_module as pointers
# and moving them up the hierarchy until they meet at the new LCA.
def_module = lca[net]
# While def_module deeper than use_module, go up with def_module.
while len(modules[def_module].name) > len(modules[use_module].name):
modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT
def_module = modules[def_module].parent
# While use_module deeper than def_module, go up with use_module.
# If use_module is below def_module in the hierarchy, we may hit
# another module which already uses this net before hitting def_module,
# so check for this case.
while len(modules[def_module].name) < len(modules[use_module].name):
if net in modules[use_module].net_flow:
return
modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT
use_module = modules[use_module].parent
# Now both pointers should be at the same depth within the hierarchy.
assert len(modules[def_module].name) == len(modules[use_module].name)
# Move both pointers up until they meet.
while def_module != use_module:
modules[def_module].net_flow[net] = _nir.ModuleNetFlow.OUTPUT
def_module = modules[def_module].parent
modules[use_module].net_flow[net] = _nir.ModuleNetFlow.INPUT
use_module = modules[use_module].parent
assert len(modules[def_module].name) == len(modules[use_module].name)
# And mark the new LCA.
modules[def_module].net_flow[net] = _nir.ModuleNetFlow.INTERNAL
lca[net] = def_module
# Now mark all uses and flesh out the structure.
for cell in netlist.cells:
for net in cell.input_nets():
use_net(net, cell.module_idx)
# TODO: ?
for module_idx, module in enumerate(netlist.modules):
for signal in module.signal_names:
for net in netlist.signals[signal]:
use_net(net, module_idx)
def _compute_ports(netlist: _nir.Netlist):
# Compute the indexes at which the outputs of a cell should be split to create a distinct port.
# These indexes are stored here as nets.
port_starts = set()
for start, _ in netlist.top.ports_i.values():
port_starts.add(_nir.Net.from_cell(0, start))
for start, width in netlist.top.ports_io.values():
port_starts.add(_nir.Net.from_cell(0, start))
for cell_idx, cell in enumerate(netlist.cells):
if isinstance(cell, _nir.Instance):
for start, _ in cell.ports_o.values():
port_starts.add(_nir.Net.from_cell(cell_idx, start))
# Compute the set of all inout nets. Currently, a net has inout flow iff it is connected to
# a toplevel inout port.
inouts = set()
for start, width in netlist.top.ports_io.values():
for idx in range(start, start + width):
inouts.add(_nir.Net.from_cell(0, idx))
for module in netlist.modules:
# Collect preferred names for ports. If a port exactly matches a signal, we reuse
# the signal name for the port. Otherwise, we synthesize a private name.
name_table = {}
for signal, name in module.signal_names.items():
value = netlist.signals[signal]
if value not in name_table and not name.startswith('$'):
name_table[value] = name
# Gather together "adjacent" nets with the same flow into ports.
visited = set()
for net in sorted(module.net_flow):
flow = module.net_flow[net]
if flow == _nir.ModuleNetFlow.INTERNAL:
continue
if flow == _nir.ModuleNetFlow.INPUT and net in inouts:
flow = module.net_flow[net] = _nir.ModuleNetFlow.INOUT
if net in visited:
continue
# We found a net that needs a port. Keep joining the next nets output by the same
# cell into the same port, if applicable, but stop at instance/top port boundaries.
nets = [net]
while True:
succ = _nir.Net.from_cell(net.cell, net.bit + 1)
if succ in port_starts:
break
if succ not in module.net_flow:
break
if module.net_flow[succ] != module.net_flow[net]:
break
net = succ
nets.append(net)
value = _nir.Value(nets)
# Joined as many nets as we could, now name and add the port.
if value in name_table:
name = name_table[value]
else:
name = f"port${value[0].cell}${value[0].bit}"
module.ports[name] = (value, flow)
visited.update(value)
# The 0th cell and the 0th module correspond to the toplevel. Transfer the net flows from
# the toplevel cell (used for data flow) to the toplevel module (used to split netlist into
# modules in the backends).
top_module = netlist.modules[0]
for name, (start, width) in netlist.top.ports_i.items():
top_module.ports[name] = (
_nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)),
_nir.ModuleNetFlow.INPUT
)
for name, (start, width) in netlist.top.ports_io.items():
top_module.ports[name] = (
_nir.Value(_nir.Net.from_cell(0, start + bit) for bit in range(width)),
_nir.ModuleNetFlow.INOUT
)
for name, value in netlist.top.ports_o.items():
top_module.ports[name] = (value, _nir.ModuleNetFlow.OUTPUT)
def build_netlist(fragment, *, name="top"):
from ._xfrm import AssignmentLegalizer
fragment = AssignmentLegalizer()(fragment)
netlist = _nir.Netlist()
_emit_netlist(netlist, fragment, hierarchy=(name,))
netlist.resolve_all_nets()
_compute_net_flows(netlist)
_compute_ports(netlist)
return netlist