hdl.ir: only pull explicitly specified ports to toplevel, if any.

Fixes #30.
This commit is contained in:
whitequark 2019-05-12 05:21:23 +00:00
parent 6a77122c2e
commit 958cb18b88
3 changed files with 156 additions and 65 deletions

View file

@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod
from collections import defaultdict, OrderedDict
from functools import reduce
import warnings
import traceback
import sys
@ -340,69 +341,140 @@ class Fragment:
return DomainLowerer(self.domains)(self)
def _propagate_ports(self, ports):
def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
def add_uses(*sigs, self=self):
for sig in flatten(sigs):
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
def add_defs(*sigs):
for sig in flatten(sigs):
if sig not in defs:
defs[sig] = self
else:
assert defs[sig] is self
def add_io(sig):
assert sig not in ios
ios[sig] = self
# Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains).
if isinstance(self, Instance):
self_driven = SignalSet()
self_used = SignalSet()
for port_name, (value, dir) in self.named_ports.items():
if dir == "i":
for signal in value._rhs_signals():
self_used.add(signal)
self.add_ports(signal, dir="i")
add_uses(value._rhs_signals())
if dir == "o":
for signal in value._lhs_signals():
self_driven.add(signal)
self.add_ports(signal, dir="o")
add_defs(value._lhs_signals())
if dir == "io":
self.add_ports(value, dir="io")
add_io(value)
else:
self_driven = union((s._lhs_signals() for s in self.statements), start=SignalSet())
self_used = union((s._rhs_signals() for s in self.statements), start=SignalSet())
for stmt in self.statements:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for domain, _ in self.iter_sync():
cd = self.domains[domain]
self_used.add(cd.clk)
add_uses(cd.clk)
if cd.rst is not None:
self_used.add(cd.rst)
add_uses(cd.rst)
# Our input ports are all the signals we're using but not driving. This is an over-
# approximation: some of these signals may be driven by our subfragments.
ins = self_used - self_driven
# Our output ports are all the signals we're asked to provide that we're driving. This is
# an underapproximation: some of these signals may be driven by subfragments.
outs = ports & self_driven
# Go through subfragments and refine our approximation for inputs.
# Repeat for subfragments.
for subfrag, name in self.subfragments:
# Refine the input port approximation: if a subfragment requires a signal as an input,
# and we aren't driving it, it has to be our input as well.
sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=())
ins |= sub_ins - self_driven
parent[subfrag] = self
level [subfrag] = level[self] + 1
for subfrag, name in self.subfragments:
# Always ask subfragments to provide all signals that are our inputs.
# If the subfragment is not driving it, it will silently ignore it.
sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=ins | ports)
# Refine the input port appropximation further: if any subfragment is driving a signal
# that we currently think should be our input, it shouldn't actually be our input.
ins -= sub_outs
# Refine the output port approximation: if a subfragment is driving a signal,
# and we're asked to provide it, we can provide it now.
outs |= ports & sub_outs
# All of our subfragments' bidirectional ports are also our bidirectional ports,
# since these are only used for pins.
self.add_ports(sub_inouts, dir="io")
subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
# We've computed the precise set of input and output ports.
self.add_ports(ins, dir="i")
self.add_ports(outs, dir="o")
def _propagate_ports(self, ports, all_undef_as_ports):
# Take this fragment graph:
#
# __ B (def: q, use: p r)
# /
# A (def: p, use: q r)
# \
# \_ C (def: r, use: p q)
#
# We need to consider three cases.
# 1. Signal p requires an input port in B;
# 2. Signal r requires an output port in C;
# 3. Signal r requires an output port in C and an input port in B.
#
# Adding these ports can be in general done in three steps for each signal:
# 1. Find the least common ancestor of all uses and defs.
# 2. Going upwards from the single def, add output ports.
# 3. Going upwards from all uses, add input ports.
return (SignalSet(self.iter_ports("i")),
SignalSet(self.iter_ports("o")),
SignalSet(self.iter_ports("io")))
parent = {self: None}
level = {self: 0}
uses = SignalDict()
defs = SignalDict()
ios = SignalDict()
self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
def prepare(self, ports=(), ensure_sync_exists=True):
ports = SignalSet(ports)
if all_undef_as_ports:
for sig in uses:
if sig in defs:
continue
ports.add(sig)
for sig in ports:
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
@memoize
def lca_of(fragu, fragv):
# Normalize fragu to be deeper than fragv.
if level[fragu] < level[fragv]:
fragu, fragv = fragv, fragu
# Find ancestor of fragu on the same level as fragv.
for _ in range(level[fragu] - level[fragv]):
fragu = parent[fragu]
# If fragv was the ancestor of fragv, we're done.
if fragu == fragv:
return fragu
# Otherwise, they are at the same level but in different branches. Step both fragu
# and fragv until we find the common ancestor.
while parent[fragu] != parent[fragv]:
fragu = parent[fragu]
fragv = parent[fragv]
return parent[fragu]
for sig in uses:
if sig in defs:
lca = reduce(lca_of, uses[sig], defs[sig])
frag = defs[sig]
while frag != lca:
frag.add_ports(sig, dir="o")
frag = parent[frag]
else:
lca = reduce(lca_of, uses[sig])
for frag in uses[sig]:
if sig in defs and frag is defs[sig]:
continue
while frag != lca:
frag.add_ports(sig, dir="i")
frag = parent[frag]
for sig in ios:
frag = ios[sig]
while frag is not None:
frag.add_ports(sig, dir="io")
frag = parent[frag]
for sig in ports:
if sig in ios:
continue
if sig in defs:
self.add_ports(sig, dir="o")
else:
self.add_ports(sig, dir="i")
def prepare(self, ports=None, ensure_sync_exists=True):
from .xfrm import SampleLowerer
fragment = SampleLowerer()(self)
@ -410,7 +482,10 @@ class Fragment:
fragment._resolve_hierarchy_conflicts()
fragment = fragment._insert_domain_resets()
fragment = fragment._lower_domain_signals()
fragment._propagate_ports(ports)
if ports is None:
fragment._propagate_ports(ports=(), all_undef_as_ports=True)
else:
fragment._propagate_ports(ports=ports, all_undef_as_ports=False)
return fragment

