hdl: remove subclassing of AnyValue and Property.

This subclassing is unnecessary and makes downstream code more complex.
In the new IR, they are unified into cells with the same name anyway.
Even before that, this change simplifies things.
This commit is contained in:
Catherine 2024-02-05 05:45:45 +00:00
parent 115954b4d9
commit 1fe7bd010f
5 changed files with 78 additions and 141 deletions

View file

@ -442,27 +442,13 @@ class _RHSValueCompiler(_ValueCompiler):
def on_Const(self, value): def on_Const(self, value):
return _const(value) return _const(value)
def on_AnyConst(self, value): def on_AnyValue(self, value):
if value in self.s.anys: if value in self.s.anys:
return self.s.anys[value] return self.s.anys[value]
res_shape = value.shape() res_shape = value.shape()
res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc)) res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc))
self.s.rtlil.cell("$anyconst", ports={ self.s.rtlil.cell("$" + value.kind.value, ports={
"\\Y": res,
}, params={
"WIDTH": res_shape.width,
}, src=_src(value.src_loc))
self.s.anys[value] = res
return res
def on_AnySeq(self, value):
if value in self.s.anys:
return self.s.anys[value]
res_shape = value.shape()
res = self.s.rtlil.wire(width=res_shape.width, src=_src(value.src_loc))
self.s.rtlil.cell("$anyseq", ports={
"\\Y": res, "\\Y": res,
}, params={ }, params={
"WIDTH": res_shape.width, "WIDTH": res_shape.width,
@ -619,10 +605,7 @@ class _LHSValueCompiler(_ValueCompiler):
def on_Const(self, value): def on_Const(self, value):
raise TypeError # :nocov: raise TypeError # :nocov:
def on_AnyConst(self, value): def on_AnyValue(self, value):
raise TypeError # :nocov:
def on_AnySeq(self, value):
raise TypeError # :nocov: raise TypeError # :nocov:
def on_Initial(self, value): def on_Initial(self, value):
@ -721,21 +704,17 @@ class _StatementCompiler(_xfrm.StatementVisitor):
else: else:
self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec) self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec)
def on_property(self, stmt): def on_Property(self, stmt):
self(stmt._check.eq(stmt.test)) self(stmt._check.eq(stmt.test))
self(stmt._en.eq(1)) self(stmt._en.eq(1))
en_wire = self.rhs_compiler(stmt._en) en_wire = self.rhs_compiler(stmt._en)
check_wire = self.rhs_compiler(stmt._check) check_wire = self.rhs_compiler(stmt._check)
self.state.rtlil.cell("$" + stmt._kind, ports={ self.state.rtlil.cell("$" + stmt.kind.value, ports={
"\\A": check_wire, "\\A": check_wire,
"\\EN": en_wire, "\\EN": en_wire,
}, src=_src(stmt.src_loc), name=stmt.name) }, src=_src(stmt.src_loc), name=stmt.name)
on_Assert = on_property
on_Assume = on_property
on_Cover = on_property
def on_Switch(self, stmt): def on_Switch(self, stmt):
self._check_rhs(stmt.test) self._check_rhs(stmt.test)

View file

@ -918,32 +918,6 @@ class Const(Value):
C = Const # shorthand C = Const # shorthand
class AnyValue(Value, DUID):
def __init__(self, shape, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
def shape(self):
return Shape(self.width, self.signed)
def _rhs_signals(self):
return SignalSet()
@final
class AnyConst(AnyValue):
def __repr__(self):
return "(anyconst {}'{})".format(self.width, "s" if self.signed else "")
@final
class AnySeq(AnyValue):
def __repr__(self):
return "(anyseq {}'{})".format(self.width, "s" if self.signed else "")
@final @final
class Operator(Value): class Operator(Value):
def __init__(self, operator, operands, *, src_loc_at=0): def __init__(self, operator, operands, *, src_loc_at=0):
@ -1442,6 +1416,37 @@ class ResetSignal(Value):
return f"(rst {self.domain})" return f"(rst {self.domain})"
@final
class AnyValue(Value, DUID):
class Kind(Enum):
AnyConst = "anyconst"
AnySeq = "anyseq"
def __init__(self, kind, shape, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
def shape(self):
return Shape(self.width, self.signed)
def _rhs_signals(self):
return SignalSet()
def __repr__(self):
return "({} {}'{})".format(self.kind.value, self.width, "s" if self.signed else "")
def AnyConst(shape, *, src_loc_at=0):
return AnyValue("anyconst", shape, src_loc_at=src_loc_at+1)
def AnySeq(shape, *, src_loc_at=0):
return AnyValue("anyseq", shape, src_loc_at=src_loc_at+1)
class Array(MutableSequence): class Array(MutableSequence):
"""Addressable multiplexer. """Addressable multiplexer.
@ -1729,11 +1734,18 @@ class UnusedProperty(UnusedMustUse):
pass pass
@final
class Property(Statement, MustUse): class Property(Statement, MustUse):
_MustUse__warning = UnusedProperty _MustUse__warning = UnusedProperty
def __init__(self, test, *, _check=None, _en=None, name=None, src_loc_at=0): class Kind(Enum):
Assert = "assert"
Assume = "assume"
Cover = "cover"
def __init__(self, kind, test, *, _check=None, _en=None, name=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at) super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
self.test = Value.cast(test) self.test = Value.cast(test)
self._check = _check self._check = _check
self._en = _en self._en = _en
@ -1742,10 +1754,10 @@ class Property(Statement, MustUse):
raise TypeError("Property name must be a string or None, not {!r}" raise TypeError("Property name must be a string or None, not {!r}"
.format(self.name)) .format(self.name))
if self._check is None: if self._check is None:
self._check = Signal(reset_less=True, name=f"${self._kind}$check") self._check = Signal(reset_less=True, name=f"${self.kind.value}$check")
self._check.src_loc = self.src_loc self._check.src_loc = self.src_loc
if _en is None: if _en is None:
self._en = Signal(reset_less=True, name=f"${self._kind}$en") self._en = Signal(reset_less=True, name=f"${self.kind.value}$en")
self._en.src_loc = self.src_loc self._en.src_loc = self.src_loc
def _lhs_signals(self): def _lhs_signals(self):
@ -1756,23 +1768,20 @@ class Property(Statement, MustUse):
def __repr__(self): def __repr__(self):
if self.name is not None: if self.name is not None:
return f"({self.name}: {self._kind} {self.test!r})" return f"({self.name}: {self.kind.value} {self.test!r})"
return f"({self._kind} {self.test!r})" return f"({self.kind.value} {self.test!r})"
@final def Assert(test, *, name=None, src_loc_at=0):
class Assert(Property): return Property("assert", test, name=name, src_loc_at=src_loc_at+1)
_kind = "assert"
@final def Assume(test, *, name=None, src_loc_at=0):
class Assume(Property): return Property("assume", test, name=name, src_loc_at=src_loc_at+1)
_kind = "assume"
@final def Cover(test, *, name=None, src_loc_at=0):
class Cover(Property): return Property("cover", test, name=name, src_loc_at=src_loc_at+1)
_kind = "cover"
@final @final

View file

@ -506,7 +506,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
self._pop_ctrl() self._pop_ctrl()
for stmt in Statement.cast(assigns): for stmt in Statement.cast(assigns):
if not isinstance(stmt, (Assign, Assert, Assume, Cover)): if not isinstance(stmt, (Assign, Property)):
raise SyntaxError( raise SyntaxError(
"Only assignments and property checks may be appended to d.{}" "Only assignments and property checks may be appended to d.{}"
.format(domain_name(domain))) .format(domain_name(domain)))

View file

@ -6,7 +6,7 @@ from copy import copy
from .._utils import flatten, _ignore_deprecated from .._utils import flatten, _ignore_deprecated
from .. import tracer from .. import tracer
from ._ast import * from ._ast import *
from ._ast import _StatementList from ._ast import _StatementList, AnyValue, Property
from ._cd import * from ._cd import *
from ._ir import * from ._ir import *
from ._mem import MemoryInstance from ._mem import MemoryInstance
@ -26,14 +26,6 @@ class ValueVisitor(metaclass=ABCMeta):
def on_Const(self, value): def on_Const(self, value):
pass # :nocov: pass # :nocov:
@abstractmethod
def on_AnyConst(self, value):
pass # :nocov:
@abstractmethod
def on_AnySeq(self, value):
pass # :nocov:
@abstractmethod @abstractmethod
def on_Signal(self, value): def on_Signal(self, value):
pass # :nocov: pass # :nocov:
@ -46,6 +38,10 @@ class ValueVisitor(metaclass=ABCMeta):
def on_ResetSignal(self, value): def on_ResetSignal(self, value):
pass # :nocov: pass # :nocov:
@abstractmethod
def on_AnyValue(self, value):
pass # :nocov:
@abstractmethod @abstractmethod
def on_Operator(self, value): def on_Operator(self, value):
pass # :nocov: pass # :nocov:
@ -79,16 +75,14 @@ class ValueVisitor(metaclass=ABCMeta):
def on_value(self, value): def on_value(self, value):
if type(value) is Const: if type(value) is Const:
new_value = self.on_Const(value) new_value = self.on_Const(value)
elif type(value) is AnyConst:
new_value = self.on_AnyConst(value)
elif type(value) is AnySeq:
new_value = self.on_AnySeq(value)
elif type(value) is Signal: elif type(value) is Signal:
new_value = self.on_Signal(value) new_value = self.on_Signal(value)
elif type(value) is ClockSignal: elif type(value) is ClockSignal:
new_value = self.on_ClockSignal(value) new_value = self.on_ClockSignal(value)
elif type(value) is ResetSignal: elif type(value) is ResetSignal:
new_value = self.on_ResetSignal(value) new_value = self.on_ResetSignal(value)
elif type(value) is AnyValue:
new_value = self.on_AnyValue(value)
elif type(value) is Operator: elif type(value) is Operator:
new_value = self.on_Operator(value) new_value = self.on_Operator(value)
elif type(value) is Slice: elif type(value) is Slice:
@ -115,12 +109,6 @@ class ValueTransformer(ValueVisitor):
def on_Const(self, value): def on_Const(self, value):
return value return value
def on_AnyConst(self, value):
return value
def on_AnySeq(self, value):
return value
def on_Signal(self, value): def on_Signal(self, value):
return value return value
@ -130,6 +118,9 @@ class ValueTransformer(ValueVisitor):
def on_ResetSignal(self, value): def on_ResetSignal(self, value):
return value return value
def on_AnyValue(self, value):
return value
def on_Operator(self, value): def on_Operator(self, value):
return Operator(value.operator, [self.on_value(o) for o in value.operands]) return Operator(value.operator, [self.on_value(o) for o in value.operands])
@ -157,15 +148,7 @@ class StatementVisitor(metaclass=ABCMeta):
pass # :nocov: pass # :nocov:
@abstractmethod @abstractmethod
def on_Assert(self, stmt): def on_Property(self, stmt):
pass # :nocov:
@abstractmethod
def on_Assume(self, stmt):
pass # :nocov:
@abstractmethod
def on_Cover(self, stmt):
pass # :nocov: pass # :nocov:
@abstractmethod @abstractmethod
@ -185,12 +168,8 @@ class StatementVisitor(metaclass=ABCMeta):
def on_statement(self, stmt): def on_statement(self, stmt):
if type(stmt) is Assign: if type(stmt) is Assign:
new_stmt = self.on_Assign(stmt) new_stmt = self.on_Assign(stmt)
elif type(stmt) is Assert: elif type(stmt) is Property:
new_stmt = self.on_Assert(stmt) new_stmt = self.on_Property(stmt)
elif type(stmt) is Assume:
new_stmt = self.on_Assume(stmt)
elif type(stmt) is Cover:
new_stmt = self.on_Cover(stmt)
elif type(stmt) is Switch: elif type(stmt) is Switch:
new_stmt = self.on_Switch(stmt) new_stmt = self.on_Switch(stmt)
elif isinstance(stmt, Iterable): elif isinstance(stmt, Iterable):
@ -216,14 +195,8 @@ class StatementTransformer(StatementVisitor):
def on_Assign(self, stmt): def on_Assign(self, stmt):
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs)) return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
def on_Assert(self, stmt): def on_Property(self, stmt):
return Assert(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name) return Property(stmt.kind, self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name)
def on_Assume(self, stmt):
return Assume(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name)
def on_Cover(self, stmt):
return Cover(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en, name=stmt.name)
def on_Switch(self, stmt): def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
@ -351,9 +324,8 @@ class DomainCollector(ValueVisitor, StatementVisitor):
pass pass
on_Const = on_ignore on_Const = on_ignore
on_AnyConst = on_ignore
on_AnySeq = on_ignore
on_Signal = on_ignore on_Signal = on_ignore
on_AnyValue = on_ignore
def on_ClockSignal(self, value): def on_ClockSignal(self, value):
self._add_used_domain(value.domain) self._add_used_domain(value.domain)
@ -388,13 +360,9 @@ class DomainCollector(ValueVisitor, StatementVisitor):
self.on_value(stmt.lhs) self.on_value(stmt.lhs)
self.on_value(stmt.rhs) self.on_value(stmt.rhs)
def on_property(self, stmt): def on_Property(self, stmt):
self.on_value(stmt.test) self.on_value(stmt.test)
on_Assert = on_property
on_Assume = on_property
on_Cover = on_property
def on_Switch(self, stmt): def on_Switch(self, stmt):
self.on_value(stmt.test) self.on_value(stmt.test)
for stmts in stmt.cases.values(): for stmts in stmt.cases.values():
@ -544,10 +512,8 @@ class SwitchCleaner(StatementVisitor):
def on_ignore(self, stmt): def on_ignore(self, stmt):
return stmt return stmt
on_Assign = on_ignore on_Assign = on_ignore
on_Assert = on_ignore on_Property = on_ignore
on_Assume = on_ignore
on_Cover = on_ignore
def on_Switch(self, stmt): def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
@ -595,15 +561,11 @@ class LHSGroupAnalyzer(StatementVisitor):
if lhs_signals: if lhs_signals:
self.unify(*stmt._lhs_signals()) self.unify(*stmt._lhs_signals())
def on_property(self, stmt): def on_Property(self, stmt):
lhs_signals = stmt._lhs_signals() lhs_signals = stmt._lhs_signals()
if lhs_signals: if lhs_signals:
self.unify(*stmt._lhs_signals()) self.unify(*stmt._lhs_signals())
on_Assert = on_property
on_Assume = on_property
on_Cover = on_property
def on_Switch(self, stmt): def on_Switch(self, stmt):
for case_stmts in stmt.cases.values(): for case_stmts in stmt.cases.values():
self.on_statements(case_stmts) self.on_statements(case_stmts)
@ -630,15 +592,11 @@ class LHSGroupFilter(SwitchCleaner):
if any_lhs_signal in self.signals: if any_lhs_signal in self.signals:
return stmt return stmt
def on_property(self, stmt): def on_Property(self, stmt):
any_lhs_signal = next(iter(stmt._lhs_signals())) any_lhs_signal = next(iter(stmt._lhs_signals()))
if any_lhs_signal in self.signals: if any_lhs_signal in self.signals:
return stmt return stmt
on_Assert = on_property
on_Assume = on_property
on_Cover = on_property
class _ControlInserter(FragmentTransformer): class _ControlInserter(FragmentTransformer):
def __init__(self, controls): def __init__(self, controls):

View file

@ -97,10 +97,7 @@ class _ValueCompiler(ValueVisitor, _Compiler):
def on_ResetSignal(self, value): def on_ResetSignal(self, value):
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
def on_AnyConst(self, value): def on_AnyValue(self, value):
raise NotImplementedError # :nocov:
def on_AnySeq(self, value):
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
def on_Initial(self, value): def on_Initial(self, value):
@ -389,13 +386,7 @@ class _StatementCompiler(StatementVisitor, _Compiler):
with self.emitter.indent(): with self.emitter.indent():
self(stmts) self(stmts)
def on_Assert(self, stmt): def on_Property(self, stmt):
raise NotImplementedError # :nocov:
def on_Assume(self, stmt):
raise NotImplementedError # :nocov:
def on_Cover(self, stmt):
raise NotImplementedError # :nocov: raise NotImplementedError # :nocov:
@classmethod @classmethod