hdl._dsl: Change FSM codegen to avoid mutating AST nodes.

Fixes #1066.
This commit is contained in:
Wanda 2024-02-27 15:03:48 +01:00 committed by Catherine
parent f524dd041a
commit 85bb5ee77c
3 changed files with 108 additions and 42 deletions

View file

@ -2304,6 +2304,11 @@ def Cover(test, *, name=None, src_loc_at=0):
return Property("cover", test, name=name, src_loc_at=src_loc_at+1) return Property("cover", test, name=name, src_loc_at=src_loc_at+1)
class _LateBoundStatement(Statement):
def resolve(self):
raise NotImplementedError # :nocov:
@final @final
class Switch(Statement): class Switch(Statement):
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}): def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):

View file

@ -9,6 +9,7 @@ from .._utils import flatten
from ..utils import bits_for from ..utils import bits_for
from .. import tracer from .. import tracer
from ._ast import * from ._ast import *
from ._ast import _StatementList, _LateBoundStatement, Property
from ._ir import * from ._ir import *
from ._cd import * from ._cd import *
from ._xfrm import * from ._xfrm import *
@ -146,16 +147,51 @@ def _guardedcontextmanager(keyword):
return decorator return decorator
class FSMNextStatement(_LateBoundStatement):
def __init__(self, ctrl_data, state, *, src_loc_at=0):
self.ctrl_data = ctrl_data
self.state = state
super().__init__(src_loc_at=1 + src_loc_at)
def resolve(self):
return self.ctrl_data["signal"].eq(self.ctrl_data["encoding"][self.state])
class FSM: class FSM:
def __init__(self, state, encoding, decoding): def __init__(self, data):
self.state = state self._data = data
self.encoding = encoding self.encoding = data["encoding"]
self.decoding = decoding self.decoding = data["decoding"]
def ongoing(self, name): def ongoing(self, name):
if name not in self.encoding: if name not in self.encoding:
self.encoding[name] = len(self.encoding) self.encoding[name] = len(self.encoding)
return Operator("==", [self.state, self.encoding[name]], src_loc_at=0) fsm_name = self._data["name"]
self._data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
return self._data["ongoing"][name]
def resolve_statement(stmt):
if isinstance(stmt, _LateBoundStatement):
return resolve_statement(stmt.resolve())
elif isinstance(stmt, Switch):
return Switch(
test=stmt.test,
cases=OrderedDict(
(patterns, resolve_statements(stmts))
for patterns, stmts in stmt.cases.items()
),
src_loc=stmt.src_loc,
case_src_locs=stmt.case_src_locs,
)
elif isinstance(stmt, (Assign, Property)):
return stmt
else:
assert False # :nocov:
def resolve_statements(stmts):
return _StatementList(resolve_statement(stmt) for stmt in stmts)
class Module(_ModuleBuilderRoot, Elaboratable): class Module(_ModuleBuilderRoot, Elaboratable):
@ -172,6 +208,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._statements = {} self._statements = {}
self._ctrl_context = None self._ctrl_context = None
self._ctrl_stack = [] self._ctrl_stack = []
self._top_comb_statements = _StatementList()
self._driving = SignalDict() self._driving = SignalDict()
self._named_submodules = {} self._named_submodules = {}
@ -391,17 +428,16 @@ class Module(_ModuleBuilderRoot, Elaboratable):
init = reset init = reset
fsm_data = self._set_ctrl("FSM", { fsm_data = self._set_ctrl("FSM", {
"name": name, "name": name,
"signal": Signal(name=f"{name}_state", src_loc_at=2),
"init": init, "init": init,
"domain": domain, "domain": domain,
"encoding": OrderedDict(), "encoding": OrderedDict(),
"decoding": OrderedDict(), "decoding": OrderedDict(),
"ongoing": {},
"states": OrderedDict(), "states": OrderedDict(),
"src_loc": tracer.get_src_loc(src_loc_at=1), "src_loc": tracer.get_src_loc(src_loc_at=1),
"state_src_locs": {}, "state_src_locs": {},
}) })
self._generated[name] = fsm = \ self._generated[name] = fsm = FSM(fsm_data)
FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"])
try: try:
self._ctrl_context = "FSM" self._ctrl_context = "FSM"
self.domain._depth += 1 self.domain._depth += 1
@ -414,6 +450,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self.domain._depth -= 1 self.domain._depth -= 1
self._ctrl_context = None self._ctrl_context = None
self._pop_ctrl() self._pop_ctrl()
fsm.state = fsm_data["signal"]
@contextmanager @contextmanager
def State(self, name): def State(self, name):
@ -423,7 +460,9 @@ class Module(_ModuleBuilderRoot, Elaboratable):
if name in fsm_data["states"]: if name in fsm_data["states"]:
raise NameError(f"FSM state '{name}' is already defined") raise NameError(f"FSM state '{name}' is already defined")
if name not in fsm_data["encoding"]: if name not in fsm_data["encoding"]:
fsm_name = fsm_data["name"]
fsm_data["encoding"][name] = len(fsm_data["encoding"]) fsm_data["encoding"][name] = len(fsm_data["encoding"])
fsm_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
try: try:
_outer_case, self._statements = self._statements, {} _outer_case, self._statements = self._statements, {}
self._ctrl_context = None self._ctrl_context = None
@ -445,9 +484,11 @@ class Module(_ModuleBuilderRoot, Elaboratable):
for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)): for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)):
if ctrl_name == "FSM": if ctrl_name == "FSM":
if name not in ctrl_data["encoding"]: if name not in ctrl_data["encoding"]:
fsm_name = ctrl_data["name"]
ctrl_data["encoding"][name] = len(ctrl_data["encoding"]) ctrl_data["encoding"][name] = len(ctrl_data["encoding"])
ctrl_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
self._add_statement( self._add_statement(
assigns=[ctrl_data["signal"].eq(ctrl_data["encoding"][name])], assigns=[FSMNextStatement(ctrl_data, name)],
domain=ctrl_data["domain"], domain=ctrl_data["domain"],
depth=len(self._ctrl_stack)) depth=len(self._ctrl_stack))
return return
@ -500,19 +541,25 @@ class Module(_ModuleBuilderRoot, Elaboratable):
src_loc=src_loc, case_src_locs=switch_case_src_locs)) src_loc=src_loc, case_src_locs=switch_case_src_locs))
if name == "FSM": if name == "FSM":
fsm_signal, fsm_init, fsm_encoding, fsm_decoding, fsm_states = \ fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
data["signal"], data["init"], data["encoding"], data["decoding"], data["states"] data["name"], data["init"], data["encoding"], data["decoding"], data["states"], data["ongoing"]
fsm_state_src_locs = data["state_src_locs"] fsm_state_src_locs = data["state_src_locs"]
if not fsm_states: if not fsm_states:
data["signal"] = Signal(0, name=f"{fsm_name}_state", src_loc_at=2)
return return
fsm_signal.width = bits_for(len(fsm_encoding) - 1)
if fsm_init is None: if fsm_init is None:
fsm_signal.init = fsm_encoding[next(iter(fsm_states))] init = fsm_encoding[next(iter(fsm_states))]
else: else:
fsm_signal.init = fsm_encoding[fsm_init] init = fsm_encoding[fsm_init]
# The FSM is encoded such that the state with encoding 0 is always the init state. # The FSM is encoded such that the state with encoding 0 is always the init state.
fsm_decoding.update((n, s) for s, n in fsm_encoding.items()) fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}" data["signal"] = fsm_signal = Signal(range(len(fsm_encoding)), init=init,
name=f"{fsm_name}_state", src_loc_at=2,
decoder=lambda n: f"{fsm_decoding[n]}/{n}")
for name, sig in fsm_ongoing.items():
self._top_comb_statements.append(
sig.eq(Operator("==", [fsm_signal, fsm_encoding[name]], src_loc_at=0)))
domains = set() domains = set()
for stmts in fsm_states.values(): for stmts in fsm_states.values():
@ -533,20 +580,21 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._pop_ctrl() self._pop_ctrl()
for stmt in Statement.cast(assigns): for stmt in Statement.cast(assigns):
if not isinstance(stmt, (Assign, Property)): if not isinstance(stmt, (Assign, Property, _LateBoundStatement)):
raise SyntaxError( raise SyntaxError(
f"Only assignments and property checks may be appended to d.{domain}") f"Only assignments and property checks may be appended to d.{domain}")
stmt._MustUse__used = True stmt._MustUse__used = True
for signal in stmt._lhs_signals(): if isinstance(stmt, Assign):
if signal not in self._driving: for signal in stmt._lhs_signals():
self._driving[signal] = domain if signal not in self._driving:
elif self._driving[signal] != domain: self._driving[signal] = domain
cd_curr = self._driving[signal] elif self._driving[signal] != domain:
raise SyntaxError( cd_curr = self._driving[signal]
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is " raise SyntaxError(
f"already driven from d.{cd_curr}") f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
self._statements.setdefault(domain, []).append(stmt) self._statements.setdefault(domain, []).append(stmt)
@ -586,9 +634,13 @@ class Module(_ModuleBuilderRoot, Elaboratable):
for submodule, src_loc in self._anon_submodules: for submodule, src_loc in self._anon_submodules:
fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc) fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc)
for domain, statements in self._statements.items(): for domain, statements in self._statements.items():
statements = resolve_statements(statements)
fragment.add_statements(domain, statements) fragment.add_statements(domain, statements)
for signal, domain in self._driving.items(): for signal in statements._lhs_signals():
fragment.add_driver(signal, domain) fragment.add_driver(signal, domain)
fragment.add_statements("comb", self._top_comb_statements)
for signal in self._top_comb_statements._lhs_signals():
fragment.add_driver(signal, "comb")
fragment.add_domains(self._domains.values()) fragment.add_domains(self._domains.values())
fragment.generated.update(self._generated) fragment.generated.update(self._generated)
return fragment return fragment