View file

@ -65,7 +65,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f = Fragment()
self.assertEqual(list(f.iter_ports()), [])
f._propagate_ports(ports=())
f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([]))
def test_iter_signals(self):
@ -80,7 +80,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.s1.eq(self.c1)
)
f._propagate_ports(ports=())
f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([]))
def test_infer_input(self):
@ -89,7 +89,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.c1.eq(self.s1)
)
f._propagate_ports(ports=())
f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i")
]))
@ -100,7 +100,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.c1.eq(self.s1)
)
f._propagate_ports(ports=(self.c1,))
f._propagate_ports(ports=(self.c1,), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i"),
(self.c1, "o")
@ -116,7 +116,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.s1.eq(0)
)
f1.add_subfragment(f2)
f1._propagate_ports(ports=())
f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict())
self.assertEqual(f2.ports, SignalDict([
(self.s1, "o"),
@ -129,7 +129,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.c1.eq(self.s1)
)
f1.add_subfragment(f2)
f1._propagate_ports(ports=())
f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([
(self.s1, "i"),
]))
@ -148,7 +148,7 @@ class FragmentPortsTestCase(FHDLTestCase):
)
f1.add_subfragment(f2)
f1._propagate_ports(ports=(self.c2,))
f1._propagate_ports(ports=(self.c2,), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([
(self.c2, "o"),
]))
@ -170,7 +170,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f3.add_driver(self.c2)
f1.add_subfragment(f3)
f1._propagate_ports(ports=())
f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict())
def test_output_input_sibling(self):
@ -187,7 +187,7 @@ class FragmentPortsTestCase(FHDLTestCase):
)
f1.add_subfragment(f3)
f1._propagate_ports(ports=())
f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict())
def test_input_cd(self):
@ -199,7 +199,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f.add_domains(sync)
f.add_driver(self.c1, "sync")
f._propagate_ports(ports=())
f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i"),
(sync.clk, "i"),
@ -215,7 +215,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f.add_domains(sync)
f.add_driver(self.c1, "sync")
f._propagate_ports(ports=())
f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i"),
(sync.clk, "i"),
@ -224,11 +224,10 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_inout(self):
s = Signal()
f1 = Fragment()
f2 = Fragment()
f2.add_ports(s, dir="io")
f2 = Instance("foo", io_x=s)
f1.add_subfragment(f2)
f1._propagate_ports(ports=())
f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([
(s, "io")
]))
@ -557,8 +556,14 @@ class InstanceTestCase(FHDLTestCase):
self.assertEqual(f.ports, SignalDict([
(clk, "i"),
(self.rst, "i"),
(self.stb, "o"),
(self.datal, "o"),
(self.datah, "o"),
(self.pins, "io"),
]))
def test_prepare_explicit_ports(self):
self.setUp_cpu()
f = self.inst.prepare(ports=[self.rst, self.stb])
self.assertEqual(f.ports, SignalDict([
(self.rst, "i"),
(self.stb, "o"),
(self.pins, "io"),
]))

View file

@ -1,11 +1,12 @@
import contextlib
import functools
import warnings
from collections import OrderedDict
from collections.abc import Iterable
from contextlib import contextmanager
__all__ = ["flatten", "union", "log2_int", "bits_for", "deprecated"]
__all__ = ["flatten", "union", "log2_int", "bits_for", "memoize", "deprecated"]
def flatten(i):
@ -46,6 +47,16 @@ def bits_for(n, require_sign_bit=False):
return r
def memoize(f):
memo = OrderedDict()
@functools.wraps(f)
def g(*args):
if args not in memo:
memo[args] = f(*args)
return memo[args]
return g
def deprecated(message, stacklevel=2):
def decorator(f):
@functools.wraps(f)