hdl.ir: only pull explicitly specified ports to toplevel, if any.
Fixes #30.
This commit is contained in:
parent
6a77122c2e
commit
958cb18b88
169
nmigen/hdl/ir.py
169
nmigen/hdl/ir.py
|
@ -1,5 +1,6 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from collections import defaultdict, OrderedDict
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import traceback
|
||||
import sys
|
||||
|
@ -340,69 +341,140 @@ class Fragment:
|
|||
|
||||
return DomainLowerer(self.domains)(self)
|
||||
|
||||
def _propagate_ports(self, ports):
|
||||
def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
|
||||
def add_uses(*sigs, self=self):
|
||||
for sig in flatten(sigs):
|
||||
if sig not in uses:
|
||||
uses[sig] = set()
|
||||
uses[sig].add(self)
|
||||
|
||||
def add_defs(*sigs):
|
||||
for sig in flatten(sigs):
|
||||
if sig not in defs:
|
||||
defs[sig] = self
|
||||
else:
|
||||
assert defs[sig] is self
|
||||
|
||||
def add_io(sig):
|
||||
assert sig not in ios
|
||||
ios[sig] = self
|
||||
|
||||
# Collect all signals we're driving (on LHS of statements), and signals we're using
|
||||
# (on RHS of statements, or in clock domains).
|
||||
if isinstance(self, Instance):
|
||||
self_driven = SignalSet()
|
||||
self_used = SignalSet()
|
||||
for port_name, (value, dir) in self.named_ports.items():
|
||||
if dir == "i":
|
||||
for signal in value._rhs_signals():
|
||||
self_used.add(signal)
|
||||
self.add_ports(signal, dir="i")
|
||||
add_uses(value._rhs_signals())
|
||||
if dir == "o":
|
||||
for signal in value._lhs_signals():
|
||||
self_driven.add(signal)
|
||||
self.add_ports(signal, dir="o")
|
||||
add_defs(value._lhs_signals())
|
||||
if dir == "io":
|
||||
self.add_ports(value, dir="io")
|
||||
add_io(value)
|
||||
else:
|
||||
self_driven = union((s._lhs_signals() for s in self.statements), start=SignalSet())
|
||||
self_used = union((s._rhs_signals() for s in self.statements), start=SignalSet())
|
||||
for stmt in self.statements:
|
||||
add_uses(stmt._rhs_signals())
|
||||
add_defs(stmt._lhs_signals())
|
||||
|
||||
for domain, _ in self.iter_sync():
|
||||
cd = self.domains[domain]
|
||||
self_used.add(cd.clk)
|
||||
add_uses(cd.clk)
|
||||
if cd.rst is not None:
|
||||
self_used.add(cd.rst)
|
||||
add_uses(cd.rst)
|
||||
|
||||
# Our input ports are all the signals we're using but not driving. This is an over-
|
||||
# approximation: some of these signals may be driven by our subfragments.
|
||||
ins = self_used - self_driven
|
||||
# Our output ports are all the signals we're asked to provide that we're driving. This is
|
||||
# an underapproximation: some of these signals may be driven by subfragments.
|
||||
outs = ports & self_driven
|
||||
|
||||
# Go through subfragments and refine our approximation for inputs.
|
||||
# Repeat for subfragments.
|
||||
for subfrag, name in self.subfragments:
|
||||
# Refine the input port approximation: if a subfragment requires a signal as an input,
|
||||
# and we aren't driving it, it has to be our input as well.
|
||||
sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=())
|
||||
ins |= sub_ins - self_driven
|
||||
parent[subfrag] = self
|
||||
level [subfrag] = level[self] + 1
|
||||
|
||||
for subfrag, name in self.subfragments:
|
||||
# Always ask subfragments to provide all signals that are our inputs.
|
||||
# If the subfragment is not driving it, it will silently ignore it.
|
||||
sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=ins | ports)
|
||||
# Refine the input port appropximation further: if any subfragment is driving a signal
|
||||
# that we currently think should be our input, it shouldn't actually be our input.
|
||||
ins -= sub_outs
|
||||
# Refine the output port approximation: if a subfragment is driving a signal,
|
||||
# and we're asked to provide it, we can provide it now.
|
||||
outs |= ports & sub_outs
|
||||
# All of our subfragments' bidirectional ports are also our bidirectional ports,
|
||||
# since these are only used for pins.
|
||||
self.add_ports(sub_inouts, dir="io")
|
||||
subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
|
||||
|
||||
# We've computed the precise set of input and output ports.
|
||||
self.add_ports(ins, dir="i")
|
||||
self.add_ports(outs, dir="o")
|
||||
def _propagate_ports(self, ports, all_undef_as_ports):
|
||||
# Take this fragment graph:
|
||||
#
|
||||
# __ B (def: q, use: p r)
|
||||
# /
|
||||
# A (def: p, use: q r)
|
||||
# \
|
||||
# \_ C (def: r, use: p q)
|
||||
#
|
||||
# We need to consider three cases.
|
||||
# 1. Signal p requires an input port in B;
|
||||
# 2. Signal r requires an output port in C;
|
||||
# 3. Signal r requires an output port in C and an input port in B.
|
||||
#
|
||||
# Adding these ports can be in general done in three steps for each signal:
|
||||
# 1. Find the least common ancestor of all uses and defs.
|
||||
# 2. Going upwards from the single def, add output ports.
|
||||
# 3. Going upwards from all uses, add input ports.
|
||||
|
||||
return (SignalSet(self.iter_ports("i")),
|
||||
SignalSet(self.iter_ports("o")),
|
||||
SignalSet(self.iter_ports("io")))
|
||||
parent = {self: None}
|
||||
level = {self: 0}
|
||||
uses = SignalDict()
|
||||
defs = SignalDict()
|
||||
ios = SignalDict()
|
||||
self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
|
||||
|
||||
def prepare(self, ports=(), ensure_sync_exists=True):
|
||||
ports = SignalSet(ports)
|
||||
if all_undef_as_ports:
|
||||
for sig in uses:
|
||||
if sig in defs:
|
||||
continue
|
||||
ports.add(sig)
|
||||
for sig in ports:
|
||||
if sig not in uses:
|
||||
uses[sig] = set()
|
||||
uses[sig].add(self)
|
||||
|
||||
@memoize
|
||||
def lca_of(fragu, fragv):
|
||||
# Normalize fragu to be deeper than fragv.
|
||||
if level[fragu] < level[fragv]:
|
||||
fragu, fragv = fragv, fragu
|
||||
# Find ancestor of fragu on the same level as fragv.
|
||||
for _ in range(level[fragu] - level[fragv]):
|
||||
fragu = parent[fragu]
|
||||
# If fragv was the ancestor of fragv, we're done.
|
||||
if fragu == fragv:
|
||||
return fragu
|
||||
# Otherwise, they are at the same level but in different branches. Step both fragu
|
||||
# and fragv until we find the common ancestor.
|
||||
while parent[fragu] != parent[fragv]:
|
||||
fragu = parent[fragu]
|
||||
fragv = parent[fragv]
|
||||
return parent[fragu]
|
||||
|
||||
for sig in uses:
|
||||
if sig in defs:
|
||||
lca = reduce(lca_of, uses[sig], defs[sig])
|
||||
|
||||
frag = defs[sig]
|
||||
while frag != lca:
|
||||
frag.add_ports(sig, dir="o")
|
||||
frag = parent[frag]
|
||||
else:
|
||||
lca = reduce(lca_of, uses[sig])
|
||||
|
||||
for frag in uses[sig]:
|
||||
if sig in defs and frag is defs[sig]:
|
||||
continue
|
||||
while frag != lca:
|
||||
frag.add_ports(sig, dir="i")
|
||||
frag = parent[frag]
|
||||
|
||||
for sig in ios:
|
||||
frag = ios[sig]
|
||||
while frag is not None:
|
||||
frag.add_ports(sig, dir="io")
|
||||
frag = parent[frag]
|
||||
|
||||
for sig in ports:
|
||||
if sig in ios:
|
||||
continue
|
||||
if sig in defs:
|
||||
self.add_ports(sig, dir="o")
|
||||
else:
|
||||
self.add_ports(sig, dir="i")
|
||||
|
||||
def prepare(self, ports=None, ensure_sync_exists=True):
|
||||
from .xfrm import SampleLowerer
|
||||
|
||||
fragment = SampleLowerer()(self)
|
||||
|
@ -410,7 +482,10 @@ class Fragment:
|
|||
fragment._resolve_hierarchy_conflicts()
|
||||
fragment = fragment._insert_domain_resets()
|
||||
fragment = fragment._lower_domain_signals()
|
||||
fragment._propagate_ports(ports)
|
||||
if ports is None:
|
||||
fragment._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
else:
|
||||
fragment._propagate_ports(ports=ports, all_undef_as_ports=False)
|
||||
return fragment
|
||||
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
f = Fragment()
|
||||
self.assertEqual(list(f.iter_ports()), [])
|
||||
|
||||
f._propagate_ports(ports=())
|
||||
f._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f.ports, SignalDict([]))
|
||||
|
||||
def test_iter_signals(self):
|
||||
|
@ -80,7 +80,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
self.s1.eq(self.c1)
|
||||
)
|
||||
|
||||
f._propagate_ports(ports=())
|
||||
f._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f.ports, SignalDict([]))
|
||||
|
||||
def test_infer_input(self):
|
||||
|
@ -89,7 +89,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
self.c1.eq(self.s1)
|
||||
)
|
||||
|
||||
f._propagate_ports(ports=())
|
||||
f._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f.ports, SignalDict([
|
||||
(self.s1, "i")
|
||||
]))
|
||||
|
@ -100,7 +100,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
self.c1.eq(self.s1)
|
||||
)
|
||||
|
||||
f._propagate_ports(ports=(self.c1,))
|
||||
f._propagate_ports(ports=(self.c1,), all_undef_as_ports=True)
|
||||
self.assertEqual(f.ports, SignalDict([
|
||||
(self.s1, "i"),
|
||||
(self.c1, "o")
|
||||
|
@ -116,7 +116,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
self.s1.eq(0)
|
||||
)
|
||||
f1.add_subfragment(f2)
|
||||
f1._propagate_ports(ports=())
|
||||
f1._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f1.ports, SignalDict())
|
||||
self.assertEqual(f2.ports, SignalDict([
|
||||
(self.s1, "o"),
|
||||
|
@ -129,7 +129,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
self.c1.eq(self.s1)
|
||||
)
|
||||
f1.add_subfragment(f2)
|
||||
f1._propagate_ports(ports=())
|
||||
f1._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f1.ports, SignalDict([
|
||||
(self.s1, "i"),
|
||||
]))
|
||||
|
@ -148,7 +148,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
)
|
||||
f1.add_subfragment(f2)
|
||||
|
||||
f1._propagate_ports(ports=(self.c2,))
|
||||
f1._propagate_ports(ports=(self.c2,), all_undef_as_ports=True)
|
||||
self.assertEqual(f1.ports, SignalDict([
|
||||
(self.c2, "o"),
|
||||
]))
|
||||
|
@ -170,7 +170,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
f3.add_driver(self.c2)
|
||||
f1.add_subfragment(f3)
|
||||
|
||||
f1._propagate_ports(ports=())
|
||||
f1._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f1.ports, SignalDict())
|
||||
|
||||
def test_output_input_sibling(self):
|
||||
|
@ -187,7 +187,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
)
|
||||
f1.add_subfragment(f3)
|
||||
|
||||
f1._propagate_ports(ports=())
|
||||
f1._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f1.ports, SignalDict())
|
||||
|
||||
def test_input_cd(self):
|
||||
|
@ -199,7 +199,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
f.add_domains(sync)
|
||||
f.add_driver(self.c1, "sync")
|
||||
|
||||
f._propagate_ports(ports=())
|
||||
f._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f.ports, SignalDict([
|
||||
(self.s1, "i"),
|
||||
(sync.clk, "i"),
|
||||
|
@ -215,7 +215,7 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
f.add_domains(sync)
|
||||
f.add_driver(self.c1, "sync")
|
||||
|
||||
f._propagate_ports(ports=())
|
||||
f._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f.ports, SignalDict([
|
||||
(self.s1, "i"),
|
||||
(sync.clk, "i"),
|
||||
|
@ -224,11 +224,10 @@ class FragmentPortsTestCase(FHDLTestCase):
|
|||
def test_inout(self):
|
||||
s = Signal()
|
||||
f1 = Fragment()
|
||||
f2 = Fragment()
|
||||
f2.add_ports(s, dir="io")
|
||||
f2 = Instance("foo", io_x=s)
|
||||
f1.add_subfragment(f2)
|
||||
|
||||
f1._propagate_ports(ports=())
|
||||
f1._propagate_ports(ports=(), all_undef_as_ports=True)
|
||||
self.assertEqual(f1.ports, SignalDict([
|
||||
(s, "io")
|
||||
]))
|
||||
|
@ -557,8 +556,14 @@ class InstanceTestCase(FHDLTestCase):
|
|||
self.assertEqual(f.ports, SignalDict([
|
||||
(clk, "i"),
|
||||
(self.rst, "i"),
|
||||
(self.stb, "o"),
|
||||
(self.datal, "o"),
|
||||
(self.datah, "o"),
|
||||
(self.pins, "io"),
|
||||
]))
|
||||
|
||||
def test_prepare_explicit_ports(self):
|
||||
self.setUp_cpu()
|
||||
f = self.inst.prepare(ports=[self.rst, self.stb])
|
||||
self.assertEqual(f.ports, SignalDict([
|
||||
(self.rst, "i"),
|
||||
(self.stb, "o"),
|
||||
(self.pins, "io"),
|
||||
]))
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
__all__ = ["flatten", "union", "log2_int", "bits_for", "deprecated"]
|
||||
__all__ = ["flatten", "union", "log2_int", "bits_for", "memoize", "deprecated"]
|
||||
|
||||
|
||||
def flatten(i):
|
||||
|
@ -46,6 +47,16 @@ def bits_for(n, require_sign_bit=False):
|
|||
return r
|
||||
|
||||
|
||||
def memoize(f):
|
||||
memo = OrderedDict()
|
||||
@functools.wraps(f)
|
||||
def g(*args):
|
||||
if args not in memo:
|
||||
memo[args] = f(*args)
|
||||
return memo[args]
|
||||
return g
|
||||
|
||||
|
||||
def deprecated(message, stacklevel=2):
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
|
|
Loading…
Reference in a new issue