hdl.ast: implement Initial.

This is the last remaining part for first-class formal support.
This commit is contained in:
whitequark 2019-08-15 02:53:07 +00:00
parent 40abaef858
commit ed7e07c6c1
8 changed files with 68 additions and 16 deletions

View file

@ -86,6 +86,9 @@ class _ValueCompiler(ValueVisitor):
def on_Sample(self, value): def on_Sample(self, value):
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
def on_Initial(self, value):
raise NotImplementedError # :nocov:
def on_Record(self, value): def on_Record(self, value):
return self(Cat(value.fields.values())) return self(Cat(value.fields.values()))

View file

@ -331,6 +331,9 @@ class _ValueCompiler(xfrm.ValueVisitor):
def on_Sample(self, value): def on_Sample(self, value):
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
def on_Initial(self, value):
raise NotImplementedError # :nocov:
def on_Record(self, value): def on_Record(self, value):
return self(ast.Cat(value.fields.values())) return self(ast.Cat(value.fields.values()))

View file

@ -1,2 +1,2 @@
from .hdl.ast import AnyConst, AnySeq, Assert, Assume from .hdl.ast import AnyConst, AnySeq, Assert, Assume
from .hdl.ast import Past, Stable, Rose, Fell from .hdl.ast import Past, Stable, Rose, Fell, Initial

View file

@ -12,9 +12,9 @@ from ..tools import *
__all__ = [ __all__ = [
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
"Array", "ArrayProxy", "Array", "ArrayProxy",
"Sample", "Past", "Stable", "Rose", "Fell",
"Signal", "ClockSignal", "ResetSignal", "Signal", "ClockSignal", "ResetSignal",
"UserValue", "UserValue",
"Sample", "Past", "Stable", "Rose", "Fell", "Initial",
"Statement", "Assign", "Assert", "Assume", "Switch", "Delay", "Tick", "Statement", "Assign", "Assert", "Assume", "Switch", "Delay", "Tick",
"Passive", "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict", "Passive", "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict",
"SignalSet", "SignalSet",
@ -957,7 +957,7 @@ class Sample(Value):
self.value = Value.wrap(expr) self.value = Value.wrap(expr)
self.clocks = int(clocks) self.clocks = int(clocks)
self.domain = domain self.domain = domain
if not isinstance(self.value, (Const, Signal, ClockSignal, ResetSignal)): if not isinstance(self.value, (Const, Signal, ClockSignal, ResetSignal, Initial)):
raise TypeError("Sampled value may only be a signal or a constant, not {!r}" raise TypeError("Sampled value may only be a signal or a constant, not {!r}"
.format(self.value)) .format(self.value))
if self.clocks < 0: if self.clocks < 0:
@ -991,6 +991,25 @@ def Fell(expr, clocks=0, domain=None):
return Sample(expr, clocks + 1, domain) & ~Sample(expr, clocks, domain) return Sample(expr, clocks + 1, domain) & ~Sample(expr, clocks, domain)
@final
class Initial(Value):
"""Start indicator, for formal verification.
An ``Initial`` signal is ``1`` at the first cycle of model checking, and ``0`` at any other.
"""
def __init__(self, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at)
def shape(self):
return (1, False)
def _rhs_signals(self):
return ValueSet((self,))
def __repr__(self):
return "(initial)"
class _StatementList(list): class _StatementList(list):
def __repr__(self): def __repr__(self):
return "({})".format(" ".join(map(repr, self))) return "({})".format(" ".join(map(repr, self)))
@ -1276,6 +1295,8 @@ class ValueKey:
tuple(ValueKey(e) for e in self.value._iter_as_values()))) tuple(ValueKey(e) for e in self.value._iter_as_values())))
elif isinstance(self.value, Sample): elif isinstance(self.value, Sample):
self._hash = hash((ValueKey(self.value.value), self.value.clocks, self.value.domain)) self._hash = hash((ValueKey(self.value.value), self.value.clocks, self.value.domain))
elif isinstance(self.value, Initial):
self._hash = 0
else: # :nocov: else: # :nocov:
raise TypeError("Object '{!r}' cannot be used as a key in value collections" raise TypeError("Object '{!r}' cannot be used as a key in value collections"
.format(self.value)) .format(self.value))
@ -1322,6 +1343,8 @@ class ValueKey:
return (ValueKey(self.value.value) == ValueKey(other.value.value) and return (ValueKey(self.value.value) == ValueKey(other.value.value) and
self.value.clocks == other.value.clocks and self.value.clocks == other.value.clocks and
self.value.domain == self.value.domain) self.value.domain == self.value.domain)
elif isinstance(self.value, Initial):
return True
else: # :nocov: else: # :nocov:
raise TypeError("Object '{!r}' cannot be used as a key in value collections" raise TypeError("Object '{!r}' cannot be used as a key in value collections"
.format(self.value)) .format(self.value))

