hdl.ir: associate statements with domains.

Fixes #1079.
This commit is contained in:
Wanda 2024-02-09 01:53:45 +01:00 committed by Catherine
parent 09854fa775
commit 6e06fc013f
12 changed files with 313 additions and 196 deletions

View file

@ -667,6 +667,7 @@ class _StatementCompiler(_xfrm.StatementVisitor):
self.rhs_compiler = rhs_compiler self.rhs_compiler = rhs_compiler
self.lhs_compiler = lhs_compiler self.lhs_compiler = lhs_compiler
self._domain = None
self._case = None self._case = None
self._test_cache = {} self._test_cache = {}
self._has_rhs = False self._has_rhs = False
@ -865,8 +866,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Register all signals driven in the current fragment. This must be done first, as it # Register all signals driven in the current fragment. This must be done first, as it
# affects further codegen; e.g. whether \sig$next signals will be generated and used. # affects further codegen; e.g. whether \sig$next signals will be generated and used.
for domain, signal in fragment.iter_drivers(): for domain, statements in fragment.statements.items():
compiler_state.add_driven(signal, sync=domain is not None) for signal in statements._lhs_signals():
compiler_state.add_driven(signal, sync=domain is not None)
# Transform all signals used as ports in the current fragment eagerly and outside of # Transform all signals used as ports in the current fragment eagerly and outside of
# any hierarchy, to make sure they get sensible (non-prefixed) names. # any hierarchy, to make sure they get sensible (non-prefixed) names.
@ -925,32 +927,32 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Therefore, we translate the fragment as many times as there are independent groups # Therefore, we translate the fragment as many times as there are independent groups
# of signals (a group is a transitive closure of signals that appear together on LHS), # of signals (a group is a transitive closure of signals that appear together on LHS),
# splitting them into many RTLIL (and thus Verilog) processes. # splitting them into many RTLIL (and thus Verilog) processes.
lhs_grouper = _xfrm.LHSGroupAnalyzer() for domain, statements in fragment.statements.items():
lhs_grouper.on_statements(fragment.statements) lhs_grouper = _xfrm.LHSGroupAnalyzer()
lhs_grouper.on_statements(statements)
for group, group_signals in lhs_grouper.groups().items(): for group, group_signals in lhs_grouper.groups().items():
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals) lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
group_stmts = lhs_group_filter(fragment.statements) group_stmts = lhs_group_filter(statements)
with module.process(name=f"$group_{group}") as process: with module.process(name=f"$group_{group}") as process:
with process.case() as case: with process.case() as case:
# For every signal in comb domain, assign \sig$next to the reset value. # For every signal in comb domain, assign \sig$next to the reset value.
# For every signal in sync domains, assign \sig$next to the current # For every signal in sync domains, assign \sig$next to the current
# value (\sig). # value (\sig).
for domain, signal in fragment.iter_drivers(): for signal in group_signals:
if signal not in group_signals: if domain is None:
continue prev_value = _ast.Const(signal.reset, signal.width)
if domain is None: else:
prev_value = _ast.Const(signal.reset, signal.width) prev_value = signal
else: case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
prev_value = signal
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
# Convert statements into decision trees. # Convert statements into decision trees.
stmt_compiler._case = case stmt_compiler._domain = domain
stmt_compiler._has_rhs = False stmt_compiler._case = case
stmt_compiler._wrap_assign = False stmt_compiler._has_rhs = False
stmt_compiler(group_stmts) stmt_compiler._wrap_assign = False
stmt_compiler(group_stmts)
# For every driven signal in the sync domain, create a flop of appropriate type. Which type # For every driven signal in the sync domain, create a flop of appropriate type. Which type
# is appropriate depends on the domain: for domains with sync reset, it is a $dff, for # is appropriate depends on the domain: for domains with sync reset, it is a $dff, for
@ -998,8 +1000,8 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# to drive it to reset value arbitrarily) or to replace them with their reset value (which # to drive it to reset value arbitrarily) or to replace them with their reset value (which
# removes valuable source location information). # removes valuable source location information).
driven = _ast.SignalSet() driven = _ast.SignalSet()
for domain, signals in fragment.iter_drivers(): for domain, statements in fragment.statements.items():
driven.update(flatten(signal._lhs_signals() for signal in signals)) driven.update(statements._lhs_signals())
driven.update(fragment.iter_ports(dir="i")) driven.update(fragment.iter_ports(dir="i"))
driven.update(fragment.iter_ports(dir="io")) driven.update(fragment.iter_ports(dir="io"))
for subfragment, sub_name in fragment.subfragments: for subfragment, sub_name in fragment.subfragments:

View file

@ -1718,6 +1718,12 @@ class _StatementList(list):
def __repr__(self): def __repr__(self):
return "({})".format(" ".join(map(repr, self))) return "({})".format(" ".join(map(repr, self)))
def _lhs_signals(self):
return union((s._lhs_signals() for s in self), start=SignalSet())
def _rhs_signals(self):
return union((s._rhs_signals() for s in self), start=SignalSet())
class Statement: class Statement:
def __init__(self, *, src_loc_at=0): def __init__(self, *, src_loc_at=0):
@ -1849,13 +1855,10 @@ class Switch(Statement):
self.case_src_locs[new_keys] = case_src_locs[orig_keys] self.case_src_locs[new_keys] = case_src_locs[orig_keys]
def _lhs_signals(self): def _lhs_signals(self):
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss), return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())
start=SignalSet())
return signals
def _rhs_signals(self): def _rhs_signals(self):
signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss), signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
start=SignalSet())
return self.test._rhs_signals() | signals return self.test._rhs_signals() | signals
def __repr__(self): def __repr__(self):

View file

