diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 7a7db76..23f6b6c 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -442,27 +442,13 @@ class _RHSValueCompiler(_ValueCompiler): def on_Const(self, value): return _const(value) - def on_AnyConst(self, value): + def on_AnyValue(self, value): if value in self.s.anys: return self.s.anys[value] res_shape = value.shape() res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$anyconst", ports={ - "\\Y": res, - }, params={ - "WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - self.s.anys[value] = res - return res - - def on_AnySeq(self, value): - if value in self.s.anys: - return self.s.anys[value] - - res_shape = value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$anyseq", ports={ + self.s.rtlil.cell("$" + value.kind.value, ports={ "\\Y": res, }, params={ "WIDTH": res_shape.width, @@ -619,10 +605,7 @@ class _LHSValueCompiler(_ValueCompiler): def on_Const(self, value): raise TypeError # :nocov: - def on_AnyConst(self, value): - raise TypeError # :nocov: - - def on_AnySeq(self, value): + def on_AnyValue(self, value): raise TypeError # :nocov: def on_Initial(self, value): @@ -721,21 +704,17 @@ class _StatementCompiler(_xfrm.StatementVisitor): else: self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) - def on_property(self, stmt): + def on_Property(self, stmt): self(stmt._check.eq(stmt.test)) self(stmt._en.eq(1)) en_wire = self.rhs_compiler(stmt._en) check_wire = self.rhs_compiler(stmt._check) - self.state.rtlil.cell("$" + stmt._kind, ports={ + self.state.rtlil.cell("$" + stmt.kind.value, ports={ "\\A": check_wire, "\\EN": en_wire, }, src=_src(stmt.src_loc), name=stmt.name) - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - def on_Switch(self, stmt): self._check_rhs(stmt.test) diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 16f6373..78c16d4 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -918,32 +918,6 @@ class Const(Value): C = Const # shorthand -class AnyValue(Value, DUID): - def __init__(self, shape, *, src_loc_at=0): - super().__init__(src_loc_at=src_loc_at) - shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) - self.width = shape.width - self.signed = shape.signed - - def shape(self): - return Shape(self.width, self.signed) - - def _rhs_signals(self): - return SignalSet() - - -@final -class AnyConst(AnyValue): - def __repr__(self): - return "(anyconst {}'{})".format(self.width, "s" if self.signed else "") - - -@final -class AnySeq(AnyValue): - def __repr__(self): - return "(anyseq {}'{})".format(self.width, "s" if self.signed else "") - - @final class Operator(Value): def __init__(self, operator, operands, *, src_loc_at=0): @@ -1442,6 +1416,37 @@ class ResetSignal(Value): return f"(rst {self.domain})" +@final +class AnyValue(Value, DUID): + class Kind(Enum): + AnyConst = "anyconst" + AnySeq = "anyseq" + + def __init__(self, kind, shape, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) + self.kind = self.Kind(kind) + shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) + self.width = shape.width + self.signed = shape.signed + + def shape(self): + return Shape(self.width, self.signed) + + def _rhs_signals(self): + return SignalSet() + + def __repr__(self): + return "({} {}'{})".format(self.kind.value, self.width, "s" if self.signed else "") + + +def AnyConst(shape, *, src_loc_at=0): + return AnyValue("anyconst", shape, src_loc_at=src_loc_at+1) + + +def AnySeq(shape, *, src_loc_at=0): + return AnyValue("anyseq", shape, src_loc_at=src_loc_at+1) + + class Array(MutableSequence): """Addressable multiplexer. @@ -1729,11 +1734,18 @@ class UnusedProperty(UnusedMustUse): pass +@final class Property(Statement, MustUse): _MustUse__warning = UnusedProperty - def __init__(self, test, *, _check=None, _en=None, name=None, src_loc_at=0): + class Kind(Enum): + Assert = "assert" + Assume = "assume" + Cover = "cover" + + def __init__(self, kind, test, *, _check=None, _en=None, name=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) + self.kind = self.Kind(kind) self.test = Value.cast(test) self._check = _check self._en = _en @@ -1742,10 +1754,10 @@ class Property(Statement, MustUse): raise TypeError("Property name must be a string or None, not {!r}" .format(self.name)) if self._check is None: - self._check = Signal(reset_less=True, name=f"${self._kind}$check") + self._check = Signal(reset_less=True, name=f"${self.kind.value}$check") self._check.src_loc = self.src_loc if _en is None: - self._en = Signal(reset_less=True, name=f"${self._kind}$en") + self._en = Signal(reset_less=True, name=f"${self.kind.value}$en") self._en.src_loc = self.src_loc def _lhs_signals(self): @@ -1756,23 +1768,20 @@ class Property(Statement, MustUse): def __repr__(self): if self.name is not None: - return f"({self.name}: {self._kind} {self.test!r})" - return f"({self._kind} {self.test!r})" + return f"({self.name}: {self.kind.value} {self.test!r})" + return f"({self.kind.value} {self.test!r})" -@final -class Assert(Property): - _kind = "assert" +def Assert(test, *, name=None, src_loc_at=0): + return Property("assert", test, name=name, src_loc_at=src_loc_at+1) -@final -class Assume(Property): - _kind = "assume" +def Assume(test, *, name=None, src_loc_at=0): + return Property("assume", test, name=name, src_loc_at=src_loc_at+1) -@final -class Cover(Property): - _kind = "cover" +def Cover(test, *, name=None, src_loc_at=0): + return Property("cover", test, name=name, src_loc_at=src_loc_at+1) @final diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index 01bb6fa..481f468 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -506,7 +506,7 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._pop_ctrl() for stmt in Statement.cast(assigns): - if not isinstance(stmt, (Assign, Assert, Assume, Cover)): + if not isinstance(stmt, (Assign, Property)): raise SyntaxError( "Only assignments and property checks may be appended to d.{}" .format(domain_name(domain))) diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index 900ff46..f23bcba 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -6,7 +6,7 @@ from copy import copy from .._utils import flatten, _ignore_deprecated from .. import tracer from ._ast import * -from ._ast import _StatementList +from ._ast import _StatementList, AnyValue, Property from ._cd import * from ._ir import * from ._mem import MemoryInstance @@ -26,14 +26,6 @@ class ValueVisitor(metaclass=ABCMeta): def on_Const(self, value): pass # :nocov: - @abstractmethod - def on_AnyConst(self, value): - pass # :nocov: - - @abstractmethod - def on_AnySeq(self, value): - pass # :nocov: - @abstractmethod def on_Signal(self, value): pass # :nocov: @@ -46,6 +38,10 @@ class ValueVisitor(metaclass=ABCMeta): def on_ResetSignal(self, value): pass # :nocov: + @abstractmethod + def on_AnyValue(self, value): + pass # :nocov: + @abstractmethod def on_Operator(self, value): pass # :nocov: @@ -79,16 +75,14 @@ class ValueVisitor(metaclass=ABCMeta): def on_value(self, value): if type(value) is Const: new_value = self.on_Const(value) - elif type(value) is AnyConst: - new_value = self.on_AnyConst(value) - elif type(value) is AnySeq: - new_value = self.on_AnySeq(value) elif type(value) is Signal: new_value = self.on_Signal(value) elif type(value) is ClockSignal: new_value = self.on_ClockSignal(value) elif type(value) is ResetSignal: new_value = self.on_ResetSignal(value) + elif type(value) is AnyValue: + new_value = self.on_AnyValue(value) elif type(value) is Operator: new_value = self.on_Operator(value) elif type(value) is Slice: @@ -115,12 +109,6 @@ class ValueTransformer(ValueVisitor): def on_Const(self, value): return value - def on_AnyConst(self, value): - return value - - def on_AnySeq(self, value): - return value - def on_Signal(self, value): return value @@ -130,6 +118,9 @@ class ValueTransformer(ValueVisitor): def on_ResetSignal(self, value): return value + def on_AnyValue(self, value): + return value + def on_Operator(self, value): return Operator(value.operator, [self.on_value(o) for o in value.operands]) @@ -157,15 +148,7 @@ class StatementVisitor(metaclass=ABCMeta): pass # :nocov: @abstractmethod - def on_Assert(self, stmt): - pass # :nocov: - - @abstractmethod - def on_Assume(self, stmt): - pass # :nocov: - - @abstractmethod - def on_Cover(self, stmt): + def on_Property(self, stmt): pass # :nocov: @abstractmethod @@ -185,12 +168,8 @@ class StatementVisitor(metaclass=ABCMeta): def on_statement(self, stmt): if type(stmt) is Assign: new_stmt = self.on_Assign(stmt) - elif type(stmt) is Assert: - new_stmt = self.on_Assert(stmt) - elif type(stmt) is Assume: - new_stmt = self.on_Assume(stmt) - elif type(stmt) is Cover: - new_stmt = self.on_Cover(stmt) + elif type(stmt) is Property: + new_stmt = self.on_Property(stmt) elif type(stmt) is Switch: new_stmt = self.on_Switch(stmt) elif isinstance(stmt, Iterable): @@ -216,14 +195,8 @@ class StatementTransformer(StatementVisitor): def on_Assign(self, stmt): return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs)) - def on_Assert(self, stmt): - return Assert(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) - - def on_Assume(self, stmt): - return Assume(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) - - def on_Cover(self, stmt): - return Cover(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) + def on_Property(self, stmt): + return Property(stmt.kind, self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) @@ -351,9 +324,8 @@ class DomainCollector(ValueVisitor, StatementVisitor): pass on_Const = on_ignore - on_AnyConst = on_ignore - on_AnySeq = on_ignore on_Signal = on_ignore + on_AnyValue = on_ignore def on_ClockSignal(self, value): self._add_used_domain(value.domain) @@ -388,13 +360,9 @@ class DomainCollector(ValueVisitor, StatementVisitor): self.on_value(stmt.lhs) self.on_value(stmt.rhs) - def on_property(self, stmt): + def on_Property(self, stmt): self.on_value(stmt.test) - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - def on_Switch(self, stmt): self.on_value(stmt.test) for stmts in stmt.cases.values(): @@ -544,10 +512,8 @@ class SwitchCleaner(StatementVisitor): def on_ignore(self, stmt): return stmt - on_Assign = on_ignore - on_Assert = on_ignore - on_Assume = on_ignore - on_Cover = on_ignore + on_Assign = on_ignore + on_Property = on_ignore def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) @@ -595,15 +561,11 @@ class LHSGroupAnalyzer(StatementVisitor): if lhs_signals: self.unify(*stmt._lhs_signals()) - def on_property(self, stmt): + def on_Property(self, stmt): lhs_signals = stmt._lhs_signals() if lhs_signals: self.unify(*stmt._lhs_signals()) - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - def on_Switch(self, stmt): for case_stmts in stmt.cases.values(): self.on_statements(case_stmts) @@ -630,15 +592,11 @@ class LHSGroupFilter(SwitchCleaner): if any_lhs_signal in self.signals: return stmt - def on_property(self, stmt): + def on_Property(self, stmt): any_lhs_signal = next(iter(stmt._lhs_signals())) if any_lhs_signal in self.signals: return stmt - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - class _ControlInserter(FragmentTransformer): def __init__(self, controls): diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 3862593..121804d 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -97,10 +97,7 @@ class _ValueCompiler(ValueVisitor, _Compiler): def on_ResetSignal(self, value): raise NotImplementedError # :nocov: - def on_AnyConst(self, value): - raise NotImplementedError # :nocov: - - def on_AnySeq(self, value): + def on_AnyValue(self, value): raise NotImplementedError # :nocov: def on_Initial(self, value): @@ -389,13 +386,7 @@ class _StatementCompiler(StatementVisitor, _Compiler): with self.emitter.indent(): self(stmts) - def on_Assert(self, stmt): - raise NotImplementedError # :nocov: - - def on_Assume(self, stmt): - raise NotImplementedError # :nocov: - - def on_Cover(self, stmt): + def on_Property(self, stmt): raise NotImplementedError # :nocov: @classmethod