View file

@ -78,6 +78,10 @@ class ValueVisitor(metaclass=ABCMeta):
def on_Sample(self, value): def on_Sample(self, value):
pass # :nocov: pass # :nocov:
@abstractmethod
def on_Initial(self, value):
pass # :nocov:
def on_unknown_value(self, value): def on_unknown_value(self, value):
raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov: raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
@ -115,6 +119,8 @@ class ValueVisitor(metaclass=ABCMeta):
new_value = self.on_ArrayProxy(value) new_value = self.on_ArrayProxy(value)
elif type(value) is Sample: elif type(value) is Sample:
new_value = self.on_Sample(value) new_value = self.on_Sample(value)
elif type(value) is Initial:
new_value = self.on_Initial(value)
elif isinstance(value, UserValue): elif isinstance(value, UserValue):
# Uses `isinstance()` and not `type() is` to allow inheriting. # Uses `isinstance()` and not `type() is` to allow inheriting.
new_value = self.on_value(value._lazy_lower()) new_value = self.on_value(value._lazy_lower())
@ -173,6 +179,9 @@ class ValueTransformer(ValueVisitor):
def on_Sample(self, value): def on_Sample(self, value):
return Sample(self.on_value(value.value), value.clocks, value.domain) return Sample(self.on_value(value.value), value.clocks, value.domain)
def on_Initial(self, value):
return value
class StatementVisitor(metaclass=ABCMeta): class StatementVisitor(metaclass=ABCMeta):
@abstractmethod @abstractmethod
@ -371,6 +380,9 @@ class DomainCollector(ValueVisitor, StatementVisitor):
def on_Sample(self, value): def on_Sample(self, value):
self.on_value(value.value) self.on_value(value.value)
def on_Initial(self, value):
pass
def on_Assign(self, stmt): def on_Assign(self, stmt):
self.on_value(stmt.lhs) self.on_value(stmt.lhs)
self.on_value(stmt.rhs) self.on_value(stmt.rhs)
@ -491,6 +503,7 @@ class SampleDomainInjector(ValueTransformer, StatementTransformer):
class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
def __init__(self): def __init__(self):
self.initial = None
self.sample_cache = None self.sample_cache = None
self.sample_stmts = None self.sample_stmts = None
@ -503,6 +516,8 @@ class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
return "clk", 0 return "clk", 0
elif isinstance(value, ResetSignal): elif isinstance(value, ResetSignal):
return "rst", 1 return "rst", 1
elif isinstance(value, Initial):
return "init", 0 # Past(Initial()) produces 0, 1, 0, 0, ...
else: else:
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
@ -510,8 +525,9 @@ class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
if value in self.sample_cache: if value in self.sample_cache:
return self.sample_cache[value] return self.sample_cache[value]
sampled_value = self.on_value(value.value)
if value.clocks == 0: if value.clocks == 0:
sample = value.value sample = sampled_value
else: else:
assert value.domain is not None assert value.domain is not None
sampled_name, sampled_reset = self._name_reset(value.value) sampled_name, sampled_reset = self._name_reset(value.value)
@ -519,7 +535,7 @@ class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset) sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset)
sample.attrs["nmigen.sample_reg"] = True sample.attrs["nmigen.sample_reg"] = True
prev_sample = self.on_Sample(Sample(value.value, value.clocks - 1, value.domain)) prev_sample = self.on_Sample(Sample(sampled_value, value.clocks - 1, value.domain))
if value.domain not in self.sample_stmts: if value.domain not in self.sample_stmts:
self.sample_stmts[value.domain] = [] self.sample_stmts[value.domain] = []
self.sample_stmts[value.domain].append(sample.eq(prev_sample)) self.sample_stmts[value.domain].append(sample.eq(prev_sample))
@ -527,7 +543,13 @@ class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
self.sample_cache[value] = sample self.sample_cache[value] = sample
return sample return sample
def on_Initial(self, value):
if self.initial is None:
self.initial = Signal(name="init")
return self.initial
def map_statements(self, fragment, new_fragment): def map_statements(self, fragment, new_fragment):
self.initial = None
self.sample_cache = ValueDict() self.sample_cache = ValueDict()
self.sample_stmts = OrderedDict() self.sample_stmts = OrderedDict()
new_fragment.add_statements(map(self.on_statement, fragment.statements)) new_fragment.add_statements(map(self.on_statement, fragment.statements))
@ -535,6 +557,8 @@ class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
new_fragment.add_statements(stmts) new_fragment.add_statements(stmts)
for stmt in stmts: for stmt in stmts:
new_fragment.add_driver(stmt.lhs, domain) new_fragment.add_driver(stmt.lhs, domain)
if self.initial is not None:
new_fragment.add_subfragment(Instance("$initstate", o_Y=self.initial))
class SwitchCleaner(StatementVisitor): class SwitchCleaner(StatementVisitor):

