hdl.ir: only pull explicitly specified ports to toplevel, if any.

Fixes #30.
This commit is contained in:
whitequark 2019-05-12 05:21:23 +00:00
parent 6a77122c2e
commit 958cb18b88
3 changed files with 156 additions and 65 deletions

View file

@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from functools import reduce
import warnings import warnings
import traceback import traceback
import sys import sys
@ -340,69 +341,140 @@ class Fragment:
return DomainLowerer(self.domains)(self) return DomainLowerer(self.domains)(self)
def _propagate_ports(self, ports): def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
def add_uses(*sigs, self=self):
for sig in flatten(sigs):
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
def add_defs(*sigs):
for sig in flatten(sigs):
if sig not in defs:
defs[sig] = self
else:
assert defs[sig] is self
def add_io(sig):
assert sig not in ios
ios[sig] = self
# Collect all signals we're driving (on LHS of statements), and signals we're using # Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains). # (on RHS of statements, or in clock domains).
if isinstance(self, Instance): if isinstance(self, Instance):
self_driven = SignalSet()
self_used = SignalSet()
for port_name, (value, dir) in self.named_ports.items(): for port_name, (value, dir) in self.named_ports.items():
if dir == "i": if dir == "i":
for signal in value._rhs_signals(): add_uses(value._rhs_signals())
self_used.add(signal)
self.add_ports(signal, dir="i")
if dir == "o": if dir == "o":
for signal in value._lhs_signals(): add_defs(value._lhs_signals())
self_driven.add(signal)
self.add_ports(signal, dir="o")
if dir == "io": if dir == "io":
self.add_ports(value, dir="io") add_io(value)
else: else:
self_driven = union((s._lhs_signals() for s in self.statements), start=SignalSet()) for stmt in self.statements:
self_used = union((s._rhs_signals() for s in self.statements), start=SignalSet()) add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for domain, _ in self.iter_sync(): for domain, _ in self.iter_sync():
cd = self.domains[domain] cd = self.domains[domain]
self_used.add(cd.clk) add_uses(cd.clk)
if cd.rst is not None: if cd.rst is not None:
self_used.add(cd.rst) add_uses(cd.rst)
# Our input ports are all the signals we're using but not driving. This is an over- # Repeat for subfragments.
# approximation: some of these signals may be driven by our subfragments.
ins = self_used - self_driven
# Our output ports are all the signals we're asked to provide that we're driving. This is
# an underapproximation: some of these signals may be driven by subfragments.
outs = ports & self_driven
# Go through subfragments and refine our approximation for inputs.
for subfrag, name in self.subfragments: for subfrag, name in self.subfragments:
# Refine the input port approximation: if a subfragment requires a signal as an input, parent[subfrag] = self
# and we aren't driving it, it has to be our input as well. level [subfrag] = level[self] + 1
sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=())
ins |= sub_ins - self_driven
for subfrag, name in self.subfragments: subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
# Always ask subfragments to provide all signals that are our inputs.
# If the subfragment is not driving it, it will silently ignore it.
sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=ins | ports)
# Refine the input port appropximation further: if any subfragment is driving a signal
# that we currently think should be our input, it shouldn't actually be our input.
ins -= sub_outs
# Refine the output port approximation: if a subfragment is driving a signal,
# and we're asked to provide it, we can provide it now.
outs |= ports & sub_outs
# All of our subfragments' bidirectional ports are also our bidirectional ports,
# since these are only used for pins.
self.add_ports(sub_inouts, dir="io")
# We've computed the precise set of input and output ports. def _propagate_ports(self, ports, all_undef_as_ports):
self.add_ports(ins, dir="i") # Take this fragment graph:
self.add_ports(outs, dir="o") #
# __ B (def: q, use: p r)
# /
# A (def: p, use: q r)
# \
# \_ C (def: r, use: p q)
#
# We need to consider three cases.
# 1. Signal p requires an input port in B;
# 2. Signal r requires an output port in C;
# 3. Signal r requires an output port in C and an input port in B.
#
# Adding these ports can be in general done in three steps for each signal:
# 1. Find the least common ancestor of all uses and defs.
# 2. Going upwards from the single def, add output ports.
# 3. Going upwards from all uses, add input ports.
return (SignalSet(self.iter_ports("i")), parent = {self: None}
SignalSet(self.iter_ports("o")), level = {self: 0}
SignalSet(self.iter_ports("io"))) uses = SignalDict()
defs = SignalDict()
ios = SignalDict()
self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
def prepare(self, ports=(), ensure_sync_exists=True): ports = SignalSet(ports)
if all_undef_as_ports:
for sig in uses:
if sig in defs:
continue
ports.add(sig)
for sig in ports:
if sig not in uses:
uses[sig] = set()
uses[sig].add(self)
@memoize
def lca_of(fragu, fragv):
# Normalize fragu to be deeper than fragv.
if level[fragu] < level[fragv]:
fragu, fragv = fragv, fragu
# Find ancestor of fragu on the same level as fragv.
for _ in range(level[fragu] - level[fragv]):
fragu = parent[fragu]
# If fragv was the ancestor of fragv, we're done.
if fragu == fragv:
return fragu
# Otherwise, they are at the same level but in different branches. Step both fragu
# and fragv until we find the common ancestor.
while parent[fragu] != parent[fragv]:
fragu = parent[fragu]
fragv = parent[fragv]
return parent[fragu]
for sig in uses:
if sig in defs:
lca = reduce(lca_of, uses[sig], defs[sig])
frag = defs[sig]
while frag != lca:
frag.add_ports(sig, dir="o")
frag = parent[frag]
else:
lca = reduce(lca_of, uses[sig])
for frag in uses[sig]:
if sig in defs and frag is defs[sig]:
continue
while frag != lca:
frag.add_ports(sig, dir="i")
frag = parent[frag]
for sig in ios:
frag = ios[sig]
while frag is not None:
frag.add_ports(sig, dir="io")
frag = parent[frag]
for sig in ports:
if sig in ios:
continue
if sig in defs:
self.add_ports(sig, dir="o")
else:
self.add_ports(sig, dir="i")
def prepare(self, ports=None, ensure_sync_exists=True):
from .xfrm import SampleLowerer from .xfrm import SampleLowerer
fragment = SampleLowerer()(self) fragment = SampleLowerer()(self)
@ -410,7 +482,10 @@ class Fragment:
fragment._resolve_hierarchy_conflicts() fragment._resolve_hierarchy_conflicts()
fragment = fragment._insert_domain_resets() fragment = fragment._insert_domain_resets()
fragment = fragment._lower_domain_signals() fragment = fragment._lower_domain_signals()
fragment._propagate_ports(ports) if ports is None:
fragment._propagate_ports(ports=(), all_undef_as_ports=True)
else:
fragment._propagate_ports(ports=ports, all_undef_as_ports=False)
return fragment return fragment

