amaranth/amaranth/hdl/_ir.py

680 lines
28 KiB
Python
Raw Normal View History

2021-12-09 22:39:50 -07:00
from abc import ABCMeta
from collections import defaultdict, OrderedDict
from functools import reduce
import warnings
from .. import tracer
2021-12-09 22:39:50 -07:00
from .._utils import *
from .._unused import *
from ._ast import *
from ._ast import _StatementList
from ._cd import *
2021-12-09 22:39:50 -07:00
__all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"]
class UnusedElaboratable(UnusedMustUse):
# The warning is initially silenced. If everything that has been constructed remains unused,
# it means the application likely crashed (with an exception, or in another way that does not
# call `sys.excepthook`), and it's not necessary to show any warnings.
# Once elaboration starts, the warning is enabled.
_MustUse__silence = True
2021-12-09 22:39:50 -07:00
class Elaboratable(MustUse):
2021-12-09 22:39:50 -07:00
_MustUse__warning = UnusedElaboratable
class DriverConflict(UserWarning):
pass
class Fragment:
@staticmethod
def get(obj, platform):
code = None
while True:
if isinstance(obj, Fragment):
return obj
elif isinstance(obj, Elaboratable):
code = obj.elaborate.__code__
UnusedElaboratable._MustUse__silence = False
2021-12-09 22:39:50 -07:00
obj._MustUse__used = True
new_obj = obj.elaborate(platform)
2021-12-09 22:39:50 -07:00
elif hasattr(obj, "elaborate"):
warnings.warn(
message="Class {!r} is an elaboratable that does not explicitly inherit from "
"Elaboratable; doing so would improve diagnostics"
.format(type(obj)),
category=RuntimeWarning,
stacklevel=2)
code = obj.elaborate.__code__
new_obj = obj.elaborate(platform)
2021-12-09 22:39:50 -07:00
else:
2023-11-14 05:58:53 -07:00
raise AttributeError(f"Object {obj!r} cannot be elaborated")
if new_obj is obj:
2023-11-14 05:58:53 -07:00
raise RecursionError(f"Object {obj!r} elaborates to itself")
if new_obj is None and code is not None:
2021-12-09 22:39:50 -07:00
warnings.warn_explicit(
message=".elaborate() returned None; missing return statement?",
category=UserWarning,
filename=code.co_filename,
lineno=code.co_firstlineno)
obj = new_obj
2021-12-09 22:39:50 -07:00
def __init__(self):
self.ports = SignalDict()
self.drivers = OrderedDict()
self.statements = {}
2021-12-09 22:39:50 -07:00
self.domains = OrderedDict()
self.subfragments = []
self.attrs = OrderedDict()
self.generated = OrderedDict()
self.flatten = False
def add_ports(self, *ports, dir):
assert dir in ("i", "o", "io")
for port in flatten(ports):
self.ports[port] = dir
def iter_ports(self, dir=None):
if dir is None:
yield from self.ports
else:
for port, port_dir in self.ports.items():
if port_dir == dir:
yield port
def add_driver(self, signal, domain=None):
if domain not in self.drivers:
self.drivers[domain] = SignalSet()
self.drivers[domain].add(signal)
def iter_drivers(self):
for domain, signals in self.drivers.items():
for signal in signals:
yield domain, signal
def iter_comb(self):
if None in self.drivers:
yield from self.drivers[None]
def iter_sync(self):
for domain, signals in self.drivers.items():
if domain is None:
continue
for signal in signals:
yield domain, signal
def iter_signals(self):
signals = SignalSet()
signals |= self.ports.keys()
for domain, domain_signals in self.drivers.items():
if domain is not None:
cd = self.domains[domain]
signals.add(cd.clk)
if cd.rst is not None:
signals.add(cd.rst)
signals |= domain_signals
return signals
def add_domains(self, *domains):
for domain in flatten(domains):
assert isinstance(domain, ClockDomain)
assert domain.name not in self.domains
self.domains[domain.name] = domain
def iter_domains(self):
yield from self.domains
def add_statements(self, domain, *stmts):
assert domain is None or isinstance(domain, str)
2021-12-09 22:39:50 -07:00
for stmt in Statement.cast(stmts):
stmt._MustUse__used = True
self.statements.setdefault(domain, _StatementList()).append(stmt)
2021-12-09 22:39:50 -07:00
def add_subfragment(self, subfragment, name=None):
assert isinstance(subfragment, Fragment)
self.subfragments.append((subfragment, name))
def find_subfragment(self, name_or_index):
if isinstance(name_or_index, int):
if name_or_index < len(self.subfragments):
subfragment, name = self.subfragments[name_or_index]
return subfragment
2023-11-14 05:58:53 -07:00
raise NameError(f"No subfragment at index #{name_or_index}")
2021-12-09 22:39:50 -07:00
else:
for subfragment, name in self.subfragments:
if name == name_or_index:
return subfragment
2023-11-14 05:58:53 -07:00
raise NameError(f"No subfragment with name '{name_or_index}'")
2021-12-09 22:39:50 -07:00
def find_generated(self, *path):
if len(path) > 1:
path_component, *path = path
return self.find_subfragment(path_component).find_generated(*path)
else:
item, = path
return self.generated[item]
def elaborate(self, platform):
return self
def _merge_subfragment(self, subfragment):
# Merge subfragment's everything except clock domains into this fragment.
# Flattening is done after clock domain propagation, so we can assume the domains
# are already the same in every involved fragment in the first place.
self.ports.update(subfragment.ports)
for domain, signal in subfragment.iter_drivers():
self.add_driver(signal, domain)
for domain, statements in subfragment.statements.items():
self.statements.setdefault(domain, []).extend(statements)
2021-12-09 22:39:50 -07:00
self.subfragments += subfragment.subfragments
# Remove the merged subfragment.
found = False
for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr:
if subfragment == check_subfrag:
del self.subfragments[i]
found = True
break
assert found
def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
assert mode in ("silent", "warn", "error")
from ._mem import MemoryInstance
2021-12-09 22:39:50 -07:00
driver_subfrags = SignalDict()
def add_subfrag(registry, entity, entry):
# Because of missing domain insertion, at the point when this code runs, we have
# a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to
# the actual signals (because the signal itself can be driven as well); but leave
# the unbound ones as it is, because there's no concrete signal for it yet anyway.
if isinstance(entity, ClockSignal) and entity.domain in self.domains:
entity = self.domains[entity.domain].clk
elif isinstance(entity, ResetSignal) and entity.domain in self.domains:
entity = self.domains[entity.domain].rst
if entity not in registry:
registry[entity] = set()
registry[entity].add(entry)
# For each signal driven by this fragment and/or its subfragments, determine which
# subfragments also drive it.
for domain, signal in self.iter_drivers():
add_subfrag(driver_subfrags, signal, (None, hierarchy))
flatten_subfrags = set()
for i, (subfrag, name) in enumerate(self.subfragments):
if name is None:
2023-11-14 05:58:53 -07:00
name = f"<unnamed #{i}>"
2021-12-09 22:39:50 -07:00
subfrag_hierarchy = hierarchy + (name,)
if subfrag.flatten:
# Always flatten subfragments that explicitly request it.
flatten_subfrags.add((subfrag, subfrag_hierarchy))
if isinstance(subfrag, (Instance, MemoryInstance)):
2021-12-09 22:39:50 -07:00
# Never flatten instances.
continue
# First, recurse into subfragments and let them detect driver conflicts as well.
subfrag_drivers = \
2021-12-09 22:39:50 -07:00
subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
# Second, classify subfragments by signals they drive.
2021-12-09 22:39:50 -07:00
for signal in subfrag_drivers:
add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
# Find out the set of subfragments that needs to be flattened into this fragment
# to resolve driver-driver conflicts.
def flatten_subfrags_if_needed(subfrags):
if len(subfrags) == 1:
return []
flatten_subfrags.update((f, h) for f, h in subfrags if f is not None)
return list(sorted(".".join(h) for f, h in subfrags))
for signal, subfrags in driver_subfrags.items():
subfrag_names = flatten_subfrags_if_needed(subfrags)
if not subfrag_names:
continue
# While we're at it, show a message.
message = ("Signal '{}' is driven from multiple fragments: {}"
.format(signal, ", ".join(subfrag_names)))
if mode == "error":
raise DriverConflict(message)
elif mode == "warn":
message += "; hierarchy will be flattened"
warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
# Flatten hierarchy.
for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
self._merge_subfragment(subfrag)
# If we flattened anything, we might be in a situation where we have a driver conflict
# again, e.g. if we had a tree of fragments like A --- B --- C where only fragments
# A and C were driving a signal S. In that case, since B is not driving S itself,
# processing B will not result in any flattening, but since B is transitively driving S,
# processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which
# has another conflict.
if any(flatten_subfrags):
# Try flattening again.
return self._resolve_hierarchy_conflicts(hierarchy, mode)
# Nothing was flattened, we're done!
return SignalSet(driver_subfrags.keys())
2021-12-09 22:39:50 -07:00
def _propagate_domains_up(self, hierarchy=("top",)):
from ._xfrm import DomainRenamer
2021-12-09 22:39:50 -07:00
2023-11-14 05:58:53 -07:00
domain_subfrags = defaultdict(set)
2021-12-09 22:39:50 -07:00
# For each domain defined by a subfragment, determine which subfragments define it.
for i, (subfrag, name) in enumerate(self.subfragments):
# First, recurse into subfragments and let them propagate domains up as well.
hier_name = name
if hier_name is None:
2023-11-14 05:58:53 -07:00
hier_name = f"<unnamed #{i}>"
2021-12-09 22:39:50 -07:00
subfrag._propagate_domains_up(hierarchy + (hier_name,))
# Second, classify subfragments by domains they define.
for domain_name, domain in subfrag.domains.items():
if domain.local:
continue
domain_subfrags[domain_name].add((subfrag, name, i))
# For each domain defined by more than one subfragment, rename the domain in each
# of the subfragments such that they no longer conflict.
for domain_name, subfrags in domain_subfrags.items():
if len(subfrags) == 1:
continue
names = [n for f, n, i in subfrags]
if not all(names):
2023-11-14 05:58:53 -07:00
names = sorted(f"<unnamed #{i}>" if n is None else f"'{n}'"
2021-12-09 22:39:50 -07:00
for f, n, i in subfrags)
raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; "
"it is necessary to either rename subfragment domains "
"explicitly, or give names to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
if len(names) != len(set(names)):
2023-11-14 05:58:53 -07:00
names = sorted(f"#{i}" for f, n, i in subfrags)
2021-12-09 22:39:50 -07:00
raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', "
"some of which have identical names; it is necessary to either "
"rename subfragment domains explicitly, or give distinct names "
"to subfragments"
.format(domain_name, ", ".join(names), ".".join(hierarchy)))
for subfrag, name, i in subfrags:
2023-11-14 05:58:53 -07:00
domain_name_map = {domain_name: f"{name}_{domain_name}"}
2021-12-09 22:39:50 -07:00
self.subfragments[i] = (DomainRenamer(domain_name_map)(subfrag), name)
# Finally, collect the (now unique) subfragment domains, and merge them into our domains.
for subfrag, name in self.subfragments:
for domain_name, domain in subfrag.domains.items():
if domain.local:
continue
self.add_domains(domain)
def _propagate_domains_down(self):
# For each domain defined in this fragment, ensure it also exists in all subfragments.
for subfrag, name in self.subfragments:
for domain in self.iter_domains():
if domain in subfrag.domains:
assert self.domains[domain] is subfrag.domains[domain]
else:
subfrag.add_domains(self.domains[domain])
subfrag._propagate_domains_down()
def _create_missing_domains(self, missing_domain, *, platform=None):
from ._xfrm import DomainCollector
2021-12-09 22:39:50 -07:00
collector = DomainCollector()
collector(self)
new_domains = []
for domain_name in collector.used_domains - collector.defined_domains:
if domain_name is None:
continue
value = missing_domain(domain_name)
if value is None:
2023-11-14 05:58:53 -07:00
raise DomainError(f"Domain '{domain_name}' is used but not defined")
2021-12-09 22:39:50 -07:00
if type(value) is ClockDomain:
self.add_domains(value)
# And expose ports on the newly added clock domain, since it is added directly
# and there was no chance to add any logic driving it.
new_domains.append(value)
else:
new_fragment = Fragment.get(value, platform=platform)
if domain_name not in new_fragment.domains:
defined = new_fragment.domains.keys()
raise DomainError(
"Fragment returned by missing domain callback does not define "
"requested domain '{}' (defines {})."
2023-11-14 05:58:53 -07:00
.format(domain_name, ", ".join(f"'{n}'" for n in defined)))
self.add_subfragment(new_fragment, f"cd_{domain_name}")
2021-12-09 22:39:50 -07:00
self.add_domains(new_fragment.domains.values())
return new_domains
def _propagate_domains(self, missing_domain, *, platform=None):
self._propagate_domains_up()
self._propagate_domains_down()
self._resolve_hierarchy_conflicts()
new_domains = self._create_missing_domains(missing_domain, platform=platform)
self._propagate_domains_down()
return new_domains
def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
from ._mem import MemoryInstance
2021-12-09 22:39:50 -07:00
def add_uses(*sigs, self=self):
for sig in flatten(sigs):
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
def add_defs(*sigs):
for sig in flatten(sigs):
if sig not in defs:
defs[sig] = self
else:
assert defs[sig] is self
def add_io(*sigs):
for sig in flatten(sigs):
if sig not in ios:
ios[sig] = self
else:
assert ios[sig] is self
# Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains).
for stmts in self.statements.values():
for stmt in stmts:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
2021-12-09 22:39:50 -07:00
for domain, _ in self.iter_sync():
cd = self.domains[domain]
add_uses(cd.clk)
if cd.rst is not None:
add_uses(cd.rst)
# Repeat for subfragments.
for subfrag, name in self.subfragments:
if isinstance(subfrag, Instance):
for port_name, (value, dir) in subfrag.named_ports.items():
if dir == "i":
# Prioritize defs over uses.
rhs_without_outputs = value._rhs_signals() - subfrag.iter_ports(dir="o")
subfrag.add_ports(rhs_without_outputs, dir=dir)
add_uses(value._rhs_signals())
if dir == "o":
subfrag.add_ports(value._lhs_signals(), dir=dir)
add_defs(value._lhs_signals())
if dir == "io":
subfrag.add_ports(value._lhs_signals(), dir=dir)
add_io(value._lhs_signals())
elif isinstance(subfrag, MemoryInstance):
for port in subfrag._read_ports:
subfrag.add_ports(port._data._lhs_signals(), dir="o")
add_defs(port._data._lhs_signals())
for value in [port._addr, port._en]:
subfrag.add_ports(value._rhs_signals(), dir="i")
add_uses(value._rhs_signals())
for port in subfrag._write_ports:
for value in [port._addr, port._en, port._data]:
subfrag.add_ports(value._rhs_signals(), dir="i")
add_uses(value._rhs_signals())
for domain, _ in subfrag.iter_sync():
cd = subfrag.domains[domain]
add_uses(cd.clk)
if cd.rst is not None:
add_uses(cd.rst)
2021-12-09 22:39:50 -07:00
else:
parent[subfrag] = self
level [subfrag] = level[self] + 1
subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
def _propagate_ports(self, ports, all_undef_as_ports):
# Take this fragment graph:
#
# __ B (def: q, use: p r)
# /
# A (def: p, use: q r)
# \
# \_ C (def: r, use: p q)
#
# We need to consider three cases.
# 1. Signal p requires an input port in B;
# 2. Signal r requires an output port in C;
# 3. Signal r requires an output port in C and an input port in B.
#
# Adding these ports can be in general done in three steps for each signal:
# 1. Find the least common ancestor of all uses and defs.
# 2. Going upwards from the single def, add output ports.
# 3. Going upwards from all uses, add input ports.
parent = {self: None}
level = {self: 0}
uses = SignalDict()
defs = SignalDict()
ios = SignalDict()
self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
ports = SignalSet(ports)
if all_undef_as_ports:
for sig in uses:
if sig in defs:
continue
ports.add(sig)
for sig in ports:
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
@memoize
def lca_of(fragu, fragv):
# Normalize fragu to be deeper than fragv.
if level[fragu] < level[fragv]:
fragu, fragv = fragv, fragu
# Find ancestor of fragu on the same level as fragv.
for _ in range(level[fragu] - level[fragv]):
fragu = parent[fragu]
# If fragv was the ancestor of fragv, we're done.
if fragu == fragv:
return fragu
# Otherwise, they are at the same level but in different branches. Step both fragu
# and fragv until we find the common ancestor.
while parent[fragu] != parent[fragv]:
fragu = parent[fragu]
fragv = parent[fragv]
return parent[fragu]
for sig in uses:
if sig in defs:
lca = reduce(lca_of, uses[sig], defs[sig])
else:
lca = reduce(lca_of, uses[sig])
for frag in uses[sig]:
if sig in defs and frag is defs[sig]:
continue
while frag != lca:
frag.add_ports(sig, dir="i")
frag = parent[frag]
if sig in defs:
frag = defs[sig]
while frag != lca:
frag.add_ports(sig, dir="o")
frag = parent[frag]
for sig in ios:
frag = ios[sig]
while frag is not None:
frag.add_ports(sig, dir="io")
frag = parent[frag]
for sig in ports:
if sig in ios:
continue
if sig in defs:
self.add_ports(sig, dir="o")
else:
self.add_ports(sig, dir="i")
def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)):
from ._xfrm import DomainLowerer
2021-12-09 22:39:50 -07:00
new_domains = self._propagate_domains(missing_domain)
fragment = DomainLowerer()(self)
2021-12-09 22:39:50 -07:00
if ports is None:
fragment._propagate_ports(ports=(), all_undef_as_ports=True)
else:
if not isinstance(ports, tuple) and not isinstance(ports, list):
msg = "`ports` must be either a list or a tuple, not {!r}"\
.format(ports)
if isinstance(ports, Value):
msg += " (did you mean `ports=(<signal>,)`, rather than `ports=<signal>`?)"
raise TypeError(msg)
mapped_ports = []
# Lower late bound signals like ClockSignal() to ports.
port_lowerer = DomainLowerer(fragment.domains)
for port in ports:
if not isinstance(port, (Signal, ClockSignal, ResetSignal)):
raise TypeError("Only signals may be added as ports, not {!r}"
.format(port))
mapped_ports.append(port_lowerer.on_value(port))
# Add ports for all newly created missing clock domains, since not doing so defeats
# the purpose of domain auto-creation. (It's possible to refer to these ports before
# the domain actually exists through late binding, but it's inconvenient.)
for cd in new_domains:
mapped_ports.append(cd.clk)
if cd.rst is not None:
mapped_ports.append(cd.rst)
fragment._propagate_ports(ports=mapped_ports, all_undef_as_ports=False)
return fragment
def _assign_names_to_signals(self):
"""Assign names to signals used in this fragment.
Returns
-------
SignalDict of Signal to str
A mapping from signals used in this fragment to their local names. Because names are
deduplicated using local information only, the same signal used in a different fragment
may get a different name.
"""
signal_names = SignalDict()
assigned_names = set()
def add_signal_name(signal):
if signal not in signal_names:
if signal.name not in assigned_names:
name = signal.name
else:
name = f"{signal.name}${len(assigned_names)}"
assert name not in assigned_names
signal_names[signal] = name
assigned_names.add(name)
for port in self.ports.keys():
add_signal_name(port)
for domain_name, domain_signals in self.drivers.items():
if domain_name is not None:
domain = self.domains[domain_name]
add_signal_name(domain.clk)
if domain.rst is not None:
add_signal_name(domain.rst)
for statements in self.statements.values():
for statement in statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (ClockSignal, ResetSignal)):
add_signal_name(signal)
return signal_names
def _assign_names_to_fragments(self, hierarchy=("top",), *, _names=None):
"""Assign names to this fragment and its subfragments.
Subfragments may not necessarily have a name. This method assigns every such subfragment
a name, ``U$<number>``, where ``<number>`` is based on its location in the hierarchy.
Subfragment names may collide with signal names safely in Amaranth, but this may confuse
backends. This method assigns every such subfragment a name, ``<name>$U$<number>``, where
``name`` is its original name, and ``<number>`` is based on its location in the hierarchy.
Arguments
---------
hierarchy : tuple of str
Name of this fragment.
Returns
-------
dict of Fragment to tuple of str
A mapping from this fragment and its subfragments to their full hierarchical names.
"""
if _names is None:
_names = dict()
_names[self] = hierarchy
signal_names = set(self._assign_names_to_signals().values())
for subfragment_index, (subfragment, subfragment_name) in enumerate(self.subfragments):
if subfragment_name is None:
subfragment_name = f"U${subfragment_index}"
elif subfragment_name in signal_names:
subfragment_name = f"{subfragment_name}$U${subfragment_index}"
assert subfragment_name not in signal_names
subfragment._assign_names_to_fragments(hierarchy=(*hierarchy, subfragment_name),
_names=_names)
return _names
2021-12-09 22:39:50 -07:00
class Instance(Fragment):
def __init__(self, type, *args, src_loc=None, src_loc_at=0, **kwargs):
2021-12-09 22:39:50 -07:00
super().__init__()
self.type = type
self.parameters = OrderedDict()
self.named_ports = OrderedDict()
self.src_loc = src_loc or tracer.get_src_loc(src_loc_at)
2021-12-09 22:39:50 -07:00
for (kind, name, value) in args:
if kind == "a":
self.attrs[name] = value
elif kind == "p":
self.parameters[name] = value
elif kind in ("i", "o", "io"):
self.named_ports[name] = (Value.cast(value), kind)
else:
raise NameError("Instance argument {!r} should be a tuple (kind, name, value) "
"where kind is one of \"a\", \"p\", \"i\", \"o\", or \"io\""
.format((kind, name, value)))
for kw, arg in kwargs.items():
if kw.startswith("a_"):
self.attrs[kw[2:]] = arg
elif kw.startswith("p_"):
self.parameters[kw[2:]] = arg
elif kw.startswith("i_"):
self.named_ports[kw[2:]] = (Value.cast(arg), "i")
elif kw.startswith("o_"):
self.named_ports[kw[2:]] = (Value.cast(arg), "o")
elif kw.startswith("io_"):
self.named_ports[kw[3:]] = (Value.cast(arg), "io")
else:
raise NameError("Instance keyword argument {}={!r} does not start with one of "
"\"a_\", \"p_\", \"i_\", \"o_\", or \"io_\""
.format(kw, arg))