@ -170,7 +170,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self.submodules = _ModuleBuilderSubmodules(self) self.submodules = _ModuleBuilderSubmodules(self)
self.domains = _ModuleBuilderDomainSet(self) self.domains = _ModuleBuilderDomainSet(self)
self._statements = Statement.cast([]) self._statements = {}
self._ctrl_context = None self._ctrl_context = None
self._ctrl_stack = [] self._ctrl_stack = []
@ -234,7 +234,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
"src_locs": [], "src_locs": [],
}) })
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, {}
self.domain._depth += 1 self.domain._depth += 1
yield yield
self._flush_ctrl() self._flush_ctrl()
@ -254,7 +254,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if if_data is None or if_data["depth"] != self.domain._depth: if if_data is None or if_data["depth"] != self.domain._depth:
raise SyntaxError("Elif without preceding If") raise SyntaxError("Elif without preceding If")
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, {}
self.domain._depth += 1 self.domain._depth += 1
yield yield
self._flush_ctrl() self._flush_ctrl()
@ -273,7 +273,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if if_data is None or if_data["depth"] != self.domain._depth: if if_data is None or if_data["depth"] != self.domain._depth:
raise SyntaxError("Else without preceding If/Elif") raise SyntaxError("Else without preceding If/Elif")
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, {}
self.domain._depth += 1 self.domain._depth += 1
yield yield
self._flush_ctrl() self._flush_ctrl()
@ -341,7 +341,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
continue continue
new_patterns = (*new_patterns, pattern.value) new_patterns = (*new_patterns, pattern.value)
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, {}
self._ctrl_context = None self._ctrl_context = None
yield yield
self._flush_ctrl() self._flush_ctrl()
@ -364,7 +364,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
warnings.warn("A case defined after the default case will never be active", warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3) SyntaxWarning, stacklevel=3)
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, {}
self._ctrl_context = None self._ctrl_context = None
yield yield
self._flush_ctrl() self._flush_ctrl()
@ -416,7 +416,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if name not in fsm_data["encoding"]: if name not in fsm_data["encoding"]:
fsm_data["encoding"][name] = len(fsm_data["encoding"]) fsm_data["encoding"][name] = len(fsm_data["encoding"])
try: try:
_outer_case, self._statements = self._statements, [] _outer_case, self._statements = self._statements, {}
self._ctrl_context = None self._ctrl_context = None
yield yield
self._flush_ctrl() self._flush_ctrl()
@ -453,28 +453,42 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if_tests, if_bodies = data["tests"], data["bodies"] if_tests, if_bodies = data["tests"], data["bodies"]
if_src_locs = data["src_locs"] if_src_locs = data["src_locs"]
tests, cases = [], OrderedDict() domains = set()
for if_test, if_case in zip(if_tests + [None], if_bodies): for if_case in if_bodies:
if if_test is not None: domains |= set(if_case)
if len(if_test) != 1:
if_test = if_test.bool()
tests.append(if_test)
if if_test is not None: for domain in domains:
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-") tests, cases = [], OrderedDict()
else: for if_test, if_case in zip(if_tests + [None], if_bodies):
match = None if if_test is not None:
cases[match] = if_case if len(if_test) != 1:
if_test = if_test.bool()
tests.append(if_test)
self._statements.append(Switch(Cat(tests), cases, if if_test is not None:
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs)))) match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
else:
match = None
cases[match] = if_case.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
if name == "Switch": if name == "Switch":
switch_test, switch_cases = data["test"], data["cases"] switch_test, switch_cases = data["test"], data["cases"]
switch_case_src_locs = data["case_src_locs"] switch_case_src_locs = data["case_src_locs"]
self._statements.append(Switch(switch_test, switch_cases, domains = set()
src_loc=src_loc, case_src_locs=switch_case_src_locs)) for stmts in switch_cases.values():
domains |= set(stmts)
for domain in domains:
domain_cases = OrderedDict()
for patterns, stmts in switch_cases.items():
domain_cases[patterns] = stmts.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
src_loc=src_loc, case_src_locs=switch_case_src_locs))
if name == "FSM": if name == "FSM":
fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \ fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \
@ -490,10 +504,20 @@ class Module(_ModuleBuilderRoot, Elaboratable):
# The FSM is encoded such that the state with encoding 0 is always the reset state. # The FSM is encoded such that the state with encoding 0 is always the reset state.
fsm_decoding.update((n, s) for s, n in fsm_encoding.items()) fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}" fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
self._statements.append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in fsm_states.items()), domains = set()
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name] for stmts in fsm_states.values():
for name in fsm_states})) domains |= set(stmts)
for domain in domains:
domain_states = OrderedDict()
for state, stmts in fsm_states.items():
domain_states[state] = stmts.get(domain, [])
self._statements.setdefault(domain, []).append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
for name in fsm_states}))
def _add_statement(self, assigns, domain, depth): def _add_statement(self, assigns, domain, depth):
def domain_name(domain): def domain_name(domain):
@ -523,7 +547,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
"already driven from d.{}" "already driven from d.{}"
.format(signal, domain_name(domain), domain_name(cd_curr))) .format(signal, domain_name(domain), domain_name(cd_curr)))
self._statements.append(stmt) self._statements.setdefault(domain, []).append(stmt)
def _add_submodule(self, submodule, name=None): def _add_submodule(self, submodule, name=None):
if not hasattr(submodule, "elaborate"): if not hasattr(submodule, "elaborate"):
@ -559,7 +583,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name) fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name)
for submodule in self._anon_submodules: for submodule in self._anon_submodules:
fragment.add_subfragment(Fragment.get(submodule, platform), None) fragment.add_subfragment(Fragment.get(submodule, platform), None)
fragment.add_statements(self._statements) for domain, statements in self._statements.items():
fragment.add_statements(domain, statements)
for signal, domain in self._driving.items(): for signal, domain in self._driving.items():
fragment.add_driver(signal, domain) fragment.add_driver(signal, domain)
fragment.add_domains(self._domains.values()) fragment.add_domains(self._domains.values())

View file

@ -7,6 +7,7 @@ from .. import tracer
from .._utils import * from .._utils import *
from .._unused import * from .._unused import *
from ._ast import * from ._ast import *
from ._ast import _StatementList
from ._cd import * from ._cd import *
@ -65,7 +66,7 @@ class Fragment:
def __init__(self): def __init__(self):
self.ports = SignalDict() self.ports = SignalDict()
self.drivers = OrderedDict() self.drivers = OrderedDict()
self.statements = [] self.statements = {}
self.domains = OrderedDict() self.domains = OrderedDict()
self.subfragments = [] self.subfragments = []
self.attrs = OrderedDict() self.attrs = OrderedDict()
@ -127,10 +128,11 @@ class Fragment:
def iter_domains(self): def iter_domains(self):
yield from self.domains yield from self.domains
def add_statements(self, *stmts): def add_statements(self, domain, *stmts):
assert domain is None or isinstance(domain, str)
for stmt in Statement.cast(stmts): for stmt in Statement.cast(stmts):
stmt._MustUse__used = True stmt._MustUse__used = True
self.statements.append(stmt) self.statements.setdefault(domain, _StatementList()).append(stmt)
def add_subfragment(self, subfragment, name=None): def add_subfragment(self, subfragment, name=None):
assert isinstance(subfragment, Fragment) assert isinstance(subfragment, Fragment)
@ -166,7 +168,8 @@ class Fragment:
self.ports.update(subfragment.ports) self.ports.update(subfragment.ports)
for domain, signal in subfragment.iter_drivers(): for domain, signal in subfragment.iter_drivers():
self.add_driver(signal, domain) self.add_driver(signal, domain)
self.statements += subfragment.statements for domain, statements in subfragment.statements.items():
self.statements.setdefault(domain, []).extend(statements)
self.subfragments += subfragment.subfragments self.subfragments += subfragment.subfragments
# Remove the merged subfragment. # Remove the merged subfragment.
@ -387,9 +390,10 @@ class Fragment:
# Collect all signals we're driving (on LHS of statements), and signals we're using # Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains). # (on RHS of statements, or in clock domains).
for stmt in self.statements: for stmts in self.statements.values():
add_uses(stmt._rhs_signals()) for stmt in stmts:
add_defs(stmt._lhs_signals()) add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for domain, _ in self.iter_sync(): for domain, _ in self.iter_sync():
cd = self.domains[domain] cd = self.domains[domain]
@ -572,10 +576,11 @@ class Fragment:
if domain.rst is not None: if domain.rst is not None:
add_signal_name(domain.rst) add_signal_name(domain.rst)
for statement in self.statements: for statements in self.statements.values():
for signal in statement._lhs_signals() | statement._rhs_signals(): for statement in statements:
if not isinstance(signal, (ClockSignal, ResetSignal)): for signal in statement._lhs_signals() | statement._rhs_signals():
add_signal_name(signal) if not isinstance(signal, (ClockSignal, ResetSignal)):
add_signal_name(signal)
return signal_names return signal_names

View file

