From 1fe7bd010f9984407d53daee04556b62194181fd Mon Sep 17 00:00:00 2001 From: Catherine Date: Mon, 5 Feb 2024 05:45:45 +0000 Subject: [PATCH] hdl: remove subclassing of `AnyValue` and `Property`. This subclassing is unnecessary and makes downstream code more complex. In the new IR, they are unified into cells with the same name anyway. Even before that, this change simplifies things. --- amaranth/back/rtlil.py | 31 +++------------ amaranth/hdl/_ast.py | 89 +++++++++++++++++++++++------------------- amaranth/hdl/_dsl.py | 2 +- amaranth/hdl/_xfrm.py | 84 ++++++++++----------------------------- amaranth/sim/_pyrtl.py | 13 +----- 5 files changed, 78 insertions(+), 141 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 7a7db76..23f6b6c 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -442,27 +442,13 @@ class _RHSValueCompiler(_ValueCompiler): def on_Const(self, value): return _const(value) - def on_AnyConst(self, value): + def on_AnyValue(self, value): if value in self.s.anys: return self.s.anys[value] res_shape = value.shape() res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$anyconst", ports={ - "\\Y": res, - }, params={ - "WIDTH": res_shape.width, - }, src=_src(value.src_loc)) - self.s.anys[value] = res - return res - - def on_AnySeq(self, value): - if value in self.s.anys: - return self.s.anys[value] - - res_shape = value.shape() - res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) - self.s.rtlil.cell("$anyseq", ports={ + self.s.rtlil.cell("$" + value.kind.value, ports={ "\\Y": res, }, params={ "WIDTH": res_shape.width, @@ -619,10 +605,7 @@ class _LHSValueCompiler(_ValueCompiler): def on_Const(self, value): raise TypeError # :nocov: - def on_AnyConst(self, value): - raise TypeError # :nocov: - - def on_AnySeq(self, value): + def on_AnyValue(self, value): raise TypeError # :nocov: def on_Initial(self, value): @@ -721,21 +704,17 @@ class _StatementCompiler(_xfrm.StatementVisitor): else: self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) - def on_property(self, stmt): + def on_Property(self, stmt): self(stmt._check.eq(stmt.test)) self(stmt._en.eq(1)) en_wire = self.rhs_compiler(stmt._en) check_wire = self.rhs_compiler(stmt._check) - self.state.rtlil.cell("$" + stmt._kind, ports={ + self.state.rtlil.cell("$" + stmt.kind.value, ports={ "\\A": check_wire, "\\EN": en_wire, }, src=_src(stmt.src_loc), name=stmt.name) - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - def on_Switch(self, stmt): self._check_rhs(stmt.test) diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 16f6373..78c16d4 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -918,32 +918,6 @@ class Const(Value): C = Const # shorthand -class AnyValue(Value, DUID): - def __init__(self, shape, *, src_loc_at=0): - super().__init__(src_loc_at=src_loc_at) - shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) - self.width = shape.width - self.signed = shape.signed - - def shape(self): - return Shape(self.width, self.signed) - - def _rhs_signals(self): - return SignalSet() - - -@final -class AnyConst(AnyValue): - def __repr__(self): - return "(anyconst {}'{})".format(self.width, "s" if self.signed else "") - - -@final -class AnySeq(AnyValue): - def __repr__(self): - return "(anyseq {}'{})".format(self.width, "s" if self.signed else "") - - @final class Operator(Value): def __init__(self, operator, operands, *, src_loc_at=0): @@ -1442,6 +1416,37 @@ class ResetSignal(Value): return f"(rst {self.domain})" +@final +class AnyValue(Value, DUID): + class Kind(Enum): + AnyConst = "anyconst" + AnySeq = "anyseq" + + def __init__(self, kind, shape, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) + self.kind = self.Kind(kind) + shape = Shape.cast(shape, src_loc_at=1 + src_loc_at) + self.width = shape.width + self.signed = shape.signed + + def shape(self): + return Shape(self.width, self.signed) + + def _rhs_signals(self): + return SignalSet() + + def __repr__(self): + return "({} {}'{})".format(self.kind.value, self.width, "s" if self.signed else "") + + +def AnyConst(shape, *, src_loc_at=0): + return AnyValue("anyconst", shape, src_loc_at=src_loc_at+1) + + +def AnySeq(shape, *, src_loc_at=0): + return AnyValue("anyseq", shape, src_loc_at=src_loc_at+1) + + class Array(MutableSequence): """Addressable multiplexer. @@ -1729,11 +1734,18 @@ class UnusedProperty(UnusedMustUse): pass +@final class Property(Statement, MustUse): _MustUse__warning = UnusedProperty - def __init__(self, test, *, _check=None, _en=None, name=None, src_loc_at=0): + class Kind(Enum): + Assert = "assert" + Assume = "assume" + Cover = "cover" + + def __init__(self, kind, test, *, _check=None, _en=None, name=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) + self.kind = self.Kind(kind) self.test = Value.cast(test) self._check = _check self._en = _en @@ -1742,10 +1754,10 @@ class Property(Statement, MustUse): raise TypeError("Property name must be a string or None, not {!r}" .format(self.name)) if self._check is None: - self._check = Signal(reset_less=True, name=f"${self._kind}$check") + self._check = Signal(reset_less=True, name=f"${self.kind.value}$check") self._check.src_loc = self.src_loc if _en is None: - self._en = Signal(reset_less=True, name=f"${self._kind}$en") + self._en = Signal(reset_less=True, name=f"${self.kind.value}$en") self._en.src_loc = self.src_loc def _lhs_signals(self): @@ -1756,23 +1768,20 @@ class Property(Statement, MustUse): def __repr__(self): if self.name is not None: - return f"({self.name}: {self._kind} {self.test!r})" - return f"({self._kind} {self.test!r})" + return f"({self.name}: {self.kind.value} {self.test!r})" + return f"({self.kind.value} {self.test!r})" -@final -class Assert(Property): - _kind = "assert" +def Assert(test, *, name=None, src_loc_at=0): + return Property("assert", test, name=name, src_loc_at=src_loc_at+1) -@final -class Assume(Property): - _kind = "assume" +def Assume(test, *, name=None, src_loc_at=0): + return Property("assume", test, name=name, src_loc_at=src_loc_at+1) -@final -class Cover(Property): - _kind = "cover" +def Cover(test, *, name=None, src_loc_at=0): + return Property("cover", test, name=name, src_loc_at=src_loc_at+1) @final diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index 01bb6fa..481f468 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -506,7 +506,7 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._pop_ctrl() for stmt in Statement.cast(assigns): - if not isinstance(stmt, (Assign, Assert, Assume, Cover)): + if not isinstance(stmt, (Assign, Property)): raise SyntaxError( "Only assignments and property checks may be appended to d.{}" .format(domain_name(domain))) diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index 900ff46..f23bcba 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -6,7 +6,7 @@ from copy import copy from .._utils import flatten, _ignore_deprecated from .. import tracer from ._ast import * -from ._ast import _StatementList +from ._ast import _StatementList, AnyValue, Property from ._cd import * from ._ir import * from ._mem import MemoryInstance @@ -26,14 +26,6 @@ class ValueVisitor(metaclass=ABCMeta): def on_Const(self, value): pass # :nocov: - @abstractmethod - def on_AnyConst(self, value): - pass # :nocov: - - @abstractmethod - def on_AnySeq(self, value): - pass # :nocov: - @abstractmethod def on_Signal(self, value): pass # :nocov: @@ -46,6 +38,10 @@ class ValueVisitor(metaclass=ABCMeta): def on_ResetSignal(self, value): pass # :nocov: + @abstractmethod + def on_AnyValue(self, value): + pass # :nocov: + @abstractmethod def on_Operator(self, value): pass # :nocov: @@ -79,16 +75,14 @@ class ValueVisitor(metaclass=ABCMeta): def on_value(self, value): if type(value) is Const: new_value = self.on_Const(value) - elif type(value) is AnyConst: - new_value = self.on_AnyConst(value) - elif type(value) is AnySeq: - new_value = self.on_AnySeq(value) elif type(value) is Signal: new_value = self.on_Signal(value) elif type(value) is ClockSignal: new_value = self.on_ClockSignal(value) elif type(value) is ResetSignal: new_value = self.on_ResetSignal(value) + elif type(value) is AnyValue: + new_value = self.on_AnyValue(value) elif type(value) is Operator: new_value = self.on_Operator(value) elif type(value) is Slice: @@ -115,12 +109,6 @@ class ValueTransformer(ValueVisitor): def on_Const(self, value): return value - def on_AnyConst(self, value): - return value - - def on_AnySeq(self, value): - return value - def on_Signal(self, value): return value @@ -130,6 +118,9 @@ class ValueTransformer(ValueVisitor): def on_ResetSignal(self, value): return value + def on_AnyValue(self, value): + return value + def on_Operator(self, value): return Operator(value.operator, [self.on_value(o) for o in value.operands]) @@ -157,15 +148,7 @@ class StatementVisitor(metaclass=ABCMeta): pass # :nocov: @abstractmethod - def on_Assert(self, stmt): - pass # :nocov: - - @abstractmethod - def on_Assume(self, stmt): - pass # :nocov: - - @abstractmethod - def on_Cover(self, stmt): + def on_Property(self, stmt): pass # :nocov: @abstractmethod @@ -185,12 +168,8 @@ class StatementVisitor(metaclass=ABCMeta): def on_statement(self, stmt): if type(stmt) is Assign: new_stmt = self.on_Assign(stmt) - elif type(stmt) is Assert: - new_stmt = self.on_Assert(stmt) - elif type(stmt) is Assume: - new_stmt = self.on_Assume(stmt) - elif type(stmt) is Cover: - new_stmt = self.on_Cover(stmt) + elif type(stmt) is Property: + new_stmt = self.on_Property(stmt) elif type(stmt) is Switch: new_stmt = self.on_Switch(stmt) elif isinstance(stmt, Iterable): @@ -216,14 +195,8 @@ class StatementTransformer(StatementVisitor): def on_Assign(self, stmt): return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs)) - def on_Assert(self, stmt): - return Assert(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) - - def on_Assume(self, stmt): - return Assume(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) - - def on_Cover(self, stmt): - return Cover(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) + def on_Property(self, stmt): + return Property(stmt.kind, self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) @@ -351,9 +324,8 @@ class DomainCollector(ValueVisitor, StatementVisitor): pass on_Const = on_ignore - on_AnyConst = on_ignore - on_AnySeq = on_ignore on_Signal = on_ignore + on_AnyValue = on_ignore def on_ClockSignal(self, value): self._add_used_domain(value.domain) @@ -388,13 +360,9 @@ class DomainCollector(ValueVisitor, StatementVisitor): self.on_value(stmt.lhs) self.on_value(stmt.rhs) - def on_property(self, stmt): + def on_Property(self, stmt): self.on_value(stmt.test) - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - def on_Switch(self, stmt): self.on_value(stmt.test) for stmts in stmt.cases.values(): @@ -544,10 +512,8 @@ class SwitchCleaner(StatementVisitor): def on_ignore(self, stmt): return stmt - on_Assign = on_ignore - on_Assert = on_ignore - on_Assume = on_ignore - on_Cover = on_ignore + on_Assign = on_ignore + on_Property = on_ignore def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) @@ -595,15 +561,11 @@ class LHSGroupAnalyzer(StatementVisitor): if lhs_signals: self.unify(*stmt._lhs_signals()) - def on_property(self, stmt): + def on_Property(self, stmt): lhs_signals = stmt._lhs_signals() if lhs_signals: self.unify(*stmt._lhs_signals()) - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - def on_Switch(self, stmt): for case_stmts in stmt.cases.values(): self.on_statements(case_stmts) @@ -630,15 +592,11 @@ class LHSGroupFilter(SwitchCleaner): if any_lhs_signal in self.signals: return stmt - def on_property(self, stmt): + def on_Property(self, stmt): any_lhs_signal = next(iter(stmt._lhs_signals())) if any_lhs_signal in self.signals: return stmt - on_Assert = on_property - on_Assume = on_property - on_Cover = on_property - class _ControlInserter(FragmentTransformer): def __init__(self, controls): diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 3862593..121804d 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -97,10 +97,7 @@ class _ValueCompiler(ValueVisitor, _Compiler): def on_ResetSignal(self, value): raise NotImplementedError # :nocov: - def on_AnyConst(self, value): - raise NotImplementedError # :nocov: - - def on_AnySeq(self, value): + def on_AnyValue(self, value): raise NotImplementedError # :nocov: def on_Initial(self, value): @@ -389,13 +386,7 @@ class _StatementCompiler(StatementVisitor, _Compiler): with self.emitter.indent(): self(stmts) - def on_Assert(self, stmt): - raise NotImplementedError # :nocov: - - def on_Assume(self, stmt): - raise NotImplementedError # :nocov: - - def on_Cover(self, stmt): + def on_Property(self, stmt): raise NotImplementedError # :nocov: @classmethod