View file

@ -65,7 +65,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
self.assertEqual(list(f.iter_ports()), []) self.assertEqual(list(f.iter_ports()), [])
f._propagate_ports(ports=()) f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([])) self.assertEqual(f.ports, SignalDict([]))
def test_iter_signals(self): def test_iter_signals(self):
@ -80,7 +80,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.s1.eq(self.c1) self.s1.eq(self.c1)
) )
f._propagate_ports(ports=()) f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([])) self.assertEqual(f.ports, SignalDict([]))
def test_infer_input(self): def test_infer_input(self):
@ -89,7 +89,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f._propagate_ports(ports=()) f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i") (self.s1, "i")
])) ]))
@ -100,7 +100,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f._propagate_ports(ports=(self.c1,)) f._propagate_ports(ports=(self.c1,), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
(self.c1, "o") (self.c1, "o")
@ -116,7 +116,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.s1.eq(0) self.s1.eq(0)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict()) self.assertEqual(f1.ports, SignalDict())
self.assertEqual(f2.ports, SignalDict([ self.assertEqual(f2.ports, SignalDict([
(self.s1, "o"), (self.s1, "o"),
@ -129,7 +129,7 @@ class FragmentPortsTestCase(FHDLTestCase):
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([ self.assertEqual(f1.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
])) ]))
@ -148,7 +148,7 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=(self.c2,)) f1._propagate_ports(ports=(self.c2,), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([ self.assertEqual(f1.ports, SignalDict([
(self.c2, "o"), (self.c2, "o"),
])) ]))
@ -170,7 +170,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f3.add_driver(self.c2) f3.add_driver(self.c2)
f1.add_subfragment(f3) f1.add_subfragment(f3)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict()) self.assertEqual(f1.ports, SignalDict())
def test_output_input_sibling(self): def test_output_input_sibling(self):
@ -187,7 +187,7 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f1.add_subfragment(f3) f1.add_subfragment(f3)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict()) self.assertEqual(f1.ports, SignalDict())
def test_input_cd(self): def test_input_cd(self):
@ -199,7 +199,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f.add_domains(sync) f.add_domains(sync)
f.add_driver(self.c1, "sync") f.add_driver(self.c1, "sync")
f._propagate_ports(ports=()) f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
(sync.clk, "i"), (sync.clk, "i"),
@ -215,7 +215,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f.add_domains(sync) f.add_domains(sync)
f.add_driver(self.c1, "sync") f.add_driver(self.c1, "sync")
f._propagate_ports(ports=()) f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
(sync.clk, "i"), (sync.clk, "i"),
@ -224,11 +224,10 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_inout(self): def test_inout(self):
s = Signal() s = Signal()
f1 = Fragment() f1 = Fragment()
f2 = Fragment() f2 = Instance("foo", io_x=s)
f2.add_ports(s, dir="io")
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([ self.assertEqual(f1.ports, SignalDict([
(s, "io") (s, "io")
])) ]))
@ -557,8 +556,14 @@ class InstanceTestCase(FHDLTestCase):
self.assertEqual(f.ports, SignalDict([ self.assertEqual(f.ports, SignalDict([
(clk, "i"), (clk, "i"),
(self.rst, "i"), (self.rst, "i"),
(self.stb, "o"), (self.pins, "io"),
(self.datal, "o"), ]))
(self.datah, "o"),
def test_prepare_explicit_ports(self):
self.setUp_cpu()
f = self.inst.prepare(ports=[self.rst, self.stb])
self.assertEqual(f.ports, SignalDict([
(self.rst, "i"),
(self.stb, "o"),
(self.pins, "io"), (self.pins, "io"),
])) ]))

View file

@ -1,11 +1,12 @@
import contextlib import contextlib
import functools import functools
import warnings import warnings
from collections import OrderedDict
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
__all__ = ["flatten", "union", "log2_int", "bits_for", "deprecated"] __all__ = ["flatten", "union", "log2_int", "bits_for", "memoize", "deprecated"]
def flatten(i): def flatten(i):
@ -46,6 +47,16 @@ def bits_for(n, require_sign_bit=False):
return r return r
def memoize(f):
memo = OrderedDict()
@functools.wraps(f)
def g(*args):
if args not in memo:
memo[args] = f(*args)
return memo[args]
return g
def deprecated(message, stacklevel=2): def deprecated(message, stacklevel=2):
def decorator(f): def decorator(f):
@functools.wraps(f) @functools.wraps(f)