hdl.ir: associate statements with domains.

Fixes #1079.
This commit is contained in:
Wanda 2024-02-09 01:53:45 +01:00 committed by Catherine
parent 09854fa775
commit 6e06fc013f
12 changed files with 313 additions and 196 deletions

View file

@ -667,6 +667,7 @@ class _StatementCompiler(_xfrm.StatementVisitor):
self.rhs_compiler = rhs_compiler
self.lhs_compiler = lhs_compiler
self._domain = None
self._case = None
self._test_cache = {}
self._has_rhs = False
@ -865,8 +866,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Register all signals driven in the current fragment. This must be done first, as it
# affects further codegen; e.g. whether \sig$next signals will be generated and used.
for domain, signal in fragment.iter_drivers():
compiler_state.add_driven(signal, sync=domain is not None)
for domain, statements in fragment.statements.items():
for signal in statements._lhs_signals():
compiler_state.add_driven(signal, sync=domain is not None)
# Transform all signals used as ports in the current fragment eagerly and outside of
# any hierarchy, to make sure they get sensible (non-prefixed) names.
@ -925,32 +927,32 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Therefore, we translate the fragment as many times as there are independent groups
# of signals (a group is a transitive closure of signals that appear together on LHS),
# splitting them into many RTLIL (and thus Verilog) processes.
lhs_grouper = _xfrm.LHSGroupAnalyzer()
lhs_grouper.on_statements(fragment.statements)
for domain, statements in fragment.statements.items():
lhs_grouper = _xfrm.LHSGroupAnalyzer()
lhs_grouper.on_statements(statements)
for group, group_signals in lhs_grouper.groups().items():
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
group_stmts = lhs_group_filter(fragment.statements)
for group, group_signals in lhs_grouper.groups().items():
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
group_stmts = lhs_group_filter(statements)
with module.process(name=f"$group_{group}") as process:
with process.case() as case:
# For every signal in comb domain, assign \sig$next to the reset value.
# For every signal in sync domains, assign \sig$next to the current
# value (\sig).
for domain, signal in fragment.iter_drivers():
if signal not in group_signals:
continue
if domain is None:
prev_value = _ast.Const(signal.reset, signal.width)
else:
prev_value = signal
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
with module.process(name=f"$group_{group}") as process:
with process.case() as case:
# For every signal in comb domain, assign \sig$next to the reset value.
# For every signal in sync domains, assign \sig$next to the current
# value (\sig).
for signal in group_signals:
if domain is None:
prev_value = _ast.Const(signal.reset, signal.width)
else:
prev_value = signal
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
# Convert statements into decision trees.
stmt_compiler._case = case
stmt_compiler._has_rhs = False
stmt_compiler._wrap_assign = False
stmt_compiler(group_stmts)
# Convert statements into decision trees.
stmt_compiler._domain = domain
stmt_compiler._case = case
stmt_compiler._has_rhs = False
stmt_compiler._wrap_assign = False
stmt_compiler(group_stmts)
# For every driven signal in the sync domain, create a flop of appropriate type. Which type
# is appropriate depends on the domain: for domains with sync reset, it is a $dff, for
@ -998,8 +1000,8 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# to drive it to reset value arbitrarily) or to replace them with their reset value (which
# removes valuable source location information).
driven = _ast.SignalSet()
for domain, signals in fragment.iter_drivers():
driven.update(flatten(signal._lhs_signals() for signal in signals))
for domain, statements in fragment.statements.items():
driven.update(statements._lhs_signals())
driven.update(fragment.iter_ports(dir="i"))
driven.update(fragment.iter_ports(dir="io"))
for subfragment, sub_name in fragment.subfragments:

View file

@ -1718,6 +1718,12 @@ class _StatementList(list):
def __repr__(self):
return "({})".format(" ".join(map(repr, self)))
def _lhs_signals(self):
return union((s._lhs_signals() for s in self), start=SignalSet())
def _rhs_signals(self):
return union((s._rhs_signals() for s in self), start=SignalSet())
class Statement:
def __init__(self, *, src_loc_at=0):
@ -1849,13 +1855,10 @@ class Switch(Statement):
self.case_src_locs[new_keys] = case_src_locs[orig_keys]
def _lhs_signals(self):
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
start=SignalSet())
return signals
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())
def _rhs_signals(self):
signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss),
start=SignalSet())
signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
return self.test._rhs_signals() | signals
def __repr__(self):

View file

