diff --git a/amaranth/_utils.py b/amaranth/_utils.py index 886de59..35adcf9 100644 --- a/amaranth/_utils.py +++ b/amaranth/_utils.py @@ -2,15 +2,27 @@ import contextlib import functools import warnings import linecache +import operator import re from collections import OrderedDict from collections.abc import Iterable -__all__ = ["flatten", "union", "final", "deprecated", "get_linter_options", +__all__ = ["to_binary", "flatten", "union", "final", "deprecated", "get_linter_options", "get_linter_option"] +def to_binary(n: int, width: int) -> str: + """Formats ``n`` as exactly ``width`` binary digits, including when ``width`` is 0""" + n = operator.index(n) + width = operator.index(width) + if n not in range(1 << width): + raise ValueError(f"{n} does not fit in {width} bits") + if width == 0: + return "" + return f"{n:0{width}b}" + + def flatten(i): for e in i: if isinstance(e, str) or not isinstance(e, Iterable): diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index cd2b500..3b71bc1 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -2768,15 +2768,9 @@ class Switch(Statement): if isinstance(key, str): key = "".join(key.split()) # remove whitespace elif isinstance(key, int): - key = format(key & key_mask, "b").rjust(len(self.test), "0") - # fixup for 0-width test - if key_mask == 0: - key = "" + key = to_binary(key & key_mask, len(self.test)) elif isinstance(key, Enum): - key = format(key.value & key_mask, "b").rjust(len(self.test), "0") - # fixup for 0-width test - if key_mask == 0: - key = "" + key = to_binary(key.value & key_mask, len(self.test)) else: raise TypeError("Object {!r} cannot be used as a switch key" .format(key)) diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 04c1d9c..10b9ab7 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -3,7 +3,7 @@ from collections import defaultdict, OrderedDict import enum import warnings -from .._utils import flatten +from .._utils import flatten, to_binary from .. import tracer, _unused from . import _ast, _cd, _ir, _nir @@ -880,7 +880,7 @@ class NetlistEmitter: conds = [] for case_index in range(len(elems)): cell = _nir.Matches(module_idx, value=index, - patterns=(f"{case_index:0{len(index)}b}",), + patterns=(to_binary(case_index, len(index)),), src_loc=value.src_loc) subcond, = self.netlist.add_value_cell(1, cell) conds.append(subcond) @@ -985,7 +985,7 @@ class NetlistEmitter: conds = [] for case_index in range(num_cases): cell = _nir.Matches(module_idx, value=offset, - patterns=(f"{case_index:0{len(offset)}b}",), + patterns=(to_binary(case_index, len(offset)),), src_loc=lhs.src_loc) subcond, = self.netlist.add_value_cell(1, cell) conds.append(subcond) @@ -1006,7 +1006,7 @@ class NetlistEmitter: conds = [] for case_index in range(len(lhs.elems)): cell = _nir.Matches(module_idx, value=index, - patterns=(f"{case_index:0{len(index)}b}",), + patterns=(to_binary(case_index, len(index)),), src_loc=lhs.src_loc) subcond, = self.netlist.add_value_cell(1, cell) conds.append(subcond)