View file

@ -181,9 +181,7 @@ class SyncFIFO(Elaboratable, FIFOInterface):
if platform == "formal": if platform == "formal":
# TODO: move this logic to SymbiYosys # TODO: move this logic to SymbiYosys
initstate = Signal() with m.If(Initial()):
m.submodules += Instance("$initstate", o_Y=initstate)
with m.If(initstate):
m.d.comb += [ m.d.comb += [
Assume(produce < self.depth), Assume(produce < self.depth),
Assume(consume < self.depth), Assume(consume < self.depth),
@ -351,10 +349,7 @@ class AsyncFIFO(Elaboratable, FIFOInterface):
] ]
if platform == "formal": if platform == "formal":
# TODO: move this logic elsewhere with m.If(Initial()):
initstate = Signal()
m.submodules += Instance("$initstate", o_Y=initstate)
with m.If(initstate):
m.d.comb += Assume(produce_w_gry == (produce_w_bin ^ produce_w_bin[1:])) m.d.comb += Assume(produce_w_gry == (produce_w_bin ^ produce_w_bin[1:]))
m.d.comb += Assume(consume_r_gry == (consume_r_bin ^ consume_r_bin[1:])) m.d.comb += Assume(consume_r_gry == (consume_r_bin ^ consume_r_bin[1:]))

View file

@ -614,3 +614,9 @@ class SampleTestCase(FHDLTestCase):
with self.assertRaises(ValueError, with self.assertRaises(ValueError,
"Cannot sample a value 1 cycles in the future"): "Cannot sample a value 1 cycles in the future"):
Sample(Signal(), -1, "sync") Sample(Signal(), -1, "sync")
class InitialTestCase(FHDLTestCase):
def test_initial(self):
i = Initial()
self.assertEqual(i.shape(), (1, False))

View file

@ -208,12 +208,10 @@ class FIFOContractSpec(Elaboratable):
with m.If((read_1 == entry_1) & (read_2 == entry_2)): with m.If((read_1 == entry_1) & (read_2 == entry_2)):
m.next = "DONE" m.next = "DONE"
initstate = Signal() with m.If(Initial()):
m.submodules += Instance("$initstate", o_Y=initstate)
with m.If(initstate):
m.d.comb += Assume(write_fsm.ongoing("WRITE-1")) m.d.comb += Assume(write_fsm.ongoing("WRITE-1"))
m.d.comb += Assume(read_fsm.ongoing("READ")) m.d.comb += Assume(read_fsm.ongoing("READ"))
with m.If(Past(initstate, self.bound - 1)): with m.If(Past(Initial(), self.bound - 1)):
m.d.comb += Assert(read_fsm.ongoing("DONE")) m.d.comb += Assert(read_fsm.ongoing("DONE"))
if self.wdomain != "sync" or self.rdomain != "sync": if self.wdomain != "sync" or self.rdomain != "sync":