@ -170,7 +170,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self.submodules = _ModuleBuilderSubmodules(self)
self.domains = _ModuleBuilderDomainSet(self)
self._statements = Statement.cast([])
self._statements = {}
self._ctrl_context = None
self._ctrl_stack = []
@ -234,7 +234,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
"src_locs": [],
})
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self.domain._depth += 1
yield
self._flush_ctrl()
@ -254,7 +254,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if if_data is None or if_data["depth"] != self.domain._depth:
raise SyntaxError("Elif without preceding If")
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self.domain._depth += 1
yield
self._flush_ctrl()
@ -273,7 +273,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if if_data is None or if_data["depth"] != self.domain._depth:
raise SyntaxError("Else without preceding If/Elif")
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self.domain._depth += 1
yield
self._flush_ctrl()
@ -341,7 +341,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
continue
new_patterns = (*new_patterns, pattern.value)
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
yield
self._flush_ctrl()
@ -364,7 +364,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3)
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
yield
self._flush_ctrl()
@ -416,7 +416,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if name not in fsm_data["encoding"]:
fsm_data["encoding"][name] = len(fsm_data["encoding"])
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
yield
self._flush_ctrl()
@ -453,28 +453,42 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if_tests, if_bodies = data["tests"], data["bodies"]
if_src_locs = data["src_locs"]
tests, cases = [], OrderedDict()
for if_test, if_case in zip(if_tests + [None], if_bodies):
if if_test is not None:
if len(if_test) != 1:
if_test = if_test.bool()
tests.append(if_test)
domains = set()
for if_case in if_bodies:
domains |= set(if_case)
if if_test is not None:
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
else:
match = None
cases[match] = if_case
for domain in domains:
tests, cases = [], OrderedDict()
for if_test, if_case in zip(if_tests + [None], if_bodies):
if if_test is not None:
if len(if_test) != 1:
if_test = if_test.bool()
tests.append(if_test)
self._statements.append(Switch(Cat(tests), cases,
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
if if_test is not None:
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
else:
match = None
cases[match] = if_case.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
if name == "Switch":
switch_test, switch_cases = data["test"], data["cases"]
switch_case_src_locs = data["case_src_locs"]
self._statements.append(Switch(switch_test, switch_cases,
src_loc=src_loc, case_src_locs=switch_case_src_locs))
domains = set()
for stmts in switch_cases.values():
domains |= set(stmts)
for domain in domains:
domain_cases = OrderedDict()
for patterns, stmts in switch_cases.items():
domain_cases[patterns] = stmts.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
src_loc=src_loc, case_src_locs=switch_case_src_locs))
if name == "FSM":
fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \
@ -490,10 +504,20 @@ class Module(_ModuleBuilderRoot, Elaboratable):
# The FSM is encoded such that the state with encoding 0 is always the reset state.
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
self._statements.append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in fsm_states.items()),
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
for name in fsm_states}))
domains = set()
for stmts in fsm_states.values():
domains |= set(stmts)
for domain in domains:
domain_states = OrderedDict()
for state, stmts in fsm_states.items():
domain_states[state] = stmts.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
for name in fsm_states}))
def _add_statement(self, assigns, domain, depth):
def domain_name(domain):
@ -523,7 +547,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
"already driven from d.{}"
.format(signal, domain_name(domain), domain_name(cd_curr)))
self._statements.append(stmt)
self._statements.setdefault(domain, []).append(stmt)
def _add_submodule(self, submodule, name=None):
if not hasattr(submodule, "elaborate"):
@ -559,7 +583,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name)
for submodule in self._anon_submodules:
fragment.add_subfragment(Fragment.get(submodule, platform), None)
fragment.add_statements(self._statements)
for domain, statements in self._statements.items():
fragment.add_statements(domain, statements)
for signal, domain in self._driving.items():
fragment.add_driver(signal, domain)
fragment.add_domains(self._domains.values())

View file

@ -7,6 +7,7 @@ from .. import tracer
from .._utils import *
from .._unused import *
from ._ast import *
from ._ast import _StatementList
from ._cd import *
@ -65,7 +66,7 @@ class Fragment:
def __init__(self):
self.ports = SignalDict()
self.drivers = OrderedDict()
self.statements = []
self.statements = {}
self.domains = OrderedDict()
self.subfragments = []
self.attrs = OrderedDict()
@ -127,10 +128,11 @@ class Fragment:
def iter_domains(self):
yield from self.domains
def add_statements(self, *stmts):
def add_statements(self, domain, *stmts):
assert domain is None or isinstance(domain, str)
for stmt in Statement.cast(stmts):
stmt._MustUse__used = True
self.statements.append(stmt)
self.statements.setdefault(domain, _StatementList()).append(stmt)
def add_subfragment(self, subfragment, name=None):
assert isinstance(subfragment, Fragment)
@ -166,7 +168,8 @@ class Fragment:
self.ports.update(subfragment.ports)
for domain, signal in subfragment.iter_drivers():
self.add_driver(signal, domain)
self.statements += subfragment.statements
for domain, statements in subfragment.statements.items():
self.statements.setdefault(domain, []).extend(statements)
self.subfragments += subfragment.subfragments
# Remove the merged subfragment.
@ -387,9 +390,10 @@ class Fragment:
# Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains).
for stmt in self.statements:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for stmts in self.statements.values():
for stmt in stmts:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for domain, _ in self.iter_sync():
cd = self.domains[domain]
@ -572,10 +576,11 @@ class Fragment:
if domain.rst is not None:
add_signal_name(domain.rst)
for statement in self.statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (ClockSignal, ResetSignal)):
add_signal_name(signal)
for statements in self.statements.values():
for statement in statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (ClockSignal, ResetSignal)):
add_signal_name(signal)
return signal_names

