hdl.{ast,dsl}: add Signal.enum; coerce Enum to Value; accept Enum patterns.

Fixes #207.
This commit is contained in:
whitequark 2019-09-16 18:59:28 +00:00
parent e8f79c5539
commit 4777a7b3a2
4 changed files with 107 additions and 9 deletions

View file

@ -30,6 +30,21 @@ class DUID:
DUID.__next_uid += 1
def _enum_shape(enum_type):
min_value = min(member.value for member in enum_type)
max_value = max(member.value for member in enum_type)
if not isinstance(min_value, int) or not isinstance(max_value, int):
raise TypeError("Only enumerations with integer values can be converted to nMigen values")
sign = min_value < 0 or max_value < 0
bits = max(bits_for(min_value, sign), bits_for(max_value, sign))
return (bits, sign)
def _enum_to_bits(enum_value):
bits, sign = _enum_shape(type(enum_value))
return format(enum_value.value & ((1 << bits) - 1), "b").rjust(bits, "0")
class Value(metaclass=ABCMeta):
@staticmethod
def wrap(obj):
@ -39,6 +54,8 @@ class Value(metaclass=ABCMeta):
return obj
elif isinstance(obj, (bool, int)):
return Const(obj)
elif isinstance(obj, Enum):
return Const(obj.value, _enum_shape(type(obj)))
else:
raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
@ -240,6 +257,10 @@ class Value(metaclass=ABCMeta):
"""
matches = []
for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits"
@ -248,9 +269,6 @@ class Value(metaclass=ABCMeta):
raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})"
.format(pattern, len(self)))
if not isinstance(pattern, (int, str)):
raise SyntaxError("Match pattern must be an integer or a string, not {}"
.format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(self):
warnings.warn("Match pattern '{:b}' is wider than match value "
"(which has width {}); comparison will never be true"
@ -259,7 +277,9 @@ class Value(metaclass=ABCMeta):
continue
if isinstance(pattern, int):
matches.append(self == pattern)
elif isinstance(pattern, str):
elif isinstance(pattern, (str, Enum)):
if isinstance(pattern, Enum):
pattern = _enum_to_bits(pattern)
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern)
@ -784,6 +804,19 @@ class Signal(Value, DUID):
bits_for(value_range.stop - value_range.step, signed))
return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
@classmethod
def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
"""Create Signal that can represent a given enumeration.
Parameters
----------
enum : type (inheriting from :class:`enum.Enum`)
Enumeration to base this Signal on.
"""
if not issubclass(enum_type, Enum):
raise TypeError("Type {!r} is not an enumeration")
return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
@classmethod
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
"""Create Signal based on another.
@ -1230,6 +1263,8 @@ class Switch(Statement):
key = "{:0{}b}".format(key, len(self.test))
elif isinstance(key, str):
pass
elif isinstance(key, Enum):
key = _enum_to_bits(key)
else:
raise TypeError("Object '{!r}' cannot be used as a switch key"
.format(key))

View file

@ -1,6 +1,7 @@
from collections import OrderedDict, namedtuple
from collections.abc import Iterable
from contextlib import contextmanager
from enum import Enum
import warnings
from ..tools import flatten, bits_for, deprecated
@ -264,6 +265,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
switch_data = self._get_ctrl("Switch")
new_patterns = ()
for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits"
.format(pattern))
@ -271,9 +276,6 @@ class Module(_ModuleBuilderRoot, Elaboratable):
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})"
.format(pattern, len(switch_data["test"])))
if not isinstance(pattern, (int, str)):
raise SyntaxError("Case pattern must be an integer or a string, not {}"
.format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
warnings.warn("Case pattern '{:b}' is wider than switch value "
"(which has width {}); comparison will never be true"