hdl.rec: make Record inherit from UserValue.

Closes #354.
This commit is contained in:
anuejn 2020-04-16 18:46:55 +02:00 committed by GitHub
parent b4af217ed0
commit ff6c0327a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 24 deletions

View file

@ -374,9 +374,6 @@ class _ValueCompiler(ValueVisitor, _Compiler):
def on_ResetSignal(self, value):
raise NotImplementedError # :nocov:
def on_Record(self, value):
return self(Cat(value.fields.values()))
def on_AnyConst(self, value):
raise NotImplementedError # :nocov:

View file

@ -365,9 +365,6 @@ class _ValueCompiler(xfrm.ValueVisitor):
def on_Initial(self, value):
raise NotImplementedError # :nocov:
def on_Record(self, value):
return self(ast.Cat(value.fields.values()))
def on_Cat(self, value):
return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts])))
@ -378,7 +375,11 @@ class _ValueCompiler(xfrm.ValueVisitor):
if value.start == 0 and value.stop == len(value.value):
return self(value.value)
sigspec = self._prepare_value_for_Slice(value.value)
if isinstance(value.value, ast.UserValue):
sigspec = self._prepare_value_for_Slice(value.value._lazy_lower())
else:
sigspec = self._prepare_value_for_Slice(value.value)
if value.start == value.stop:
return "{}"
elif value.start + 1 == value.stop:
@ -644,7 +645,7 @@ class _LHSValueCompiler(_ValueCompiler):
return wire_next or wire_curr
def _prepare_value_for_Slice(self, value):
assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat, rec.Record))
assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat))
return self(value)
def on_Part(self, value):

View file

@ -1188,7 +1188,10 @@ class UserValue(Value):
def _lazy_lower(self):
if self.__lowered is None:
self.__lowered = Value.cast(self.lower())
lowered = self.lower()
if isinstance(lowered, UserValue):
lowered = lowered._lazy_lower()
self.__lowered = Value.cast(lowered)
return self.__lowered
def shape(self):

View file

@ -85,7 +85,7 @@ class Layout:
# Unlike most Values, Record *can* be subclassed.
class Record(Value):
class Record(UserValue):
@staticmethod
def like(other, *, name=None, name_suffix=None, src_loc_at=0):
if name is not None:
@ -113,6 +113,8 @@ class Record(Value):
return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
if name is None:
name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
@ -165,8 +167,8 @@ class Record(Value):
else:
return super().__getitem__(item)
def shape(self):
return Shape(sum(len(f) for f in self.fields.values()))
def lower(self):
return Cat(self.fields.values())
def _lhs_signals(self):
return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())

View file

@ -38,10 +38,6 @@ class ValueVisitor(metaclass=ABCMeta):
def on_Signal(self, value):
pass # :nocov:
@abstractmethod
def on_Record(self, value):
pass # :nocov:
@abstractmethod
def on_ClockSignal(self, value):
pass # :nocov:
@ -98,9 +94,6 @@ class ValueVisitor(metaclass=ABCMeta):
elif isinstance(value, Signal):
# Uses `isinstance()` and not `type() is` because nmigen.compat requires it.
new_value = self.on_Signal(value)
elif isinstance(value, Record):
# Uses `isinstance()` and not `type() is` to allow inheriting from Record.
new_value = self.on_Record(value)
elif type(value) is ClockSignal:
new_value = self.on_ClockSignal(value)
elif type(value) is ResetSignal:
@ -147,9 +140,6 @@ class ValueTransformer(ValueVisitor):
def on_Signal(self, value):
return value
def on_Record(self, value):
return value
def on_ClockSignal(self, value):
return value
@ -372,8 +362,6 @@ class DomainCollector(ValueVisitor, StatementVisitor):
def on_ResetSignal(self, value):
self._add_used_domain(value.domain)
on_Record = on_ignore
def on_Operator(self, value):
for o in value.operands:
self.on_value(o)

View file

@ -916,6 +916,14 @@ class UserValueTestCase(FHDLTestCase):
self.assertEqual(uv.shape(), unsigned(1))
self.assertEqual(uv.lower_count, 1)
def test_lower_to_user_value(self):
uv = MockUserValue(MockUserValue(1))
self.assertEqual(uv.shape(), unsigned(1))
self.assertIsInstance(uv.shape(), Shape)
uv.lowered = MockUserValue(2)
self.assertEqual(uv.shape(), unsigned(1))
self.assertEqual(uv.lower_count, 1)
class SampleTestCase(FHDLTestCase):
def test_const(self):

View file

@ -620,3 +620,12 @@ class UserValueTestCase(FHDLTestCase):
)
)
""")
class UserValueRecursiveTestCase(UserValueTestCase):
def setUp(self):
self.s = Signal()
self.c = Signal()
self.uv = MockUserValue(MockUserValue(self.s))
# inherit the test_lower method from UserValueTestCase because the checks are the same