hdl.{ast,dsl}: add Signal.enum; coerce Enum to Value; accept Enum patterns.

Fixes #207.
This commit is contained in:
whitequark 2019-09-16 18:59:28 +00:00
parent e8f79c5539
commit 4777a7b3a2
4 changed files with 107 additions and 9 deletions

View file

@ -30,6 +30,21 @@ class DUID:
DUID.__next_uid += 1 DUID.__next_uid += 1
def _enum_shape(enum_type):
min_value = min(member.value for member in enum_type)
max_value = max(member.value for member in enum_type)
if not isinstance(min_value, int) or not isinstance(max_value, int):
raise TypeError("Only enumerations with integer values can be converted to nMigen values")
sign = min_value < 0 or max_value < 0
bits = max(bits_for(min_value, sign), bits_for(max_value, sign))
return (bits, sign)
def _enum_to_bits(enum_value):
bits, sign = _enum_shape(type(enum_value))
return format(enum_value.value & ((1 << bits) - 1), "b").rjust(bits, "0")
class Value(metaclass=ABCMeta): class Value(metaclass=ABCMeta):
@staticmethod @staticmethod
def wrap(obj): def wrap(obj):
@ -39,6 +54,8 @@ class Value(metaclass=ABCMeta):
return obj return obj
elif isinstance(obj, (bool, int)): elif isinstance(obj, (bool, int)):
return Const(obj) return Const(obj)
elif isinstance(obj, Enum):
return Const(obj.value, _enum_shape(type(obj)))
else: else:
raise TypeError("Object '{!r}' is not an nMigen value".format(obj)) raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
@ -240,6 +257,10 @@ class Value(metaclass=ABCMeta):
""" """
matches = [] matches = []
for pattern in patterns: for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern): if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits" "bits"
@ -248,9 +269,6 @@ class Value(metaclass=ABCMeta):
raise SyntaxError("Match pattern '{}' must have the same width as match value " raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})" "(which is {})"
.format(pattern, len(self))) .format(pattern, len(self)))
if not isinstance(pattern, (int, str)):
raise SyntaxError("Match pattern must be an integer or a string, not {}"
.format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(self): if isinstance(pattern, int) and bits_for(pattern) > len(self):
warnings.warn("Match pattern '{:b}' is wider than match value " warnings.warn("Match pattern '{:b}' is wider than match value "
"(which has width {}); comparison will never be true" "(which has width {}); comparison will never be true"
@ -259,7 +277,9 @@ class Value(metaclass=ABCMeta):
continue continue
if isinstance(pattern, int): if isinstance(pattern, int):
matches.append(self == pattern) matches.append(self == pattern)
elif isinstance(pattern, str): elif isinstance(pattern, (str, Enum)):
if isinstance(pattern, Enum):
pattern = _enum_to_bits(pattern)
mask = int(pattern.replace("0", "1").replace("-", "0"), 2) mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2) pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern) matches.append((self & mask) == pattern)
@ -784,6 +804,19 @@ class Signal(Value, DUID):
bits_for(value_range.stop - value_range.step, signed)) bits_for(value_range.stop - value_range.step, signed))
return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs) return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
@classmethod
def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
"""Create Signal that can represent a given enumeration.
Parameters
----------
enum : type (inheriting from :class:`enum.Enum`)
Enumeration to base this Signal on.
"""
if not issubclass(enum_type, Enum):
raise TypeError("Type {!r} is not an enumeration")
return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
@classmethod @classmethod
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs): def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
"""Create Signal based on another. """Create Signal based on another.
@ -1230,6 +1263,8 @@ class Switch(Statement):
key = "{:0{}b}".format(key, len(self.test)) key = "{:0{}b}".format(key, len(self.test))
elif isinstance(key, str): elif isinstance(key, str):
pass pass
elif isinstance(key, Enum):
key = _enum_to_bits(key)
else: else:
raise TypeError("Object '{!r}' cannot be used as a switch key" raise TypeError("Object '{!r}' cannot be used as a switch key"
.format(key)) .format(key))

View file