@ -124,7 +124,7 @@ class Memory(Elaboratable):
port._MustUse__used = True port._MustUse__used = True
if port.domain == "comb": if port.domain == "comb":
# Asynchronous port # Asynchronous port
f.add_statements(port.data.eq(self._array[port.addr])) f.add_statements(None, port.data.eq(self._array[port.addr]))
f.add_driver(port.data) f.add_driver(port.data)
else: else:
# Synchronous port # Synchronous port
@ -143,6 +143,7 @@ class Memory(Elaboratable):
cond = write_port.en & (port.addr == write_port.addr) cond = write_port.en & (port.addr == write_port.addr)
data = Mux(cond, write_port.data, data) data = Mux(cond, write_port.data, data)
f.add_statements( f.add_statements(
port.domain,
Switch(port.en, { Switch(port.en, {
1: port.data.eq(data) 1: port.data.eq(data)
}) })
@ -155,10 +156,10 @@ class Memory(Elaboratable):
offset = index * port.granularity offset = index * port.granularity
bits = slice(offset, offset + port.granularity) bits = slice(offset, offset + port.granularity)
write_data = self._array[port.addr][bits].eq(port.data[bits]) write_data = self._array[port.addr][bits].eq(port.data[bits])
f.add_statements(Switch(en_bit, { 1: write_data })) f.add_statements(port.domain, Switch(en_bit, { 1: write_data }))
else: else:
write_data = self._array[port.addr].eq(port.data) write_data = self._array[port.addr].eq(port.data)
f.add_statements(Switch(port.en, { 1: write_data })) f.add_statements(port.domain, Switch(port.en, { 1: write_data }))
for signal in self._array: for signal in self._array:
f.add_driver(signal, port.domain) f.add_driver(signal, port.domain)
return f return f

View file

@ -228,9 +228,11 @@ class FragmentTransformer:
def map_statements(self, fragment, new_fragment): def map_statements(self, fragment, new_fragment):
if hasattr(self, "on_statement"): if hasattr(self, "on_statement"):
new_fragment.add_statements(map(self.on_statement, fragment.statements)) for domain, statements in fragment.statements.items():
new_fragment.add_statements(domain, map(self.on_statement, statements))
else: else:
new_fragment.add_statements(fragment.statements) for domain, statements in fragment.statements.items():
new_fragment.add_statements(domain, statements)
def map_drivers(self, fragment, new_fragment): def map_drivers(self, fragment, new_fragment):
for domain, signal in fragment.iter_drivers(): for domain, signal in fragment.iter_drivers():
@ -397,9 +399,9 @@ class DomainCollector(ValueVisitor, StatementVisitor):
else: else:
self.defined_domains.add(domain_name) self.defined_domains.add(domain_name)
self.on_statements(fragment.statements) for domain_name, statements in fragment.statements.items():
for domain_name in fragment.drivers:
self._add_used_domain(domain_name) self._add_used_domain(domain_name)
self.on_statements(statements)
for subfragment, name in fragment.subfragments: for subfragment, name in fragment.subfragments:
self.on_fragment(subfragment) self.on_fragment(subfragment)
@ -442,6 +444,13 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer)
assert cd.name == self.domain_map[domain] assert cd.name == self.domain_map[domain]
new_fragment.add_domains(cd) new_fragment.add_domains(cd)
def map_statements(self, fragment, new_fragment):
for domain, statements in fragment.statements.items():
new_fragment.add_statements(
self.domain_map.get(domain, domain),
map(self.on_statement, statements)
)
def map_drivers(self, fragment, new_fragment): def map_drivers(self, fragment, new_fragment):
for domain, signals in fragment.drivers.items(): for domain, signals in fragment.drivers.items():
if domain in self.domain_map: if domain in self.domain_map:
@ -499,7 +508,7 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
continue continue
stmts = [signal.eq(Const(signal.reset, signal.width)) stmts = [signal.eq(Const(signal.reset, signal.width))
for signal in signals if not signal.reset_less] for signal in signals if not signal.reset_less]
fragment.add_statements(Switch(domain.rst, {1: stmts})) fragment.add_statements(domain_name, Switch(domain.rst, {1: stmts}))
def on_fragment(self, fragment): def on_fragment(self, fragment):
self.domains = fragment.domains self.domains = fragment.domains
@ -571,6 +580,7 @@ class LHSGroupAnalyzer(StatementVisitor):
self.on_statements(case_stmts) self.on_statements(case_stmts)
def on_statements(self, stmts): def on_statements(self, stmts):
assert not isinstance(stmts, str)
for stmt in stmts: for stmt in stmts:
self.on_statement(stmt) self.on_statement(stmt)
@ -624,13 +634,13 @@ class _ControlInserter(FragmentTransformer):
class ResetInserter(_ControlInserter): class ResetInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals): def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(Const(s.reset, s.width)) for s in signals if not s.reset_less] stmts = [s.eq(Const(s.reset, s.width)) for s in signals if not s.reset_less]
fragment.add_statements(Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
class EnableInserter(_ControlInserter): class EnableInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals): def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(s) for s in signals] stmts = [s.eq(s) for s in signals]
fragment.add_statements(Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc))
def on_fragment(self, fragment): def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment) new_fragment = super().on_fragment(fragment)

View file

@ -5,7 +5,7 @@ import sys
from ..hdl import * from ..hdl import *
from ..hdl._ast import SignalSet from ..hdl._ast import SignalSet
from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter from ..hdl._xfrm import ValueVisitor, StatementVisitor
from ._base import BaseProcess from ._base import BaseProcess
@ -409,9 +409,9 @@ class _FragmentCompiler:
def __call__(self, fragment): def __call__(self, fragment):
processes = set() processes = set()
for domain_name, domain_signals in fragment.drivers.items(): for domain_name, domain_stmts in fragment.statements.items():
domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
domain_process = PyRTLProcess(is_comb=domain_name is None) domain_process = PyRTLProcess(is_comb=domain_name is None)
domain_signals = domain_stmts._lhs_signals()
emitter = _PythonEmitter() emitter = _PythonEmitter()
emitter.append(f"def run():") emitter.append(f"def run():")

View file

