compat.genlib.fsm: import/wrap Migen code.

This commit is contained in:
whitequark 2018-12-13 12:40:14 +00:00
parent 9661e897e6
commit 6251c95d4e
5 changed files with 230 additions and 11 deletions

View file

@ -30,7 +30,8 @@ proc_clean
write_verilog
# Make sure there are no undriven wires in generated RTLIL.
proc
select -assert-none w:* i:* %a %d c:* %co* %a %d n:$* %d
write_ilang x.il
select -assert-none w:* i:* %a %d o:* %a %ci* %d c:* %co* %a %d n:$* %d
""".format(il_text))
if popen.returncode:
raise YosysError(error.strip())

View file

@ -8,4 +8,4 @@ from .fhdl.bitcontainer import *
# from .sim import *
# from .genlib.record import *
# from .genlib.fsm import *
from .genlib.fsm import *

View file

@ -45,17 +45,41 @@ class If(ast.Switch):
class Case(ast.Switch):
@deprecated("instead of `Case(test, ...)`, use `with m.Case(test, ...):`")
@deprecated("instead of `Case(test, { value: stmts })`, use `with m.Switch(test):` and "
"`with m.Case(value): stmts`; instead of `\"default\": stmts`, use "
"`with m.Case(): stmts`")
def __init__(self, test, cases):
new_cases = []
for k, v in cases.items():
if k == "default":
if isinstance(k, (bool, int)):
k = Const(k)
if (not isinstance(k, Const)
and not (isinstance(k, str) and k == "default")):
raise TypeError("Case object is not a Migen constant")
if isinstance(k, str) and k == "default":
k = "-" * len(ast.Value.wrap(test))
else:
k = k.value
new_cases.append((k, v))
super().__init__(test, OrderedDict(new_cases))
@deprecated("instead of `Case(...).makedefault()`, use an explicit default case: "
"`with m.Case(): ...`")
def makedefault(self, key=None):
raise NotImplementedError
if key is None:
for choice in self.cases.keys():
if (key is None
or (isinstance(choice, str) and choice == "default")
or choice > key):
key = choice
if isinstance(key, str) and key == "default":
key = "-" * len(self.test)
else:
key = "{:0{}b}".format(wrap(key).value, len(self.test))
stmts = self.cases[key]
del self.cases[key]
self.cases["-" * len(self.test)] = stmts
return self
def Array(*args):

187
nmigen/compat/genlib/fsm.py Normal file
View file

@ -0,0 +1,187 @@
import warnings
from collections import OrderedDict
from ...fhdl.xfrm import ValueTransformer, StatementTransformer
from ...fhdl.ast import *
from ..fhdl.module import CompatModule, CompatFinalizeError
from ..fhdl.structure import If, Case
__all__ = ["AnonymousState", "NextState", "NextValue", "FSM"]
class AnonymousState:
pass
class NextState(Statement):
def __init__(self, state):
self.state = state
class NextValue(Statement):
def __init__(self, target, value):
self.target = target
self.value = value
def _target_eq(a, b):
if type(a) != type(b):
return False
ty = type(a)
if ty == Const:
return a.value == b.value
elif ty == Signal:
return a is b
elif ty == Cat:
return all(_target_eq(x, y) for x, y in zip(a.l, b.l))
elif ty == Slice:
return (_target_eq(a.value, b.value)
and a.start == b.start
and a.stop == b.stop)
elif ty == Part:
return (_target_eq(a.value, b.value)
and _target_eq(a.offset == b.offset)
and a.width == b.width)
elif ty == ArrayProxy:
return (all(_target_eq(x, y) for x, y in zip(a.choices, b.choices))
and _target_eq(a.key, b.key))
else:
raise ValueError("NextValue cannot be used with target type '{}'"
.format(ty))
class _LowerNext(ValueTransformer, StatementTransformer):
def __init__(self, next_state_signal, encoding, aliases):
self.next_state_signal = next_state_signal
self.encoding = encoding
self.aliases = aliases
# (target, next_value_ce, next_value)
self.registers = []
def _get_register_control(self, target):
for x in self.registers:
if _target_eq(target, x[0]):
return x[1], x[2]
raise KeyError
def on_unknown_statement(self, node):
if isinstance(node, NextState):
try:
actual_state = self.aliases[node.state]
except KeyError:
actual_state = node.state
return self.next_state_signal.eq(self.encoding[actual_state])
elif isinstance(node, NextValue):
try:
next_value_ce, next_value = self._get_register_control(node.target)
except KeyError:
related = node.target if isinstance(node.target, Signal) else None
next_value = Signal(node.target.shape())
next_value_ce = Signal()
self.registers.append((node.target, next_value_ce, next_value))
return next_value.eq(node.value), next_value_ce.eq(1)
else:
return node
class FSM(CompatModule):
def __init__(self, reset_state=None):
self.actions = OrderedDict()
self.state_aliases = dict()
self.reset_state = reset_state
self.before_entering_signals = OrderedDict()
self.before_leaving_signals = OrderedDict()
self.after_entering_signals = OrderedDict()
self.after_leaving_signals = OrderedDict()
def act(self, state, *statements):
if self.finalized:
raise CompatFinalizeError
if self.reset_state is None:
self.reset_state = state
if state not in self.actions:
self.actions[state] = []
self.actions[state] += statements
def delayed_enter(self, name, target, delay):
if self.finalized:
raise CompatFinalizeError
if delay > 0:
state = name
for i in range(delay):
if i == delay - 1:
next_state = target
else:
next_state = AnonymousState()
self.act(state, NextState(next_state))
state = next_state
else:
self.state_aliases[name] = target
def ongoing(self, state):
is_ongoing = Signal()
self.act(state, is_ongoing.eq(1))
return is_ongoing
def _get_signal(self, d, state):
if state not in self.actions:
self.actions[state] = []
try:
return d[state]
except KeyError:
is_el = Signal()
d[state] = is_el
return is_el
def before_entering(self, state):
return self._get_signal(self.before_entering_signals, state)
def before_leaving(self, state):
return self._get_signal(self.before_leaving_signals, state)
def after_entering(self, state):
signal = self._get_signal(self.after_entering_signals, state)
self.sync += signal.eq(self.before_entering(state))
return signal
def after_leaving(self, state):
signal = self._get_signal(self.after_leaving_signals, state)
self.sync += signal.eq(self.before_leaving(state))
return signal
def do_finalize(self):
nstates = len(self.actions)
self.encoding = dict((s, n) for n, s in enumerate(self.actions.keys()))
self.decoding = {n: s for s, n in self.encoding.items()}
self.state = Signal(max=nstates, reset=self.encoding[self.reset_state])
self.state._enumeration = self.decoding
self.next_state = Signal(max=nstates)
self.next_state._enumeration = {n: "{}:{}".format(n, s) for n, s in self.decoding.items()}
for state, signal in self.before_leaving_signals.items():
encoded = self.encoding[state]
self.comb += signal.eq((self.state == encoded) & ~(self.next_state == encoded))
if self.reset_state in self.after_entering_signals:
self.after_entering_signals[self.reset_state].reset = 1
for state, signal in self.before_entering_signals.items():
encoded = self.encoding[state]
self.comb += signal.eq(~(self.state == encoded) & (self.next_state == encoded))
self._finalize_sync(self._lower_controls())
def _lower_controls(self):
return _LowerNext(self.next_state, self.encoding, self.state_aliases)
def _finalize_sync(self, ls):
cases = dict((self.encoding[k], ls.on_statement(v)) for k, v in self.actions.items() if v)
with warnings.catch_warnings():
self.comb += [
self.next_state.eq(self.state),
Case(self.state, cases).makedefault(self.encoding[self.reset_state])
]
self.sync += self.state.eq(self.next_state)
for register, next_value_ce, next_value in ls.registers:
self.sync += If(next_value_ce, register.eq(next_value))

View file

@ -1,5 +1,6 @@
from collections import OrderedDict
from collections import OrderedDict, Iterable
from ..tools import flatten
from .ast import *
from .ir import *
@ -36,6 +37,9 @@ class ValueTransformer:
def on_Repl(self, value):
return Repl(self.on_value(value.value), value.count)
def on_unknown_value(self, value):
raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
def on_value(self, value):
if isinstance(value, Const):
new_value = self.on_Const(value)
@ -56,7 +60,7 @@ class ValueTransformer:
elif isinstance(value, Repl):
new_value = self.on_Repl(value)
else:
raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
new_value = self.on_unknown_value(value)
if isinstance(new_value, Value):
new_value.src_loc = value.src_loc
return new_value
@ -73,21 +77,24 @@ class StatementTransformer:
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_value(v)) for k, v in stmt.cases.items())
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases)
def on_statements(self, stmt):
return list(flatten(self.on_statement(stmt) for stmt in self.on_statement(stmt)))
return list(flatten(self.on_statement(stmt) for stmt in stmt))
def on_unknown_statement(self, stmt):
raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
def on_statement(self, stmt):
if isinstance(stmt, Assign):
return self.on_Assign(stmt)
elif isinstance(stmt, Switch):
return self.on_Switch(stmt)
elif isinstance(stmt, (list, tuple)):
elif isinstance(stmt, Iterable):
return self.on_statements(stmt)
else:
raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
return self.on_unknown_statement(stmt)
def __call__(self, value):
return self.on_statement(value)