hdl._dsl: Change FSM codegen to avoid mutating AST nodes.

Fixes #1066.
This commit is contained in:
Wanda 2024-02-27 15:03:48 +01:00 committed by Catherine
parent f524dd041a
commit 85bb5ee77c
3 changed files with 108 additions and 42 deletions

View file

@ -2304,6 +2304,11 @@ def Cover(test, *, name=None, src_loc_at=0):
return Property("cover", test, name=name, src_loc_at=src_loc_at+1)
class _LateBoundStatement(Statement):
def resolve(self):
raise NotImplementedError # :nocov:
@final
class Switch(Statement):
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):

View file

@ -9,6 +9,7 @@ from .._utils import flatten
from ..utils import bits_for
from .. import tracer
from ._ast import *
from ._ast import _StatementList, _LateBoundStatement, Property
from ._ir import *
from ._cd import *
from ._xfrm import *
@ -146,16 +147,51 @@ def _guardedcontextmanager(keyword):
return decorator
class FSMNextStatement(_LateBoundStatement):
def __init__(self, ctrl_data, state, *, src_loc_at=0):
self.ctrl_data = ctrl_data
self.state = state
super().__init__(src_loc_at=1 + src_loc_at)
def resolve(self):
return self.ctrl_data["signal"].eq(self.ctrl_data["encoding"][self.state])
class FSM:
def __init__(self, state, encoding, decoding):
self.state = state
self.encoding = encoding
self.decoding = decoding
def __init__(self, data):
self._data = data
self.encoding = data["encoding"]
self.decoding = data["decoding"]
def ongoing(self, name):
if name not in self.encoding:
self.encoding[name] = len(self.encoding)
return Operator("==", [self.state, self.encoding[name]], src_loc_at=0)
fsm_name = self._data["name"]
self._data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
return self._data["ongoing"][name]
def resolve_statement(stmt):
if isinstance(stmt, _LateBoundStatement):
return resolve_statement(stmt.resolve())
elif isinstance(stmt, Switch):
return Switch(
test=stmt.test,
cases=OrderedDict(
(patterns, resolve_statements(stmts))
for patterns, stmts in stmt.cases.items()
),
src_loc=stmt.src_loc,
case_src_locs=stmt.case_src_locs,
)
elif isinstance(stmt, (Assign, Property)):
return stmt
else:
assert False # :nocov:
def resolve_statements(stmts):
return _StatementList(resolve_statement(stmt) for stmt in stmts)
class Module(_ModuleBuilderRoot, Elaboratable):
@ -172,6 +208,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._statements = {}
self._ctrl_context = None
self._ctrl_stack = []
self._top_comb_statements = _StatementList()
self._driving = SignalDict()
self._named_submodules = {}
@ -391,17 +428,16 @@ class Module(_ModuleBuilderRoot, Elaboratable):
init = reset
fsm_data = self._set_ctrl("FSM", {
"name": name,
"signal": Signal(name=f"{name}_state", src_loc_at=2),
"init": init,
"domain": domain,
"encoding": OrderedDict(),
"decoding": OrderedDict(),
"ongoing": {},
"states": OrderedDict(),
"src_loc": tracer.get_src_loc(src_loc_at=1),
"state_src_locs": {},
})
self._generated[name] = fsm = \
FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"])
self._generated[name] = fsm = FSM(fsm_data)
try:
self._ctrl_context = "FSM"
self.domain._depth += 1
@ -414,6 +450,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self.domain._depth -= 1
self._ctrl_context = None
self._pop_ctrl()
fsm.state = fsm_data["signal"]
@contextmanager
def State(self, name):
@ -423,7 +460,9 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if name in fsm_data["states"]:
raise NameError(f"FSM state '{name}' is already defined")
if name not in fsm_data["encoding"]:
fsm_name = fsm_data["name"]
fsm_data["encoding"][name] = len(fsm_data["encoding"])
fsm_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
try:
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
@ -445,9 +484,11 @@ class Module(_ModuleBuilderRoot, Elaboratable):
for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)):
if ctrl_name == "FSM":
if name not in ctrl_data["encoding"]:
fsm_name = ctrl_data["name"]
ctrl_data["encoding"][name] = len(ctrl_data["encoding"])
ctrl_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
self._add_statement(
assigns=[ctrl_data["signal"].eq(ctrl_data["encoding"][name])],
assigns=[FSMNextStatement(ctrl_data, name)],
domain=ctrl_data["domain"],
depth=len(self._ctrl_stack))
return
@ -500,19 +541,25 @@ class Module(_ModuleBuilderRoot, Elaboratable):
src_loc=src_loc, case_src_locs=switch_case_src_locs))
if name == "FSM":
fsm_signal, fsm_init, fsm_encoding, fsm_decoding, fsm_states = \
data["signal"], data["init"], data["encoding"], data["decoding"], data["states"]
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
data["name"], data["init"], data["encoding"], data["decoding"], data["states"], data["ongoing"]
fsm_state_src_locs = data["state_src_locs"]
if not fsm_states:
data["signal"] = Signal(0, name=f"{fsm_name}_state", src_loc_at=2)
return
fsm_signal.width = bits_for(len(fsm_encoding) - 1)
if fsm_init is None:
fsm_signal.init = fsm_encoding[next(iter(fsm_states))]
init = fsm_encoding[next(iter(fsm_states))]
else:
fsm_signal.init = fsm_encoding[fsm_init]
init = fsm_encoding[fsm_init]
# The FSM is encoded such that the state with encoding 0 is always the init state.
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
data["signal"] = fsm_signal = Signal(range(len(fsm_encoding)), init=init,
name=f"{fsm_name}_state", src_loc_at=2,
decoder=lambda n: f"{fsm_decoding[n]}/{n}")
for name, sig in fsm_ongoing.items():
self._top_comb_statements.append(
sig.eq(Operator("==", [fsm_signal, fsm_encoding[name]], src_loc_at=0)))
domains = set()
for stmts in fsm_states.values():
@ -533,20 +580,21 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._pop_ctrl()
for stmt in Statement.cast(assigns):
if not isinstance(stmt, (Assign, Property)):
if not isinstance(stmt, (Assign, Property, _LateBoundStatement)):
raise SyntaxError(
f"Only assignments and property checks may be appended to d.{domain}")
stmt._MustUse__used = True
for signal in stmt._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
if isinstance(stmt, Assign):
for signal in stmt._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
self._statements.setdefault(domain, []).append(stmt)
@ -586,9 +634,13 @@ class Module(_ModuleBuilderRoot, Elaboratable):
for submodule, src_loc in self._anon_submodules:
fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc)
for domain, statements in self._statements.items():
statements = resolve_statements(statements)
fragment.add_statements(domain, statements)
for signal, domain in self._driving.items():
fragment.add_driver(signal, domain)
for signal in statements._lhs_signals():
fragment.add_driver(signal, domain)
fragment.add_statements("comb", self._top_comb_statements)
for signal in self._top_comb_statements._lhs_signals():
fragment.add_driver(signal, "comb")
fragment.add_domains(self._domains.values())
fragment.generated.update(self._generated)
return fragment