@ -1,6 +1,7 @@
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum
import warnings import warnings
from ..tools import flatten, bits_for, deprecated from ..tools import flatten, bits_for, deprecated
@ -264,6 +265,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
switch_data = self._get_ctrl("Switch") switch_data = self._get_ctrl("Switch")
new_patterns = () new_patterns = ()
for pattern in patterns: for pattern in patterns:
if not isinstance(pattern, (int, str, Enum)):
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
"not {!r}"
.format(pattern))
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern): if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits" raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits"
.format(pattern)) .format(pattern))
@ -271,9 +276,6 @@ class Module(_ModuleBuilderRoot, Elaboratable):
raise SyntaxError("Case pattern '{}' must have the same width as switch value " raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})" "(which is {})"
.format(pattern, len(switch_data["test"]))) .format(pattern, len(switch_data["test"])))
if not isinstance(pattern, (int, str)):
raise SyntaxError("Case pattern must be an integer or a string, not {}"
.format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]): if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
warnings.warn("Case pattern '{:b}' is wider than switch value " warnings.warn("Case pattern '{:b}' is wider than switch value "
"(which has width {}); comparison will never be true" "(which has width {}); comparison will never be true"

View file

@ -5,6 +5,23 @@ from ..hdl.ast import *
from .tools import * from .tools import *
class UnsignedEnum(Enum):
FOO = 1
BAR = 2
BAZ = 3
class SignedEnum(Enum):
FOO = -1
BAR = 0
BAZ = +1
class StringEnum(Enum):
FOO = "a"
BAR = "b"
class ValueTestCase(FHDLTestCase): class ValueTestCase(FHDLTestCase):
def test_wrap(self): def test_wrap(self):
self.assertIsInstance(Value.wrap(0), Const) self.assertIsInstance(Value.wrap(0), Const)
@ -15,6 +32,19 @@ class ValueTestCase(FHDLTestCase):
msg="Object ''str'' is not an nMigen value"): msg="Object ''str'' is not an nMigen value"):
Value.wrap("str") Value.wrap("str")
def test_wrap_enum(self):
e1 = Value.wrap(UnsignedEnum.FOO)
self.assertIsInstance(e1, Const)
self.assertEqual(e1.shape(), (2, False))
e2 = Value.wrap(SignedEnum.FOO)
self.assertIsInstance(e2, Const)
self.assertEqual(e2.shape(), (2, True))
def test_wrap_enum_wrong(self):
with self.assertRaises(TypeError,
msg="Only enumerations with integer values can be converted to nMigen values"):
Value.wrap(StringEnum.FOO)
def test_bool(self): def test_bool(self):
with self.assertRaises(TypeError, with self.assertRaises(TypeError,
msg="Attempted to convert nMigen value to boolean"): msg="Attempted to convert nMigen value to boolean"):
@ -276,6 +306,12 @@ class OperatorTestCase(FHDLTestCase):
(== (& (sig s) (const 4'd12)) (const 4'd8)) (== (& (sig s) (const 4'd12)) (const 4'd8))
""") """)
def test_matches_enum(self):
s = Signal.enum(SignedEnum)
self.assertRepr(s.matches(SignedEnum.FOO), """
(== (& (sig s) (const 2'd3)) (const 2'd3))
""")
def test_matches_width_wrong(self): def test_matches_width_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaises(SyntaxError, with self.assertRaises(SyntaxError,
@ -295,7 +331,7 @@ class OperatorTestCase(FHDLTestCase):
def test_matches_pattern_wrong(self): def test_matches_pattern_wrong(self):
s = Signal(4) s = Signal(4)
with self.assertRaises(SyntaxError, with self.assertRaises(SyntaxError,
msg="Match pattern must be an integer or a string, not 1.0"): msg="Match pattern must be an integer, a string, or an enumeration, not 1.0"):
s.matches(1.0) s.matches(1.0)
def test_hash(self): def test_hash(self):
@ -605,6 +641,13 @@ class SignalTestCase(FHDLTestCase):
self.assertEqual(s.decoder(1), "RED/1") self.assertEqual(s.decoder(1), "RED/1")
self.assertEqual(s.decoder(3), "3") self.assertEqual(s.decoder(3), "3")
def test_enum(self):
s1 = Signal.enum(UnsignedEnum)
self.assertEqual(s1.shape(), (2, False))
s2 = Signal.enum(SignedEnum)
self.assertEqual(s2.shape(), (2, True))
self.assertEqual(s2.decoder(SignedEnum.FOO), "FOO/-1")
class ClockSignalTestCase(FHDLTestCase): class ClockSignalTestCase(FHDLTestCase):
def test_domain(self): def test_domain(self):

View file

@ -1,4 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from enum import Enum
from ..hdl.ast import * from ..hdl.ast import *
from ..hdl.cd import * from ..hdl.cd import *
@ -355,6 +356,23 @@ class DSLTestCase(FHDLTestCase):
) )
""") """)
def test_Switch_enum(self):
class Color(Enum):
RED = 1
BLUE = 2
m = Module()
se = Signal.enum(Color)
with m.Switch(se):
with m.Case(Color.RED):
m.d.comb += self.c1.eq(1)
self.assertRepr(m._statements, """
(
(switch (sig se)
(case 01 (eq (sig c1) (const 1'd1)))
)
)
""")
def test_Case_width_wrong(self): def test_Case_width_wrong(self):
m = Module() m = Module()
with m.Switch(self.w1): with m.Switch(self.w1):
@ -385,7 +403,7 @@ class DSLTestCase(FHDLTestCase):
m = Module() m = Module()
with m.Switch(self.w1): with m.Switch(self.w1):
with self.assertRaises(SyntaxError, with self.assertRaises(SyntaxError,
msg="Case pattern must be an integer or a string, not 1.0"): msg="Case pattern must be an integer, a string, or an enumeration, not 1.0"):
with m.Case(1.0): with m.Case(1.0):
pass pass