hdl._ir: Fix Array lowering with 0-width index.

This commit is contained in:
Wanda 2024-03-28 03:18:03 +01:00 committed by Catherine
parent 1d5de80347
commit 5577f4e703
3 changed files with 19 additions and 13 deletions

View file

@ -2,15 +2,27 @@ import contextlib
import functools import functools
import warnings import warnings
import linecache import linecache
import operator
import re import re
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable from collections.abc import Iterable
__all__ = ["flatten", "union", "final", "deprecated", "get_linter_options", __all__ = ["to_binary", "flatten", "union", "final", "deprecated", "get_linter_options",
"get_linter_option"] "get_linter_option"]
def to_binary(n: int, width: int) -> str:
"""Formats ``n`` as exactly ``width`` binary digits, including when ``width`` is 0"""
n = operator.index(n)
width = operator.index(width)
if n not in range(1 << width):
raise ValueError(f"{n} does not fit in {width} bits")
if width == 0:
return ""
return f"{n:0{width}b}"
def flatten(i): def flatten(i):
for e in i: for e in i:
if isinstance(e, str) or not isinstance(e, Iterable): if isinstance(e, str) or not isinstance(e, Iterable):

View file

@ -2768,15 +2768,9 @@ class Switch(Statement):
if isinstance(key, str): if isinstance(key, str):
key = "".join(key.split()) # remove whitespace key = "".join(key.split()) # remove whitespace
elif isinstance(key, int): elif isinstance(key, int):
key = format(key & key_mask, "b").rjust(len(self.test), "0") key = to_binary(key & key_mask, len(self.test))
# fixup for 0-width test
if key_mask == 0:
key = ""
elif isinstance(key, Enum): elif isinstance(key, Enum):
key = format(key.value & key_mask, "b").rjust(len(self.test), "0") key = to_binary(key.value & key_mask, len(self.test))
# fixup for 0-width test
if key_mask == 0:
key = ""
else: else:
raise TypeError("Object {!r} cannot be used as a switch key" raise TypeError("Object {!r} cannot be used as a switch key"
.format(key)) .format(key))

View file

@ -3,7 +3,7 @@ from collections import defaultdict, OrderedDict
import enum import enum
import warnings import warnings
from .._utils import flatten from .._utils import flatten, to_binary
from .. import tracer, _unused from .. import tracer, _unused
from . import _ast, _cd, _ir, _nir from . import _ast, _cd, _ir, _nir
@ -880,7 +880,7 @@ class NetlistEmitter:
conds = [] conds = []
for case_index in range(len(elems)): for case_index in range(len(elems)):
cell = _nir.Matches(module_idx, value=index, cell = _nir.Matches(module_idx, value=index,
patterns=(f"{case_index:0{len(index)}b}",), patterns=(to_binary(case_index, len(index)),),
src_loc=value.src_loc) src_loc=value.src_loc)
subcond, = self.netlist.add_value_cell(1, cell) subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond) conds.append(subcond)
@ -985,7 +985,7 @@ class NetlistEmitter:
conds = [] conds = []
for case_index in range(num_cases): for case_index in range(num_cases):
cell = _nir.Matches(module_idx, value=offset, cell = _nir.Matches(module_idx, value=offset,
patterns=(f"{case_index:0{len(offset)}b}",), patterns=(to_binary(case_index, len(offset)),),
src_loc=lhs.src_loc) src_loc=lhs.src_loc)
subcond, = self.netlist.add_value_cell(1, cell) subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond) conds.append(subcond)
@ -1006,7 +1006,7 @@ class NetlistEmitter:
conds = [] conds = []
for case_index in range(len(lhs.elems)): for case_index in range(len(lhs.elems)):
cell = _nir.Matches(module_idx, value=index, cell = _nir.Matches(module_idx, value=index,
patterns=(f"{case_index:0{len(index)}b}",), patterns=(to_binary(case_index, len(index)),),
src_loc=lhs.src_loc) src_loc=lhs.src_loc)
subcond, = self.netlist.add_value_cell(1, cell) subcond, = self.netlist.add_value_cell(1, cell)
conds.append(subcond) conds.append(subcond)