View file

@ -593,8 +593,9 @@ class DSLTestCase(FHDLTestCase):
m.d.sync += b.eq(~b)
with m.If(c):
m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """
frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
(
(switch (sig fsm_state)
(case 0
@ -602,9 +603,11 @@ class DSLTestCase(FHDLTestCase):
)
(case 1 )
)
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
)
""")
self.assertRepr(m._statements["sync"], """
self.assertRepr(frag.statements["sync"], """
(
(switch (sig fsm_state)
(case 0
@ -620,13 +623,13 @@ class DSLTestCase(FHDLTestCase):
)
)
""")
self.assertEqual({repr(k): v for k, v in m._driving.items()}, {
self.assertEqual({repr(sig): k for k, v in frag.drivers.items() for sig in v}, {
"(sig a)": "comb",
"(sig fsm_state)": "sync",
"(sig b)": "sync",
"(sig fsm_ongoing_FIRST)": "comb",
"(sig fsm_ongoing_SECOND)": "comb",
})
frag = m.elaborate(platform=None)
fsm = frag.find_generated("fsm")
self.assertIsInstance(fsm.state, Signal)
self.assertEqual(fsm.encoding, OrderedDict({
@ -647,8 +650,8 @@ class DSLTestCase(FHDLTestCase):
m.next = "SECOND"
with m.State("SECOND"):
m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """
frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
(
(switch (sig fsm_state)
(case 0
@ -656,9 +659,11 @@ class DSLTestCase(FHDLTestCase):
)
(case 1 )
)
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
)
""")
self.assertRepr(m._statements["sync"], """
self.assertRepr(frag.statements["sync"], """
(
(switch (sig fsm_state)
(case 0
@ -683,8 +688,8 @@ class DSLTestCase(FHDLTestCase):
m.next = "SECOND"
with m.State("SECOND"):
m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """
frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
(
(switch (sig fsm_state)
(case 0
@ -692,9 +697,11 @@ class DSLTestCase(FHDLTestCase):
)
(case 1 )
)
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
)
""")
self.assertRepr(m._statements["sync"], """
self.assertRepr(frag.statements["sync"], """
(
(switch (sig fsm_state)
(case 0
@ -731,13 +738,15 @@ class DSLTestCase(FHDLTestCase):
m.d.comb += a.eq(fsm.ongoing("FIRST"))
with m.State("SECOND"):
pass
m._flush()
frag = m.elaborate(platform=None)
self.assertEqual(m._generated["fsm"].state.init, 1)
self.maxDiff = 10000
self.assertRepr(m._statements["comb"], """
self.assertRepr(frag.statements["comb"], """
(
(eq (sig b) (== (sig fsm_state) (const 1'd0)))
(eq (sig a) (== (sig fsm_state) (const 1'd1)))
(eq (sig b) (sig fsm_ongoing_SECOND))
(eq (sig a) (sig fsm_ongoing_FIRST))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd1)))
)
""")