hdl.rec: migrate Record from UserValue to ValueCastable.

Closes #528.
This commit is contained in:
awygle 2020-11-05 17:10:39 -08:00 committed by GitHub
parent 06c734992f
commit abbebf8efe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 10 deletions

View file

@ -1,6 +1,6 @@
from enum import Enum
from collections import OrderedDict
from functools import reduce
from functools import reduce, wraps
from .. import tracer
from .._utils import union, deprecated
@ -85,8 +85,7 @@ class Layout:
return "Layout([{}])".format(", ".join(field_reprs))
# Unlike most Values, Record *can* be subclassed.
class Record(UserValue):
class Record(ValueCastable):
@staticmethod
def like(other, *, name=None, name_suffix=None, src_loc_at=0):
if name is not None:
@ -114,8 +113,6 @@ class Record(UserValue):
return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
if name is None:
name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
@ -146,7 +143,17 @@ class Record(UserValue):
src_loc_at=1 + src_loc_at)
def __getattr__(self, name):
return self[name]
# must check `getattr` before `self` - we need to hit Value methods before fields
try:
value_attr = getattr(Value, name)
if callable(value_attr):
@wraps(value_attr)
def _wrapper(*args, **kwargs):
return value_attr(self, *args, **kwargs)
return _wrapper
return value_attr
except AttributeError:
return self[name]
def __getitem__(self, item):
if isinstance(item, str):
@ -166,11 +173,23 @@ class Record(UserValue):
if field_name in item
})
else:
return super().__getitem__(item)
try:
return Value.__getitem__(self, item)
except KeyError:
if self.name is None:
reference = "Unnamed record"
else:
reference = "Record '{}'".format(self.name)
raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
.format(reference, item, ", ".join(self.fields))) from None
def lower(self):
@ValueCastable.lowermethod
def as_value(self):
return Cat(self.fields.values())
def __len__(self):
return len(self.as_value())
def _lhs_signals(self):
return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())

View file

@ -135,8 +135,8 @@ class RecordTestCase(FHDLTestCase):
("stb", 1),
])
self.assertEqual(repr(r[0]), "(slice (rec r data stb) 0:1)")
self.assertEqual(repr(r[0:3]), "(slice (rec r data stb) 0:3)")
self.assertEqual(repr(r[0]), "(slice (cat (sig r__data) (sig r__stb)) 0:1)")
self.assertEqual(repr(r[0:3]), "(slice (cat (sig r__data) (sig r__stb)) 0:3)")
def test_wrong_field(self):
r = Record([