View file

@ -593,8 +593,9 @@ class DSLTestCase(FHDLTestCase):
m.d.sync += b.eq(~b) m.d.sync += b.eq(~b)
with m.If(c): with m.If(c):
m.next = "FIRST" m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """ frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
@ -602,9 +603,11 @@ class DSLTestCase(FHDLTestCase):
) )
(case 1 ) (case 1 )
) )
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
) )
""") """)
self.assertRepr(m._statements["sync"], """ self.assertRepr(frag.statements["sync"], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
@ -620,13 +623,13 @@ class DSLTestCase(FHDLTestCase):
) )
) )
""") """)
self.assertEqual({repr(k): v for k, v in m._driving.items()}, { self.assertEqual({repr(sig): k for k, v in frag.drivers.items() for sig in v}, {
"(sig a)": "comb", "(sig a)": "comb",
"(sig fsm_state)": "sync", "(sig fsm_state)": "sync",
"(sig b)": "sync", "(sig b)": "sync",
"(sig fsm_ongoing_FIRST)": "comb",
"(sig fsm_ongoing_SECOND)": "comb",
}) })
frag = m.elaborate(platform=None)
fsm = frag.find_generated("fsm") fsm = frag.find_generated("fsm")
self.assertIsInstance(fsm.state, Signal) self.assertIsInstance(fsm.state, Signal)
self.assertEqual(fsm.encoding, OrderedDict({ self.assertEqual(fsm.encoding, OrderedDict({
@ -647,8 +650,8 @@ class DSLTestCase(FHDLTestCase):
m.next = "SECOND" m.next = "SECOND"
with m.State("SECOND"): with m.State("SECOND"):
m.next = "FIRST" m.next = "FIRST"
m._flush() frag = m.elaborate(platform=None)
self.assertRepr(m._statements["comb"], """ self.assertRepr(frag.statements["comb"], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
@ -656,9 +659,11 @@ class DSLTestCase(FHDLTestCase):
) )
(case 1 ) (case 1 )
) )
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
) )
""") """)
self.assertRepr(m._statements["sync"], """ self.assertRepr(frag.statements["sync"], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
@ -683,8 +688,8 @@ class DSLTestCase(FHDLTestCase):
m.next = "SECOND" m.next = "SECOND"
with m.State("SECOND"): with m.State("SECOND"):
m.next = "FIRST" m.next = "FIRST"
m._flush() frag = m.elaborate(platform=None)
self.assertRepr(m._statements["comb"], """ self.assertRepr(frag.statements["comb"], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
@ -692,9 +697,11 @@ class DSLTestCase(FHDLTestCase):
) )
(case 1 ) (case 1 )
) )
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
) )
""") """)
self.assertRepr(m._statements["sync"], """ self.assertRepr(frag.statements["sync"], """
( (
(switch (sig fsm_state) (switch (sig fsm_state)
(case 0 (case 0
@ -731,13 +738,15 @@ class DSLTestCase(FHDLTestCase):
m.d.comb += a.eq(fsm.ongoing("FIRST")) m.d.comb += a.eq(fsm.ongoing("FIRST"))
with m.State("SECOND"): with m.State("SECOND"):
pass pass
m._flush() frag = m.elaborate(platform=None)
self.assertEqual(m._generated["fsm"].state.init, 1) self.assertEqual(m._generated["fsm"].state.init, 1)
self.maxDiff = 10000 self.maxDiff = 10000
self.assertRepr(m._statements["comb"], """ self.assertRepr(frag.statements["comb"], """
( (
(eq (sig b) (== (sig fsm_state) (const 1'd0))) (eq (sig b) (sig fsm_ongoing_SECOND))
(eq (sig a) (== (sig fsm_state) (const 1'd1))) (eq (sig a) (sig fsm_ongoing_FIRST))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd1)))
) )
""") """)