hdl.rec: migrate Record from UserValue to ValueCastable.

Closes #528.
This commit is contained in:
awygle 2020-11-05 17:10:39 -08:00 committed by GitHub
parent 06c734992f
commit abbebf8efe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 10 deletions

View file

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from collections import OrderedDict from collections import OrderedDict
from functools import reduce from functools import reduce, wraps
from .. import tracer from .. import tracer
from .._utils import union, deprecated from .._utils import union, deprecated
@ -85,8 +85,7 @@ class Layout:
return "Layout([{}])".format(", ".join(field_reprs)) return "Layout([{}])".format(", ".join(field_reprs))
# Unlike most Values, Record *can* be subclassed. class Record(ValueCastable):
class Record(UserValue):
@staticmethod @staticmethod
def like(other, *, name=None, name_suffix=None, src_loc_at=0): def like(other, *, name=None, name_suffix=None, src_loc_at=0):
if name is not None: if name is not None:
@ -114,8 +113,6 @@ class Record(UserValue):
return Record(other.layout, name=new_name, fields=fields, src_loc_at=1) return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
def __init__(self, layout, *, name=None, fields=None, src_loc_at=0): def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
if name is None: if name is None:
name = tracer.get_var_name(depth=2 + src_loc_at, default=None) name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
@ -146,7 +143,17 @@ class Record(UserValue):
src_loc_at=1 + src_loc_at) src_loc_at=1 + src_loc_at)
def __getattr__(self, name): def __getattr__(self, name):
return self[name] # must check `getattr` before `self` - we need to hit Value methods before fields
try:
value_attr = getattr(Value, name)
if callable(value_attr):
@wraps(value_attr)
def _wrapper(*args, **kwargs):
return value_attr(self, *args, **kwargs)
return _wrapper
return value_attr
except AttributeError:
return self[name]
def __getitem__(self, item): def __getitem__(self, item):
if isinstance(item, str): if isinstance(item, str):
@ -166,11 +173,23 @@ class Record(UserValue):
if field_name in item if field_name in item
}) })
else: else:
return super().__getitem__(item) try:
return Value.__getitem__(self, item)
except KeyError:
if self.name is None:
reference = "Unnamed record"
else:
reference = "Record '{}'".format(self.name)
raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
.format(reference, item, ", ".join(self.fields))) from None
def lower(self): @ValueCastable.lowermethod
def as_value(self):
return Cat(self.fields.values()) return Cat(self.fields.values())
def __len__(self):
return len(self.as_value())
def _lhs_signals(self): def _lhs_signals(self):
return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())

View file

@ -135,8 +135,8 @@ class RecordTestCase(FHDLTestCase):
("stb", 1), ("stb", 1),
]) ])
self.assertEqual(repr(r[0]), "(slice (rec r data stb) 0:1)") self.assertEqual(repr(r[0]), "(slice (cat (sig r__data) (sig r__stb)) 0:1)")
self.assertEqual(repr(r[0:3]), "(slice (rec r data stb) 0:3)") self.assertEqual(repr(r[0:3]), "(slice (cat (sig r__data) (sig r__stb)) 0:3)")
def test_wrong_field(self): def test_wrong_field(self):
r = Record([ r = Record([