View file

@ -124,7 +124,7 @@ class Memory(Elaboratable):
port._MustUse__used = True
if port.domain == "comb":
# Asynchronous port
f.add_statements(port.data.eq(self._array[port.addr]))
f.add_statements(None, port.data.eq(self._array[port.addr]))
f.add_driver(port.data)
else:
# Synchronous port
@ -143,6 +143,7 @@ class Memory(Elaboratable):
cond = write_port.en & (port.addr == write_port.addr)
data = Mux(cond, write_port.data, data)
f.add_statements(
port.domain,
Switch(port.en, {
1: port.data.eq(data)
})
@ -155,10 +156,10 @@ class Memory(Elaboratable):
offset = index * port.granularity
bits = slice(offset, offset + port.granularity)
write_data = self._array[port.addr][bits].eq(port.data[bits])
f.add_statements(Switch(en_bit, { 1: write_data }))
f.add_statements(port.domain, Switch(en_bit, { 1: write_data }))
else:
write_data = self._array[port.addr].eq(port.data)
f.add_statements(Switch(port.en, { 1: write_data }))
f.add_statements(port.domain, Switch(port.en, { 1: write_data }))
for signal in self._array:
f.add_driver(signal, port.domain)
return f

View file

@ -228,9 +228,11 @@ class FragmentTransformer:
def map_statements(self, fragment, new_fragment):
if hasattr(self, "on_statement"):
new_fragment.add_statements(map(self.on_statement, fragment.statements))
for domain, statements in fragment.statements.items():
new_fragment.add_statements(domain, map(self.on_statement, statements))
else:
new_fragment.add_statements(fragment.statements)
for domain, statements in fragment.statements.items():
new_fragment.add_statements(domain, statements)
def map_drivers(self, fragment, new_fragment):
for domain, signal in fragment.iter_drivers():
@ -397,9 +399,9 @@ class DomainCollector(ValueVisitor, StatementVisitor):
else:
self.defined_domains.add(domain_name)
self.on_statements(fragment.statements)
for domain_name in fragment.drivers:
for domain_name, statements in fragment.statements.items():
self._add_used_domain(domain_name)
self.on_statements(statements)
for subfragment, name in fragment.subfragments:
self.on_fragment(subfragment)
@ -442,6 +444,13 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer)
assert cd.name == self.domain_map[domain]
new_fragment.add_domains(cd)
def map_statements(self, fragment, new_fragment):
for domain, statements in fragment.statements.items():
new_fragment.add_statements(
self.domain_map.get(domain, domain),
map(self.on_statement, statements)
)
def map_drivers(self, fragment, new_fragment):
for domain, signals in fragment.drivers.items():
if domain in self.domain_map:
@ -499,7 +508,7 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
continue
stmts = [signal.eq(Const(signal.reset, signal.width))
for signal in signals if not signal.reset_less]
fragment.add_statements(Switch(domain.rst, {1: stmts}))
fragment.add_statements(domain_name, Switch(domain.rst, {1: stmts}))
def on_fragment(self, fragment):
self.domains = fragment.domains
@ -571,6 +580,7 @@ class LHSGroupAnalyzer(StatementVisitor):
self.on_statements(case_stmts)
def on_statements(self, stmts):
assert not isinstance(stmts, str)
for stmt in stmts:
self.on_statement(stmt)
@ -624,13 +634,13 @@ class _ControlInserter(FragmentTransformer):
class ResetInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(Const(s.reset, s.width)) for s in signals if not s.reset_less]
fragment.add_statements(Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
class EnableInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(s) for s in signals]
fragment.add_statements(Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc))
fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc))
def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment)

View file

@ -5,7 +5,7 @@ import sys
from ..hdl import *
from ..hdl._ast import SignalSet
from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
from ..hdl._xfrm import ValueVisitor, StatementVisitor
from ._base import BaseProcess
@ -409,9 +409,9 @@ class _FragmentCompiler:
def __call__(self, fragment):
processes = set()
for domain_name, domain_signals in fragment.drivers.items():
domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
for domain_name, domain_stmts in fragment.statements.items():
domain_process = PyRTLProcess(is_comb=domain_name is None)
domain_signals = domain_stmts._lhs_signals()
emitter = _PythonEmitter()
emitter.append(f"def run():")