@ -34,7 +34,7 @@ class DSLTestCase(FHDLTestCase):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
m._flush() m._flush()
self.assertEqual(m._driving[self.c1], None) self.assertEqual(m._driving[self.c1], None)
self.assertRepr(m._statements, """( self.assertRepr(m._statements[None], """(
(eq (sig c1) (const 1'd1)) (eq (sig c1) (const 1'd1))
)""") )""")
@ -43,7 +43,7 @@ class DSLTestCase(FHDLTestCase):
m.d.sync += self.c1.eq(1) m.d.sync += self.c1.eq(1)
m._flush() m._flush()
self.assertEqual(m._driving[self.c1], "sync") self.assertEqual(m._driving[self.c1], "sync")
self.assertRepr(m._statements, """( self.assertRepr(m._statements["sync"], """(
(eq (sig c1) (const 1'd1)) (eq (sig c1) (const 1'd1))
)""") )""")
@ -52,7 +52,7 @@ class DSLTestCase(FHDLTestCase):
m.d.pix += self.c1.eq(1) m.d.pix += self.c1.eq(1)
m._flush() m._flush()
self.assertEqual(m._driving[self.c1], "pix") self.assertEqual(m._driving[self.c1], "pix")
self.assertRepr(m._statements, """( self.assertRepr(m._statements["pix"], """(
(eq (sig c1) (const 1'd1)) (eq (sig c1) (const 1'd1))
)""") )""")
@ -61,7 +61,7 @@ class DSLTestCase(FHDLTestCase):
m.d["pix"] += self.c1.eq(1) m.d["pix"] += self.c1.eq(1)
m._flush() m._flush()
self.assertEqual(m._driving[self.c1], "pix") self.assertEqual(m._driving[self.c1], "pix")
self.assertRepr(m._statements, """( self.assertRepr(m._statements["pix"], """(
(eq (sig c1) (const 1'd1)) (eq (sig c1) (const 1'd1))
)""") )""")
@ -118,7 +118,7 @@ class DSLTestCase(FHDLTestCase):
def test_clock_signal(self): def test_clock_signal(self):
m = Module() m = Module()
m.d.comb += ClockSignal("pix").eq(ClockSignal()) m.d.comb += ClockSignal("pix").eq(ClockSignal())
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(eq (clk pix) (clk sync)) (eq (clk pix) (clk sync))
) )
@ -127,7 +127,7 @@ class DSLTestCase(FHDLTestCase):
def test_reset_signal(self): def test_reset_signal(self):
m = Module() m = Module()
m.d.comb += ResetSignal("pix").eq(1) m.d.comb += ResetSignal("pix").eq(1)
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(eq (rst pix) (const 1'd1)) (eq (rst pix) (const 1'd1))
) )
@ -138,7 +138,7 @@ class DSLTestCase(FHDLTestCase):
with m.If(self.s1): with m.If(self.s1):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (sig s1)) (switch (cat (sig s1))
(case 1 (eq (sig c1) (const 1'd1))) (case 1 (eq (sig c1) (const 1'd1)))
@ -151,9 +151,9 @@ class DSLTestCase(FHDLTestCase):
with m.If(self.s1): with m.If(self.s1):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
with m.Elif(self.s2): with m.Elif(self.s2):
m.d.sync += self.c2.eq(0) m.d.comb += self.c2.eq(0)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (sig s1) (sig s2)) (switch (cat (sig s1) (sig s2))
(case -1 (eq (sig c1) (const 1'd1))) (case -1 (eq (sig c1) (const 1'd1)))
@ -162,6 +162,30 @@ class DSLTestCase(FHDLTestCase):
) )
""") """)
def test_If_Elif_multi(self):
m = Module()
with m.If(self.s1):
m.d.comb += self.c1.eq(1)
with m.Elif(self.s2):
m.d.sync += self.c2.eq(0)
m._flush()
self.assertRepr(m._statements[None], """
(
(switch (cat (sig s1) (sig s2))
(case -1 (eq (sig c1) (const 1'd1)))
(case 1- )
)
)
""")
self.assertRepr(m._statements["sync"], """
(
(switch (cat (sig s1) (sig s2))
(case -1 )
(case 1- (eq (sig c2) (const 1'd0)))
)
)
""")
def test_If_Elif_Else(self): def test_If_Elif_Else(self):
m = Module() m = Module()
with m.If(self.s1): with m.If(self.s1):
@ -171,15 +195,24 @@ class DSLTestCase(FHDLTestCase):
with m.Else(): with m.Else():
m.d.comb += self.c3.eq(1) m.d.comb += self.c3.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (sig s1) (sig s2)) (switch (cat (sig s1) (sig s2))
(case -1 (eq (sig c1) (const 1'd1))) (case -1 (eq (sig c1) (const 1'd1)))
(case 1- (eq (sig c2) (const 1'd0))) (case 1- )
(default (eq (sig c3) (const 1'd1))) (default (eq (sig c3) (const 1'd1)))
) )
) )
""") """)
self.assertRepr(m._statements["sync"], """
(
(switch (cat (sig s1) (sig s2))
(case -1 )
(case 1- (eq (sig c2) (const 1'd0)))
(default )
)
)
""")
def test_If_If(self): def test_If_If(self):
m = Module() m = Module()
@ -188,7 +221,7 @@ class DSLTestCase(FHDLTestCase):
with m.If(self.s2): with m.If(self.s2):
m.d.comb += self.c2.eq(1) m.d.comb += self.c2.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (sig s1)) (switch (cat (sig s1))
(case 1 (eq (sig c1) (const 1'd1))) (case 1 (eq (sig c1) (const 1'd1)))
@ -206,7 +239,7 @@ class DSLTestCase(FHDLTestCase):
with m.If(self.s2): with m.If(self.s2):
m.d.comb += self.c2.eq(1) m.d.comb += self.c2.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (sig s1)) (switch (cat (sig s1))
(case 1 (eq (sig c1) (const 1'd1)) (case 1 (eq (sig c1) (const 1'd1))
@ -227,7 +260,7 @@ class DSLTestCase(FHDLTestCase):
with m.Else(): with m.Else():
m.d.comb += self.c3.eq(1) m.d.comb += self.c3.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (sig s1)) (switch (cat (sig s1))
(case 1 (case 1
@ -298,7 +331,7 @@ class DSLTestCase(FHDLTestCase):
with m.If(self.w1): with m.If(self.w1):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (b (sig w1))) (switch (cat (b (sig w1)))
(case 1 (eq (sig c1) (const 1'd1))) (case 1 (eq (sig c1) (const 1'd1)))
@ -356,7 +389,7 @@ class DSLTestCase(FHDLTestCase):
with m.Case("1 0--"): with m.Case("1 0--"):
m.d.comb += self.c2.eq(1) m.d.comb += self.c2.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig w1) (switch (sig w1)
(case 0011 (eq (sig c1) (const 1'd1))) (case 0011 (eq (sig c1) (const 1'd1)))
@ -374,7 +407,7 @@ class DSLTestCase(FHDLTestCase):
with m.Case(): with m.Case():
m.d.comb += self.c2.eq(1) m.d.comb += self.c2.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig w1) (switch (sig w1)
(case 0011 (eq (sig c1) (const 1'd1))) (case 0011 (eq (sig c1) (const 1'd1)))
@ -390,7 +423,7 @@ class DSLTestCase(FHDLTestCase):
with m.Default(): with m.Default():
m.d.comb += self.c2.eq(1) m.d.comb += self.c2.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig w1) (switch (sig w1)
(case 0011 (eq (sig c1) (const 1'd1))) (case 0011 (eq (sig c1) (const 1'd1)))
@ -405,7 +438,7 @@ class DSLTestCase(FHDLTestCase):
with m.Case(1): with m.Case(1):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (const 1'd1) (switch (const 1'd1)
(case 1 (eq (sig c1) (const 1'd1))) (case 1 (eq (sig c1) (const 1'd1)))
@ -422,7 +455,7 @@ class DSLTestCase(FHDLTestCase):
with m.Switch(se): with m.Switch(se):
with m.Case(Color.RED): with m.Case(Color.RED):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig se) (switch (sig se)
(case 01 (eq (sig c1) (const 1'd1))) (case 01 (eq (sig c1) (const 1'd1)))
@ -439,7 +472,7 @@ class DSLTestCase(FHDLTestCase):
with m.Switch(se): with m.Switch(se):
with m.Case(Cat(Color.RED, Color.BLUE)): with m.Case(Cat(Color.RED, Color.BLUE)):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig se) (switch (sig se)
(case 10 (eq (sig c1) (const 1'd1))) (case 10 (eq (sig c1) (const 1'd1)))
@ -451,26 +484,23 @@ class DSLTestCase(FHDLTestCase):
class Color(Enum): class Color(Enum):
RED = 0b10101010 RED = 0b10101010
m = Module() m = Module()
dummy = Signal()
with m.Switch(self.w1): with m.Switch(self.w1):
with self.assertRaisesRegex(SyntaxError, with self.assertRaisesRegex(SyntaxError,
r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"): r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"):
with m.Case("--"): with m.Case("--"):
pass m.d.comb += dummy.eq(0)
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has " r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
r"width 4\); comparison will never be true$"): r"width 4\); comparison will never be true$"):
with m.Case(0b10110): with m.Case(0b10110):
pass m.d.comb += dummy.eq(0)
with self.assertWarnsRegex(SyntaxWarning, with self.assertWarnsRegex(SyntaxWarning,
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value " r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
r"\(which has width 4\); comparison will never be true$"): r"\(which has width 4\); comparison will never be true$"):
with m.Case(Color.RED): with m.Case(Color.RED):
pass m.d.comb += dummy.eq(0)
self.assertRepr(m._statements, """ self.assertEqual(m._statements, {})
(
(switch (sig w1) )
)
""")
def test_Case_bits_wrong(self): def test_Case_bits_wrong(self):
m = Module() m = Module()
@ -549,11 +579,20 @@ class DSLTestCase(FHDLTestCase):
with m.If(c): with m.If(c):
m.next = "FIRST" m.next = "FIRST"
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
(eq (sig a) (const 1'd1)) (eq (sig a) (const 1'd1))
)
(case 1 )
)
)
""")
self.assertRepr(m._statements["sync"], """
(
(switch (sig fsm_state)
(case 0
(eq (sig fsm_state) (const 1'd1)) (eq (sig fsm_state) (const 1'd1))
) )
(case 1 (case 1
@ -594,11 +633,20 @@ class DSLTestCase(FHDLTestCase):
with m.State("SECOND"): with m.State("SECOND"):
m.next = "FIRST" m.next = "FIRST"
m._flush() m._flush()
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
(eq (sig a) (const 1'd0)) (eq (sig a) (const 1'd0))
)
(case 1 )
)
)
""")
self.assertRepr(m._statements["sync"], """
(
(switch (sig fsm_state)
(case 0
(eq (sig fsm_state) (const 1'd1)) (eq (sig fsm_state) (const 1'd1))
) )
(case 1 (case 1
@ -622,16 +670,10 @@ class DSLTestCase(FHDLTestCase):
m._flush() m._flush()
self.assertEqual(m._generated["fsm"].state.reset, 1) self.assertEqual(m._generated["fsm"].state.reset, 1)
self.maxDiff = 10000 self.maxDiff = 10000
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(eq (sig b) (== (sig fsm_state) (const 1'd0))) (eq (sig b) (== (sig fsm_state) (const 1'd0)))
(eq (sig a) (== (sig fsm_state) (const 1'd1))) (eq (sig a) (== (sig fsm_state) (const 1'd1)))
(switch (sig fsm_state)
(case 1
)
(case 0
)
)
) )
""") """)
@ -639,9 +681,7 @@ class DSLTestCase(FHDLTestCase):
m = Module() m = Module()
with m.FSM(): with m.FSM():
pass pass
self.assertRepr(m._statements, """ self.assertEqual(m._statements, {})
()
""")
def test_FSM_wrong_domain(self): def test_FSM_wrong_domain(self):
m = Module() m = Module()
@ -713,7 +753,7 @@ class DSLTestCase(FHDLTestCase):
with m.If(self.w1): with m.If(self.w1):
m.d.comb += self.c1.eq(1) m.d.comb += self.c1.eq(1)
m.d.comb += self.c2.eq(1) m.d.comb += self.c2.eq(1)
self.assertRepr(m._statements, """ self.assertRepr(m._statements[None], """
( (
(switch (cat (b (sig w1))) (switch (cat (b (sig w1)))
(case 1 (eq (sig c1) (const 1'd1))) (case 1 (eq (sig c1) (const 1'd1)))
@ -830,7 +870,7 @@ class DSLTestCase(FHDLTestCase):
m1.submodules.foo = m2 m1.submodules.foo = m2
f1 = m1.elaborate(platform=None) f1 = m1.elaborate(platform=None)
self.assertRepr(f1.statements, """ self.assertRepr(f1.statements[None], """
( (
(eq (sig c1) (sig s1)) (eq (sig c1) (sig s1))
) )
@ -841,9 +881,13 @@ class DSLTestCase(FHDLTestCase):
self.assertEqual(len(f1.subfragments), 1) self.assertEqual(len(f1.subfragments), 1)
(f2, f2_name), = f1.subfragments (f2, f2_name), = f1.subfragments
self.assertEqual(f2_name, "foo") self.assertEqual(f2_name, "foo")
self.assertRepr(f2.statements, """ self.assertRepr(f2.statements[None], """
( (
(eq (sig c2) (sig s2)) (eq (sig c2) (sig s2))
)
""")
self.assertRepr(f2.statements["sync"], """
(
(eq (sig c3) (sig s3)) (eq (sig c3) (sig s3))
) )
""") """)

View file

@ -100,6 +100,7 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_self_contained(self): def test_self_contained(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
None,
self.c1.eq(self.s1), self.c1.eq(self.s1),
self.s1.eq(self.c1) self.s1.eq(self.c1)
) )
@ -110,6 +111,7 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_infer_input(self): def test_infer_input(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
None,
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
@ -121,6 +123,7 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_request_output(self): def test_request_output(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
None,
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
@ -133,10 +136,12 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_input_in_subfragment(self): def test_input_in_subfragment(self):
f1 = Fragment() f1 = Fragment()
f1.add_statements( f1.add_statements(
None,
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements(
None,
self.s1.eq(0) self.s1.eq(0)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
@ -150,6 +155,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f1 = Fragment() f1 = Fragment()
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements(
None,
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
@ -164,10 +170,12 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_output_from_subfragment(self): def test_output_from_subfragment(self):
f1 = Fragment() f1 = Fragment()
f1.add_statements( f1.add_statements(
None,
self.c1.eq(0) self.c1.eq(0)
) )
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements(
None,
self.c2.eq(1) self.c2.eq(1)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
@ -183,15 +191,18 @@ class FragmentPortsTestCase(FHDLTestCase):
def test_output_from_subfragment_2(self): def test_output_from_subfragment_2(self):
f1 = Fragment() f1 = Fragment()
f1.add_statements( f1.add_statements(
None,
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements(
None,
self.c2.eq(self.s1) self.c2.eq(self.s1)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f3 = Fragment() f3 = Fragment()
f3.add_statements( f3.add_statements(
None,
self.s1.eq(0) self.s1.eq(0)
) )
f2.add_subfragment(f3) f2.add_subfragment(f3)
@ -205,11 +216,13 @@ class FragmentPortsTestCase(FHDLTestCase):
f1 = Fragment() f1 = Fragment()
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements(
None,
self.c1.eq(self.c2) self.c1.eq(self.c2)
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f3 = Fragment() f3 = Fragment()
f3.add_statements( f3.add_statements(
None,
self.c2.eq(0) self.c2.eq(0)
) )
f3.add_driver(self.c2) f3.add_driver(self.c2)
@ -222,12 +235,14 @@ class FragmentPortsTestCase(FHDLTestCase):
f1 = Fragment() f1 = Fragment()
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements(
None,
self.c2.eq(0) self.c2.eq(0)
) )
f2.add_driver(self.c2) f2.add_driver(self.c2)
f1.add_subfragment(f2) f1.add_subfragment(f2)
f3 = Fragment() f3 = Fragment()
f3.add_statements( f3.add_statements(
None,
self.c1.eq(self.c2) self.c1.eq(self.c2)
) )
f1.add_subfragment(f3) f1.add_subfragment(f3)
@ -239,6 +254,7 @@ class FragmentPortsTestCase(FHDLTestCase):
sync = ClockDomain() sync = ClockDomain()
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
"sync",
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f.add_domains(sync) f.add_domains(sync)
@ -255,6 +271,7 @@ class FragmentPortsTestCase(FHDLTestCase):
sync = ClockDomain(reset_less=True) sync = ClockDomain(reset_less=True)
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
"sync",
self.c1.eq(self.s1) self.c1.eq(self.s1)
) )
f.add_domains(sync) f.add_domains(sync)
@ -490,7 +507,7 @@ class FragmentDomainsTestCase(FHDLTestCase):
def test_propagate_missing(self): def test_propagate_missing(self):
s1 = Signal() s1 = Signal()
f1 = Fragment() f1 = Fragment()
f1.add_driver(s1, "sync") f1.add_statements("sync", s1.eq(1))
with self.assertRaisesRegex(DomainError, with self.assertRaisesRegex(DomainError,
r"^Domain 'sync' is used but not defined$"): r"^Domain 'sync' is used but not defined$"):
@ -499,7 +516,7 @@ class FragmentDomainsTestCase(FHDLTestCase):
def test_propagate_create_missing(self): def test_propagate_create_missing(self):
s1 = Signal() s1 = Signal()
f1 = Fragment() f1 = Fragment()
f1.add_driver(s1, "sync") f1.add_statements("sync", s1.eq(1))
f2 = Fragment() f2 = Fragment()
f1.add_subfragment(f2) f1.add_subfragment(f2)
@ -512,7 +529,7 @@ class FragmentDomainsTestCase(FHDLTestCase):
def test_propagate_create_missing_fragment(self): def test_propagate_create_missing_fragment(self):
s1 = Signal() s1 = Signal()
f1 = Fragment() f1 = Fragment()
f1.add_driver(s1, "sync") f1.add_statements("sync", s1.eq(1))
cd = ClockDomain("sync") cd = ClockDomain("sync")
f2 = Fragment() f2 = Fragment()
@ -529,7 +546,7 @@ class FragmentDomainsTestCase(FHDLTestCase):
def test_propagate_create_missing_fragment_many_domains(self): def test_propagate_create_missing_fragment_many_domains(self):
s1 = Signal() s1 = Signal()
f1 = Fragment() f1 = Fragment()
f1.add_driver(s1, "sync") f1.add_statements("sync", s1.eq(1))
cd_por = ClockDomain("por") cd_por = ClockDomain("por")
cd_sync = ClockDomain("sync") cd_sync = ClockDomain("sync")
@ -548,7 +565,7 @@ class FragmentDomainsTestCase(FHDLTestCase):
def test_propagate_create_missing_fragment_wrong(self): def test_propagate_create_missing_fragment_wrong(self):
s1 = Signal() s1 = Signal()
f1 = Fragment() f1 = Fragment()
f1.add_driver(s1, "sync") f1.add_statements("sync", s1.eq(1))
f2 = Fragment() f2 = Fragment()
f2.add_domains(ClockDomain("foo")) f2.add_domains(ClockDomain("foo"))
@ -566,7 +583,7 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
self.c2 = Signal() self.c2 = Signal()
self.f1 = Fragment() self.f1 = Fragment()
self.f1.add_statements(self.c1.eq(0)) self.f1.add_statements("sync", self.c1.eq(0))
self.f1.add_driver(self.s1) self.f1.add_driver(self.s1)
self.f1.add_driver(self.c1, "sync") self.f1.add_driver(self.c1, "sync")
@ -574,7 +591,7 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
self.f1.add_subfragment(self.f1a, "f1a") self.f1.add_subfragment(self.f1a, "f1a")
self.f2 = Fragment() self.f2 = Fragment()
self.f2.add_statements(self.c2.eq(1)) self.f2.add_statements("sync", self.c2.eq(1))
self.f2.add_driver(self.s1) self.f2.add_driver(self.s1)
self.f2.add_driver(self.c2, "sync") self.f2.add_driver(self.c2, "sync")
self.f1.add_subfragment(self.f2) self.f1.add_subfragment(self.f2)
@ -594,7 +611,7 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
(self.f1b, "f1b"), (self.f1b, "f1b"),
(self.f2a, "f2a"), (self.f2a, "f2a"),
]) ])
self.assertRepr(self.f1.statements, """ self.assertRepr(self.f1.statements["sync"], """
( (
(eq (sig c1) (const 1'd0)) (eq (sig c1) (const 1'd0))
(eq (sig c2) (const 1'd1)) (eq (sig c2) (const 1'd1))
@ -629,12 +646,12 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
self.f2 = Fragment() self.f2 = Fragment()
self.f2.add_driver(self.s1) self.f2.add_driver(self.s1)
self.f2.add_statements(self.c1.eq(0)) self.f2.add_statements(None, self.c1.eq(0))
self.f1.add_subfragment(self.f2) self.f1.add_subfragment(self.f2)
self.f3 = Fragment() self.f3 = Fragment()
self.f3.add_driver(self.s1) self.f3.add_driver(self.s1)
self.f3.add_statements(self.c2.eq(1)) self.f3.add_statements(None, self.c2.eq(1))
self.f1.add_subfragment(self.f3) self.f1.add_subfragment(self.f3)
def test_conflict_sub_sub(self): def test_conflict_sub_sub(self):
@ -642,7 +659,7 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
self.f1._resolve_hierarchy_conflicts(mode="silent") self.f1._resolve_hierarchy_conflicts(mode="silent")
self.assertEqual(self.f1.subfragments, []) self.assertEqual(self.f1.subfragments, [])
self.assertRepr(self.f1.statements, """ self.assertRepr(self.f1.statements[None], """
( (
(eq (sig c1) (const 1'd0)) (eq (sig c1) (const 1'd0))
(eq (sig c2) (const 1'd1)) (eq (sig c2) (const 1'd1))
@ -658,12 +675,12 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
self.f1.add_driver(self.s1) self.f1.add_driver(self.s1)
self.f2 = Fragment() self.f2 = Fragment()
self.f2.add_statements(self.c1.eq(0)) self.f2.add_statements(None, self.c1.eq(0))
self.f1.add_subfragment(self.f2) self.f1.add_subfragment(self.f2)
self.f3 = Fragment() self.f3 = Fragment()
self.f3.add_driver(self.s1) self.f3.add_driver(self.s1)
self.f3.add_statements(self.c2.eq(1)) self.f3.add_statements(None, self.c2.eq(1))
self.f2.add_subfragment(self.f3) self.f2.add_subfragment(self.f3)
def test_conflict_self_subsub(self): def test_conflict_self_subsub(self):
@ -671,7 +688,7 @@ class FragmentHierarchyConflictTestCase(FHDLTestCase):
self.f1._resolve_hierarchy_conflicts(mode="silent") self.f1._resolve_hierarchy_conflicts(mode="silent")
self.assertEqual(self.f1.subfragments, []) self.assertEqual(self.f1.subfragments, [])
self.assertRepr(self.f1.statements, """ self.assertRepr(self.f1.statements[None], """
( (
(eq (sig c1) (const 1'd0)) (eq (sig c1) (const 1'd0))
(eq (sig c2) (const 1'd1)) (eq (sig c2) (const 1'd1))
@ -848,11 +865,11 @@ class InstanceTestCase(FHDLTestCase):
f.add_domains(cd_sync_norst := ClockDomain(reset_less=True)) f.add_domains(cd_sync_norst := ClockDomain(reset_less=True))
f.add_ports((i, rst), dir="i") f.add_ports((i, rst), dir="i")
f.add_ports((o1, o2, o3), dir="o") f.add_ports((o1, o2, o3), dir="o")
f.add_statements([o1.eq(0)]) f.add_statements(None, [o1.eq(0)])
f.add_driver(o1, domain=None) f.add_driver(o1, domain=None)
f.add_statements([o2.eq(i1)]) f.add_statements("sync", [o2.eq(i1)])
f.add_driver(o2, domain="sync") f.add_driver(o2, domain="sync")
f.add_statements([o3.eq(i1)]) f.add_statements("sync_norst", [o3.eq(i1)])
f.add_driver(o3, domain="sync_norst") f.add_driver(o3, domain="sync_norst")
names = f._assign_names_to_signals() names = f._assign_names_to_signals()

View file

@ -26,26 +26,35 @@ class DomainRenamerTestCase(FHDLTestCase):
def test_rename_signals(self): def test_rename_signals(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
None,
self.s1.eq(ClockSignal()), self.s1.eq(ClockSignal()),
ResetSignal().eq(self.s2), ResetSignal().eq(self.s2),
self.s3.eq(0),
self.s4.eq(ClockSignal("other")), self.s4.eq(ClockSignal("other")),
self.s5.eq(ResetSignal("other")), self.s5.eq(ResetSignal("other")),
) )
f.add_statements(
"sync",
self.s3.eq(0),
)
f.add_driver(self.s1, None) f.add_driver(self.s1, None)
f.add_driver(self.s2, None) f.add_driver(self.s2, None)
f.add_driver(self.s3, "sync") f.add_driver(self.s3, "sync")
f = DomainRenamer("pix")(f) f = DomainRenamer("pix")(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements[None], """
( (
(eq (sig s1) (clk pix)) (eq (sig s1) (clk pix))
(eq (rst pix) (sig s2)) (eq (rst pix) (sig s2))
(eq (sig s3) (const 1'd0))
(eq (sig s4) (clk other)) (eq (sig s4) (clk other))
(eq (sig s5) (rst other)) (eq (sig s5) (rst other))
) )
""") """)
self.assertRepr(f.statements["pix"], """
(
(eq (sig s3) (const 1'd0))
)
""")
self.assertFalse("sync" in f.statements)
self.assertEqual(f.drivers, { self.assertEqual(f.drivers, {
None: SignalSet((self.s1, self.s2)), None: SignalSet((self.s1, self.s2)),
"pix": SignalSet((self.s3,)), "pix": SignalSet((self.s3,)),
@ -54,12 +63,13 @@ class DomainRenamerTestCase(FHDLTestCase):
def test_rename_multi(self): def test_rename_multi(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
None,
self.s1.eq(ClockSignal()), self.s1.eq(ClockSignal()),
self.s2.eq(ResetSignal("other")), self.s2.eq(ResetSignal("other")),
) )
f = DomainRenamer({"sync": "pix", "other": "pix2"})(f) f = DomainRenamer({"sync": "pix", "other": "pix2"})(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements[None], """
( (
(eq (sig s1) (clk pix)) (eq (sig s1) (clk pix))
(eq (sig s2) (rst pix2)) (eq (sig s2) (rst pix2))
@ -86,12 +96,13 @@ class DomainRenamerTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
f.add_domains(cd_pix) f.add_domains(cd_pix)
f.add_statements( f.add_statements(
None,
self.s1.eq(ResetSignal(allow_reset_less=True)), self.s1.eq(ResetSignal(allow_reset_less=True)),
) )
f = DomainRenamer("pix")(f) f = DomainRenamer("pix")(f)
f = DomainLowerer()(f) f = DomainLowerer()(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements[None], """
( (
(eq (sig s1) (const 1'd0)) (eq (sig s1) (const 1'd0))
) )
@ -151,11 +162,12 @@ class DomainLowererTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
f.add_domains(sync) f.add_domains(sync)
f.add_statements( f.add_statements(
None,
self.s.eq(ClockSignal("sync")) self.s.eq(ClockSignal("sync"))
) )
f = DomainLowerer()(f) f = DomainLowerer()(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements[None], """
( (
(eq (sig s) (sig clk)) (eq (sig s) (sig clk))
) )
@ -166,11 +178,12 @@ class DomainLowererTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
f.add_domains(sync) f.add_domains(sync)
f.add_statements( f.add_statements(
None,
self.s.eq(ResetSignal("sync")) self.s.eq(ResetSignal("sync"))
) )
f = DomainLowerer()(f) f = DomainLowerer()(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements[None], """
( (
(eq (sig s) (sig rst)) (eq (sig s) (sig rst))
) )
@ -181,11 +194,12 @@ class DomainLowererTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
f.add_domains(sync) f.add_domains(sync)
f.add_statements( f.add_statements(
None,
self.s.eq(ResetSignal("sync", allow_reset_less=True)) self.s.eq(ResetSignal("sync", allow_reset_less=True))
) )
f = DomainLowerer()(f) f = DomainLowerer()(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements[None], """
( (
(eq (sig s) (const 1'd0)) (eq (sig s) (const 1'd0))
) )
@ -208,6 +222,7 @@ class DomainLowererTestCase(FHDLTestCase):
def test_lower_wrong_domain(self): def test_lower_wrong_domain(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
None,
self.s.eq(ClockSignal("xxx")) self.s.eq(ClockSignal("xxx"))
) )
@ -220,6 +235,7 @@ class DomainLowererTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
f.add_domains(sync) f.add_domains(sync)
f.add_statements( f.add_statements(
None,
self.s.eq(ResetSignal("sync")) self.s.eq(ResetSignal("sync"))
) )
@ -368,12 +384,13 @@ class ResetInserterTestCase(FHDLTestCase):
def test_reset_default(self): def test_reset_default(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements(
"sync",
self.s1.eq(1) self.s1.eq(1)
) )
f.add_driver(self.s1, "sync") f.add_driver(self.s1, "sync")
f = ResetInserter(self.c1)(f) f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
(switch (sig c1) (switch (sig c1)
@ -384,18 +401,20 @@ class ResetInserterTestCase(FHDLTestCase):
def test_reset_cd(self): def test_reset_cd(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements("sync", self.s1.eq(1))
self.s1.eq(1), f.add_statements("pix", self.s2.eq(0))
self.s2.eq(0),
)
f.add_domains(ClockDomain("sync")) f.add_domains(ClockDomain("sync"))
f.add_driver(self.s1, "sync") f.add_driver(self.s1, "sync")
f.add_driver(self.s2, "pix") f.add_driver(self.s2, "pix")
f = ResetInserter({"pix": self.c1})(f) f = ResetInserter({"pix": self.c1})(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
)
""")
self.assertRepr(f.statements["pix"], """
(
(eq (sig s2) (const 1'd0)) (eq (sig s2) (const 1'd0))
(switch (sig c1) (switch (sig c1)
(case 1 (eq (sig s2) (const 1'd1))) (case 1 (eq (sig s2) (const 1'd1)))
@ -405,13 +424,11 @@ class ResetInserterTestCase(FHDLTestCase):
def test_reset_value(self): def test_reset_value(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements("sync", self.s2.eq(0))
self.s2.eq(0)
)
f.add_driver(self.s2, "sync") f.add_driver(self.s2, "sync")
f = ResetInserter(self.c1)(f) f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s2) (const 1'd0)) (eq (sig s2) (const 1'd0))
(switch (sig c1) (switch (sig c1)
@ -422,13 +439,11 @@ class ResetInserterTestCase(FHDLTestCase):
def test_reset_less(self): def test_reset_less(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements("sync", self.s3.eq(0))
self.s3.eq(0)
)
f.add_driver(self.s3, "sync") f.add_driver(self.s3, "sync")
f = ResetInserter(self.c1)(f) f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s3) (const 1'd0)) (eq (sig s3) (const 1'd0))
(switch (sig c1) (switch (sig c1)
@ -447,13 +462,11 @@ class EnableInserterTestCase(FHDLTestCase):
def test_enable_default(self): def test_enable_default(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements("sync", self.s1.eq(1))
self.s1.eq(1)
)
f.add_driver(self.s1, "sync") f.add_driver(self.s1, "sync")
f = EnableInserter(self.c1)(f) f = EnableInserter(self.c1)(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
(switch (sig c1) (switch (sig c1)
@ -464,17 +477,19 @@ class EnableInserterTestCase(FHDLTestCase):
def test_enable_cd(self): def test_enable_cd(self):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements("sync", self.s1.eq(1))
self.s1.eq(1), f.add_statements("pix", self.s2.eq(0))
self.s2.eq(0),
)
f.add_driver(self.s1, "sync") f.add_driver(self.s1, "sync")
f.add_driver(self.s2, "pix") f.add_driver(self.s2, "pix")
f = EnableInserter({"pix": self.c1})(f) f = EnableInserter({"pix": self.c1})(f)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
)
""")
self.assertRepr(f.statements["pix"], """
(
(eq (sig s2) (const 1'd0)) (eq (sig s2) (const 1'd0))
(switch (sig c1) (switch (sig c1)
(case 0 (eq (sig s2) (sig s2))) (case 0 (eq (sig s2) (sig s2)))
@ -484,21 +499,17 @@ class EnableInserterTestCase(FHDLTestCase):
def test_enable_subfragment(self): def test_enable_subfragment(self):
f1 = Fragment() f1 = Fragment()
f1.add_statements( f1.add_statements("sync", self.s1.eq(1))
self.s1.eq(1)
)
f1.add_driver(self.s1, "sync") f1.add_driver(self.s1, "sync")
f2 = Fragment() f2 = Fragment()
f2.add_statements( f2.add_statements("sync", self.s2.eq(1))
self.s2.eq(1)
)
f2.add_driver(self.s2, "sync") f2.add_driver(self.s2, "sync")
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1 = EnableInserter(self.c1)(f1) f1 = EnableInserter(self.c1)(f1)
(f2, _), = f1.subfragments (f2, _), = f1.subfragments
self.assertRepr(f1.statements, """ self.assertRepr(f1.statements["sync"], """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
(switch (sig c1) (switch (sig c1)
@ -506,7 +517,7 @@ class EnableInserterTestCase(FHDLTestCase):
) )
) )
""") """)
self.assertRepr(f2.statements, """ self.assertRepr(f2.statements["sync"], """
( (
(eq (sig s2) (const 1'd1)) (eq (sig s2) (const 1'd1))
(switch (sig c1) (switch (sig c1)
@ -542,9 +553,7 @@ class _MockElaboratable(Elaboratable):
def elaborate(self, platform): def elaborate(self, platform):
f = Fragment() f = Fragment()
f.add_statements( f.add_statements("sync", self.s1.eq(1))
self.s1.eq(1)
)
f.add_driver(self.s1, "sync") f.add_driver(self.s1, "sync")
return f return f
@ -569,7 +578,7 @@ class TransformedElaboratableTestCase(FHDLTestCase):
self.assertIs(te1, te2) self.assertIs(te1, te2)
f = Fragment.get(te2, None) f = Fragment.get(te2, None)
self.assertRepr(f.statements, """ self.assertRepr(f.statements["sync"], """
( (
(eq (sig s1) (const 1'd1)) (eq (sig s1) (const 1'd1))
(switch (sig c1) (switch (sig c1)

View file

@ -889,7 +889,7 @@ class ConnectTestCase(unittest.TestCase):
m = Module() m = Module()
connect(m, src=src, snk=snk) connect(m, src=src, snk=snk)
self.assertEqual([repr(stmt) for stmt in m._statements], [ self.assertEqual([repr(stmt) for stmt in m._statements[None]], [
'(eq (sig snk__addr) (sig src__addr))', '(eq (sig snk__addr) (sig src__addr))',
'(eq (sig snk__cycle) (sig src__cycle))', '(eq (sig snk__cycle) (sig src__cycle))',
'(eq (sig src__r_data) (sig snk__r_data))', '(eq (sig src__r_data) (sig snk__r_data))',
@ -903,7 +903,7 @@ class ConnectTestCase(unittest.TestCase):
a=Const(1)), a=Const(1)),
q=NS(signature=Signature({"a": In(1)}), q=NS(signature=Signature({"a": In(1)}),
a=Const(1))) a=Const(1)))
self.assertEqual(m._statements, []) self.assertEqual(m._statements, {})
def test_nested(self): def test_nested(self):
m = Module() m = Module()
@ -912,7 +912,7 @@ class ConnectTestCase(unittest.TestCase):
a=NS(signature=Signature({"f": Out(1)}), f=Signal(name='p__a'))), a=NS(signature=Signature({"f": Out(1)}), f=Signal(name='p__a'))),
q=NS(signature=Signature({"a": In(Signature({"f": Out(1)}))}), q=NS(signature=Signature({"a": In(Signature({"f": Out(1)}))}),
a=NS(signature=Signature({"f": Out(1)}).flip(), f=Signal(name='q__a')))) a=NS(signature=Signature({"f": Out(1)}).flip(), f=Signal(name='q__a'))))
self.assertEqual([repr(stmt) for stmt in m._statements], [ self.assertEqual([repr(stmt) for stmt in m._statements[None]], [
'(eq (sig q__a) (sig p__a))' '(eq (sig q__a) (sig p__a))'
]) ])
@ -931,7 +931,7 @@ class ConnectTestCase(unittest.TestCase):
g=Signal(name="q__b__g"), g=Signal(name="q__b__g"),
f=Signal(name="q__b__f")), f=Signal(name="q__b__f")),
a=Signal(name="q__a"))) a=Signal(name="q__a")))
self.assertEqual([repr(stmt) for stmt in m._statements], [ self.assertEqual([repr(stmt) for stmt in m._statements[None]], [
'(eq (sig q__a) (sig p__a))', '(eq (sig q__a) (sig p__a))',
'(eq (sig q__b__f) (sig p__b__f))', '(eq (sig q__b__f) (sig p__b__f))',
'(eq (sig q__b__g) (sig p__b__g))', '(eq (sig q__b__g) (sig p__b__g))',
@ -942,7 +942,7 @@ class ConnectTestCase(unittest.TestCase):
m = Module() m = Module()
connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',))) connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',)))
self.assertEqual([repr(stmt) for stmt in m._statements], [ self.assertEqual([repr(stmt) for stmt in m._statements[None]], [
'(eq (sig q__a__0) (sig p__a__0))', '(eq (sig q__a__0) (sig p__a__0))',
'(eq (sig q__a__1) (sig p__a__1))' '(eq (sig q__a__1) (sig p__a__1))'
]) ])
@ -952,7 +952,7 @@ class ConnectTestCase(unittest.TestCase):
m = Module() m = Module()
connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',))) connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',)))
self.assertEqual([repr(stmt) for stmt in m._statements], [ self.assertEqual([repr(stmt) for stmt in m._statements[None]], [
'(eq (sig q__a__0__0) (sig p__a__0__0))', '(eq (sig q__a__0__0) (sig p__a__0__0))',
]) ])

View file

@ -27,7 +27,7 @@ class SimulatorUnitTestCase(FHDLTestCase):
stmt = stmt(osig, *isigs) stmt = stmt(osig, *isigs)
frag = Fragment() frag = Fragment()
frag.add_statements(stmt) frag.add_statements(None, stmt)
for signal in flatten(s._lhs_signals() for s in Statement.cast(stmt)): for signal in flatten(s._lhs_signals() for s in Statement.cast(stmt)):
frag.add_driver(signal) frag.add_driver(signal)
@ -1045,9 +1045,10 @@ class SimulatorRegressionTestCase(FHDLTestCase):
def test_bug_595(self): def test_bug_595(self):
dut = Module() dut = Module()
dummy = Signal()
with dut.FSM(name="name with space"): with dut.FSM(name="name with space"):
with dut.State(0): with dut.State(0):
pass dut.d.comb += dummy.eq(1)
sim = Simulator(dut) sim = Simulator(dut)
with self.assertRaisesRegex(NameError, with self.assertRaisesRegex(NameError,
r"^Signal 'bench\.top\.name with space_state' contains a whitespace character$"): r"^Signal 'bench\.top\.name with space_state' contains a whitespace character$"):