hdl.{ast,dsl}: add Signal.enum; coerce Enum to Value; accept Enum patterns.
Fixes #207.
This commit is contained in:
parent
e8f79c5539
commit
4777a7b3a2
|
@ -30,6 +30,21 @@ class DUID:
|
||||||
DUID.__next_uid += 1
|
DUID.__next_uid += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _enum_shape(enum_type):
|
||||||
|
min_value = min(member.value for member in enum_type)
|
||||||
|
max_value = max(member.value for member in enum_type)
|
||||||
|
if not isinstance(min_value, int) or not isinstance(max_value, int):
|
||||||
|
raise TypeError("Only enumerations with integer values can be converted to nMigen values")
|
||||||
|
sign = min_value < 0 or max_value < 0
|
||||||
|
bits = max(bits_for(min_value, sign), bits_for(max_value, sign))
|
||||||
|
return (bits, sign)
|
||||||
|
|
||||||
|
|
||||||
|
def _enum_to_bits(enum_value):
|
||||||
|
bits, sign = _enum_shape(type(enum_value))
|
||||||
|
return format(enum_value.value & ((1 << bits) - 1), "b").rjust(bits, "0")
|
||||||
|
|
||||||
|
|
||||||
class Value(metaclass=ABCMeta):
|
class Value(metaclass=ABCMeta):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wrap(obj):
|
def wrap(obj):
|
||||||
|
@ -39,6 +54,8 @@ class Value(metaclass=ABCMeta):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, (bool, int)):
|
elif isinstance(obj, (bool, int)):
|
||||||
return Const(obj)
|
return Const(obj)
|
||||||
|
elif isinstance(obj, Enum):
|
||||||
|
return Const(obj.value, _enum_shape(type(obj)))
|
||||||
else:
|
else:
|
||||||
raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
|
raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
|
||||||
|
|
||||||
|
@ -240,6 +257,10 @@ class Value(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
|
if not isinstance(pattern, (int, str, Enum)):
|
||||||
|
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
|
||||||
|
"not {!r}"
|
||||||
|
.format(pattern))
|
||||||
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
|
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
|
||||||
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
|
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
|
||||||
"bits"
|
"bits"
|
||||||
|
@ -248,9 +269,6 @@ class Value(metaclass=ABCMeta):
|
||||||
raise SyntaxError("Match pattern '{}' must have the same width as match value "
|
raise SyntaxError("Match pattern '{}' must have the same width as match value "
|
||||||
"(which is {})"
|
"(which is {})"
|
||||||
.format(pattern, len(self)))
|
.format(pattern, len(self)))
|
||||||
if not isinstance(pattern, (int, str)):
|
|
||||||
raise SyntaxError("Match pattern must be an integer or a string, not {}"
|
|
||||||
.format(pattern))
|
|
||||||
if isinstance(pattern, int) and bits_for(pattern) > len(self):
|
if isinstance(pattern, int) and bits_for(pattern) > len(self):
|
||||||
warnings.warn("Match pattern '{:b}' is wider than match value "
|
warnings.warn("Match pattern '{:b}' is wider than match value "
|
||||||
"(which has width {}); comparison will never be true"
|
"(which has width {}); comparison will never be true"
|
||||||
|
@ -259,7 +277,9 @@ class Value(metaclass=ABCMeta):
|
||||||
continue
|
continue
|
||||||
if isinstance(pattern, int):
|
if isinstance(pattern, int):
|
||||||
matches.append(self == pattern)
|
matches.append(self == pattern)
|
||||||
elif isinstance(pattern, str):
|
elif isinstance(pattern, (str, Enum)):
|
||||||
|
if isinstance(pattern, Enum):
|
||||||
|
pattern = _enum_to_bits(pattern)
|
||||||
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
|
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
|
||||||
pattern = int(pattern.replace("-", "0"), 2)
|
pattern = int(pattern.replace("-", "0"), 2)
|
||||||
matches.append((self & mask) == pattern)
|
matches.append((self & mask) == pattern)
|
||||||
|
@ -784,6 +804,19 @@ class Signal(Value, DUID):
|
||||||
bits_for(value_range.stop - value_range.step, signed))
|
bits_for(value_range.stop - value_range.step, signed))
|
||||||
return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
|
return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
|
||||||
|
"""Create Signal that can represent a given enumeration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
enum : type (inheriting from :class:`enum.Enum`)
|
||||||
|
Enumeration to base this Signal on.
|
||||||
|
"""
|
||||||
|
if not issubclass(enum_type, Enum):
|
||||||
|
raise TypeError("Type {!r} is not an enumeration")
|
||||||
|
return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
|
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
|
||||||
"""Create Signal based on another.
|
"""Create Signal based on another.
|
||||||
|
@ -1230,6 +1263,8 @@ class Switch(Statement):
|
||||||
key = "{:0{}b}".format(key, len(self.test))
|
key = "{:0{}b}".format(key, len(self.test))
|
||||||
elif isinstance(key, str):
|
elif isinstance(key, str):
|
||||||
pass
|
pass
|
||||||
|
elif isinstance(key, Enum):
|
||||||
|
key = _enum_to_bits(key)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Object '{!r}' cannot be used as a switch key"
|
raise TypeError("Object '{!r}' cannot be used as a switch key"
|
||||||
.format(key))
|
.format(key))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ..tools import flatten, bits_for, deprecated
|
from ..tools import flatten, bits_for, deprecated
|
||||||
|
@ -264,6 +265,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
||||||
switch_data = self._get_ctrl("Switch")
|
switch_data = self._get_ctrl("Switch")
|
||||||
new_patterns = ()
|
new_patterns = ()
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
|
if not isinstance(pattern, (int, str, Enum)):
|
||||||
|
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
|
||||||
|
"not {!r}"
|
||||||
|
.format(pattern))
|
||||||
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
|
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
|
||||||
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits"
|
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits"
|
||||||
.format(pattern))
|
.format(pattern))
|
||||||
|
@ -271,9 +276,6 @@ class Module(_ModuleBuilderRoot, Elaboratable):
|
||||||
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
|
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
|
||||||
"(which is {})"
|
"(which is {})"
|
||||||
.format(pattern, len(switch_data["test"])))
|
.format(pattern, len(switch_data["test"])))
|
||||||
if not isinstance(pattern, (int, str)):
|
|
||||||
raise SyntaxError("Case pattern must be an integer or a string, not {}"
|
|
||||||
.format(pattern))
|
|
||||||
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
|
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
|
||||||
warnings.warn("Case pattern '{:b}' is wider than switch value "
|
warnings.warn("Case pattern '{:b}' is wider than switch value "
|
||||||
"(which has width {}); comparison will never be true"
|
"(which has width {}); comparison will never be true"
|
||||||
|
|
|
@ -5,6 +5,23 @@ from ..hdl.ast import *
|
||||||
from .tools import *
|
from .tools import *
|
||||||
|
|
||||||
|
|
||||||
|
class UnsignedEnum(Enum):
|
||||||
|
FOO = 1
|
||||||
|
BAR = 2
|
||||||
|
BAZ = 3
|
||||||
|
|
||||||
|
|
||||||
|
class SignedEnum(Enum):
|
||||||
|
FOO = -1
|
||||||
|
BAR = 0
|
||||||
|
BAZ = +1
|
||||||
|
|
||||||
|
|
||||||
|
class StringEnum(Enum):
|
||||||
|
FOO = "a"
|
||||||
|
BAR = "b"
|
||||||
|
|
||||||
|
|
||||||
class ValueTestCase(FHDLTestCase):
|
class ValueTestCase(FHDLTestCase):
|
||||||
def test_wrap(self):
|
def test_wrap(self):
|
||||||
self.assertIsInstance(Value.wrap(0), Const)
|
self.assertIsInstance(Value.wrap(0), Const)
|
||||||
|
@ -15,6 +32,19 @@ class ValueTestCase(FHDLTestCase):
|
||||||
msg="Object ''str'' is not an nMigen value"):
|
msg="Object ''str'' is not an nMigen value"):
|
||||||
Value.wrap("str")
|
Value.wrap("str")
|
||||||
|
|
||||||
|
def test_wrap_enum(self):
|
||||||
|
e1 = Value.wrap(UnsignedEnum.FOO)
|
||||||
|
self.assertIsInstance(e1, Const)
|
||||||
|
self.assertEqual(e1.shape(), (2, False))
|
||||||
|
e2 = Value.wrap(SignedEnum.FOO)
|
||||||
|
self.assertIsInstance(e2, Const)
|
||||||
|
self.assertEqual(e2.shape(), (2, True))
|
||||||
|
|
||||||
|
def test_wrap_enum_wrong(self):
|
||||||
|
with self.assertRaises(TypeError,
|
||||||
|
msg="Only enumerations with integer values can be converted to nMigen values"):
|
||||||
|
Value.wrap(StringEnum.FOO)
|
||||||
|
|
||||||
def test_bool(self):
|
def test_bool(self):
|
||||||
with self.assertRaises(TypeError,
|
with self.assertRaises(TypeError,
|
||||||
msg="Attempted to convert nMigen value to boolean"):
|
msg="Attempted to convert nMigen value to boolean"):
|
||||||
|
@ -276,6 +306,12 @@ class OperatorTestCase(FHDLTestCase):
|
||||||
(== (& (sig s) (const 4'd12)) (const 4'd8))
|
(== (& (sig s) (const 4'd12)) (const 4'd8))
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
def test_matches_enum(self):
|
||||||
|
s = Signal.enum(SignedEnum)
|
||||||
|
self.assertRepr(s.matches(SignedEnum.FOO), """
|
||||||
|
(== (& (sig s) (const 2'd3)) (const 2'd3))
|
||||||
|
""")
|
||||||
|
|
||||||
def test_matches_width_wrong(self):
|
def test_matches_width_wrong(self):
|
||||||
s = Signal(4)
|
s = Signal(4)
|
||||||
with self.assertRaises(SyntaxError,
|
with self.assertRaises(SyntaxError,
|
||||||
|
@ -295,7 +331,7 @@ class OperatorTestCase(FHDLTestCase):
|
||||||
def test_matches_pattern_wrong(self):
|
def test_matches_pattern_wrong(self):
|
||||||
s = Signal(4)
|
s = Signal(4)
|
||||||
with self.assertRaises(SyntaxError,
|
with self.assertRaises(SyntaxError,
|
||||||
msg="Match pattern must be an integer or a string, not 1.0"):
|
msg="Match pattern must be an integer, a string, or an enumeration, not 1.0"):
|
||||||
s.matches(1.0)
|
s.matches(1.0)
|
||||||
|
|
||||||
def test_hash(self):
|
def test_hash(self):
|
||||||
|
@ -605,6 +641,13 @@ class SignalTestCase(FHDLTestCase):
|
||||||
self.assertEqual(s.decoder(1), "RED/1")
|
self.assertEqual(s.decoder(1), "RED/1")
|
||||||
self.assertEqual(s.decoder(3), "3")
|
self.assertEqual(s.decoder(3), "3")
|
||||||
|
|
||||||
|
def test_enum(self):
|
||||||
|
s1 = Signal.enum(UnsignedEnum)
|
||||||
|
self.assertEqual(s1.shape(), (2, False))
|
||||||
|
s2 = Signal.enum(SignedEnum)
|
||||||
|
self.assertEqual(s2.shape(), (2, True))
|
||||||
|
self.assertEqual(s2.decoder(SignedEnum.FOO), "FOO/-1")
|
||||||
|
|
||||||
|
|
||||||
class ClockSignalTestCase(FHDLTestCase):
|
class ClockSignalTestCase(FHDLTestCase):
|
||||||
def test_domain(self):
|
def test_domain(self):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from ..hdl.ast import *
|
from ..hdl.ast import *
|
||||||
from ..hdl.cd import *
|
from ..hdl.cd import *
|
||||||
|
@ -355,6 +356,23 @@ class DSLTestCase(FHDLTestCase):
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
def test_Switch_enum(self):
|
||||||
|
class Color(Enum):
|
||||||
|
RED = 1
|
||||||
|
BLUE = 2
|
||||||
|
m = Module()
|
||||||
|
se = Signal.enum(Color)
|
||||||
|
with m.Switch(se):
|
||||||
|
with m.Case(Color.RED):
|
||||||
|
m.d.comb += self.c1.eq(1)
|
||||||
|
self.assertRepr(m._statements, """
|
||||||
|
(
|
||||||
|
(switch (sig se)
|
||||||
|
(case 01 (eq (sig c1) (const 1'd1)))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
def test_Case_width_wrong(self):
|
def test_Case_width_wrong(self):
|
||||||
m = Module()
|
m = Module()
|
||||||
with m.Switch(self.w1):
|
with m.Switch(self.w1):
|
||||||
|
@ -385,7 +403,7 @@ class DSLTestCase(FHDLTestCase):
|
||||||
m = Module()
|
m = Module()
|
||||||
with m.Switch(self.w1):
|
with m.Switch(self.w1):
|
||||||
with self.assertRaises(SyntaxError,
|
with self.assertRaises(SyntaxError,
|
||||||
msg="Case pattern must be an integer or a string, not 1.0"):
|
msg="Case pattern must be an integer, a string, or an enumeration, not 1.0"):
|
||||||
with m.Case(1.0):
|
with m.Case(1.0):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue