back.rtlil: reorganize value compiler into LHS/RHS.

This also implements Cat on LHS.
This commit is contained in:
whitequark 2018-12-16 13:30:20 +00:00
parent ed39748889
commit 9794e732e2

View file

@ -195,7 +195,94 @@ def src(src_loc):
return "{}:{}".format(file, line)
class _ValueTransformer(xfrm.AbstractValueTransformer):
class _ValueCompilerState:
def __init__(self, rtlil):
self.rtlil = rtlil
self.wires = ast.ValueDict()
self.driven = ast.ValueDict()
self.ports = ast.ValueDict()
self.sub_name = None
def add_driven(self, signal, sync):
self.driven[signal] = sync
def add_port(self, signal, kind):
assert kind in ("i", "o", "io")
if kind == "i":
kind = "input"
elif kind == "o":
kind = "output"
elif kind == "io":
kind = "inout"
self.ports[signal] = (len(self.ports), kind)
def resolve(self, signal):
if signal in self.wires:
return self.wires[signal]
if signal in self.ports:
port_id, port_kind = self.ports[signal]
else:
port_id = port_kind = None
if self.sub_name:
wire_name = "{}_{}".format(self.sub_name, signal.name)
else:
wire_name = signal.name
for attr_name, attr_signal in signal.attrs.items():
self.rtlil.attribute(attr_name, attr_signal)
wire_curr = self.rtlil.wire(width=signal.nbits, name=wire_name,
port_id=port_id, port_kind=port_kind,
src=src(signal.src_loc))
if signal in self.driven:
wire_next = self.rtlil.wire(width=signal.nbits, name=wire_curr + "$next",
src=src(signal.src_loc))
else:
wire_next = None
self.wires[signal] = (wire_curr, wire_next)
return wire_curr, wire_next
def resolve_curr(self, signal):
wire_curr, wire_next = self.resolve(signal)
return wire_curr
@contextmanager
def hierarchy(self, sub_name):
try:
self.sub_name = sub_name
yield
finally:
self.sub_name = None
class _ValueCompiler(xfrm.AbstractValueTransformer):
def __init__(self, state):
self.s = state
def on_unknown(self, value):
if value is None:
return None
else:
super().on_unknown(value)
def on_ClockSignal(self, value):
raise NotImplementedError # :nocov:
def on_ResetSignal(self, value):
raise NotImplementedError # :nocov:
def on_Slice(self, value):
if value.end == value.start + 1:
return "{} [{}]".format(self(value.value), value.start)
else:
return "{} [{}:{}]".format(self(value.value), value.end - 1, value.start)
def on_Cat(self, value):
return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.operands])))
class _RHSValueCompiler(_ValueCompiler):
operator_map = {
(1, "~"): "$not",
(1, "-"): "$neg",
@ -220,49 +307,6 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
(3, "m"): "$mux",
}
def __init__(self, rtlil):
self.rtlil = rtlil
self.wires = ast.ValueDict()
self.driven = ast.ValueDict()
self.ports = ast.ValueDict()
self.is_lhs = False
self.sub_name = None
def add_driven(self, signal, sync):
self.driven[signal] = sync
def add_port(self, signal, kind):
assert kind in ("i", "o", "io")
if kind == "i":
kind = "input"
elif kind == "o":
kind = "output"
elif kind == "io":
kind = "inout"
self.ports[signal] = (len(self.ports), kind)
@contextmanager
def lhs(self):
try:
self.is_lhs = True
yield
finally:
self.is_lhs = False
@contextmanager
def hierarchy(self, sub_name):
try:
self.sub_name = sub_name
yield
finally:
self.sub_name = None
def on_unknown(self, value):
if value is None:
return None
else:
super().on_unknown(value)
def on_Const(self, value):
if isinstance(value.value, str):
return "{}'{}".format(value.nbits, value.value)
@ -270,48 +314,15 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
return "{}'{:b}".format(value.nbits, value.value)
def on_Signal(self, value):
if value in self.wires:
wire_curr, wire_next = self.wires[value]
else:
if value in self.ports:
port_id, port_kind = self.ports[value]
else:
port_id = port_kind = None
if self.sub_name:
wire_name = "{}_{}".format(self.sub_name, value.name)
else:
wire_name = value.name
for attr_name, attr_value in value.attrs.items():
self.rtlil.attribute(attr_name, attr_value)
wire_curr = self.rtlil.wire(width=value.nbits, name=wire_name,
port_id=port_id, port_kind=port_kind,
src=src(value.src_loc))
if value in self.driven:
wire_next = self.rtlil.wire(width=value.nbits, name=wire_curr + "$next",
src=src(value.src_loc))
else:
wire_next = None
self.wires[value] = (wire_curr, wire_next)
if self.is_lhs:
if wire_next is None:
raise ValueError("Cannot return lhs for non-driven signal {}".format(repr(value)))
return wire_next
else:
return wire_curr
def on_ClockSignal(self, value):
raise NotImplementedError # :nocov:
def on_ResetSignal(self, value):
raise NotImplementedError # :nocov:
wire_curr, wire_next = self.s.resolve(value)
return wire_curr
def on_Operator_unary(self, value):
arg, = value.operands
arg_bits, arg_sign = arg.shape()
res_bits, res_sign = value.shape()
res = self.rtlil.wire(width=res_bits)
self.rtlil.cell(self.operator_map[(1, value.op)], ports={
res = self.s.rtlil.wire(width=res_bits)
self.s.rtlil.cell(self.operator_map[(1, value.op)], ports={
"\\A": self(arg),
"\\Y": res,
}, params={
@ -327,8 +338,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
value_bits, value_sign = value.shape()
if new_bits > value_bits:
res = self.rtlil.wire(width=new_bits)
self.rtlil.cell("$pos", ports={
res = self.s.rtlil.wire(width=new_bits)
self.s.rtlil.cell("$pos", ports={
"\\A": self(value),
"\\Y": res,
}, params={
@ -353,8 +364,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign)
rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign)
res_bits, res_sign = value.shape()
res = self.rtlil.wire(width=res_bits)
self.rtlil.cell(self.operator_map[(2, value.op)], ports={
res = self.s.rtlil.wire(width=res_bits)
self.s.rtlil.cell(self.operator_map[(2, value.op)], ports={
"\\A": lhs_wire,
"\\B": rhs_wire,
"\\Y": res,
@ -375,8 +386,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
lhs_bits = rhs_bits = res_bits = max(lhs_bits, rhs_bits, res_bits)
lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign)
rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign)
res = self.rtlil.wire(width=res_bits)
self.rtlil.cell("$mux", ports={
res = self.s.rtlil.wire(width=res_bits)
self.s.rtlil.cell("$mux", ports={
"\\A": lhs_wire,
"\\B": rhs_wire,
"\\S": self(sel),
@ -395,20 +406,11 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
assert value.op == "m"
return self.on_Operator_mux(value)
else:
raise TypeError
def on_Slice(self, value):
if value.end == value.start + 1:
return "{} [{}]".format(self(value.value), value.start)
else:
return "{} [{}:{}]".format(self(value.value), value.end - 1, value.start)
raise TypeError # :nocov:
def on_Part(self, value):
raise NotImplementedError
def on_Cat(self, value):
return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.operands])))
def on_Repl(self, value):
return "{{ {} }}".format(" ".join(self(value.value) for _ in range(value.count)))
@ -416,27 +418,52 @@ class _ValueTransformer(xfrm.AbstractValueTransformer):
raise NotImplementedError
class _LHSValueCompiler(_ValueCompiler):
def on_Const(self, value):
raise TypeError # :nocov:
def on_Operator(self, value):
raise TypeError # :nocov:
def on_Signal(self, value):
wire_curr, wire_next = self.s.resolve(value)
if wire_next is None:
raise ValueError("Cannot return lhs for non-driven signal {}".format(repr(value)))
return wire_next
def on_Part(self, value):
raise NotImplementedError
def on_Repl(self, value):
raise TypeError # :nocov:
def on_ArrayProxy(self, value):
raise NotImplementedError
def convert_fragment(builder, fragment, name, top):
with builder.module(name or "anonymous", attrs={"top": 1} if top else {}) as module:
xformer = _ValueTransformer(module)
compiler_state = _ValueCompilerState(module)
rhs_compiler = _RHSValueCompiler(compiler_state)
lhs_compiler = _LHSValueCompiler(compiler_state)
# Register all signals driven in the current fragment. This must be done first, as it
# affects further codegen; e.g. whether sig$next signals will be generated and used.
for domain, signal in fragment.iter_drivers():
xformer.add_driven(signal, sync=domain is not None)
compiler_state.add_driven(signal, sync=domain is not None)
# Transform all signals used as ports in the current fragment eagerly and outside of
# any hierarchy, to make sure they get sensible (non-prefixed) names.
for signal in fragment.ports:
xformer.add_port(signal, fragment.ports[signal])
xformer(signal)
compiler_state.add_port(signal, fragment.ports[signal])
rhs_compiler(signal)
# Transform all clocks clocks and resets eagerly and outside of any hierarchy, to make
# sure they get sensible (non-prefixed) names. This does not affect semantics.
for domain, _ in fragment.iter_sync():
cd = fragment.domains[domain]
xformer(cd.clk)
xformer(cd.rst)
rhs_compiler(cd.clk)
rhs_compiler(cd.rst)
# Transform all subfragments to their respective cells. Transforming signals connected
# to their ports into wires eagerly makes sure they get sensible (prefixed with submodule
@ -444,9 +471,9 @@ def convert_fragment(builder, fragment, name, top):
for subfragment, sub_name in fragment.subfragments:
sub_name, sub_port_map = \
convert_fragment(builder, subfragment, top=False, name=sub_name)
with xformer.hierarchy(sub_name):
with compiler_state.hierarchy(sub_name):
module.cell(sub_name, name=sub_name, ports={
p: xformer(s) for p, s in sub_port_map.items()
p: rhs_compiler(s) for p, s in sub_port_map.items()
})
with module.process() as process:
@ -455,11 +482,10 @@ def convert_fragment(builder, fragment, name, top):
# For every signal in sync domains, assign \sig$next to the current value (\sig).
for domain, signal in fragment.iter_drivers():
if domain is None:
prev_value = xformer(ast.Const(signal.reset, signal.nbits))
prev_value = ast.Const(signal.reset, signal.nbits)
else:
prev_value = xformer(signal)
with xformer.lhs():
case.assign(xformer(signal), prev_value)
prev_value = signal
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
# Convert statements into decision trees.
def _convert_stmts(case, stmts):
@ -468,17 +494,15 @@ def convert_fragment(builder, fragment, name, top):
lhs_bits, lhs_sign = stmt.lhs.shape()
rhs_bits, rhs_sign = stmt.rhs.shape()
if lhs_bits == rhs_bits:
rhs_sigspec = xformer(stmt.rhs)
rhs_sigspec = rhs_compiler(stmt.rhs)
else:
# In RTLIL, LHS and RHS of assignment must have exactly same width.
rhs_sigspec = xformer.match_shape(
rhs_sigspec = rhs_compiler.match_shape(
stmt.rhs, lhs_bits, rhs_sign)
with xformer.lhs():
lhs_sigspec = xformer(stmt.lhs)
case.assign(lhs_sigspec, rhs_sigspec)
case.assign(lhs_compiler(stmt.lhs), rhs_sigspec)
elif isinstance(stmt, ast.Switch):
with case.switch(xformer(stmt.test)) as switch:
with case.switch(rhs_compiler(stmt.test)) as switch:
for value, nested_stmts in stmt.cases.items():
with switch.case(value) as nested_case:
_convert_stmts(nested_case, nested_stmts)
@ -489,12 +513,11 @@ def convert_fragment(builder, fragment, name, top):
_convert_stmts(case, fragment.statements)
# For every signal in the sync domain, assign \sig's initial value (which will end up
# as the \init reg attribute) to the reset value. Note that this assigns \sig,
# not \sig$next.
# as the \init reg attribute) to the reset value.
with process.sync("init") as sync:
for domain, signal in fragment.iter_sync():
sync.update(xformer(signal),
xformer(ast.Const(signal.reset, signal.nbits)))
wire_curr, wire_next = compiler_state.resolve(signal)
sync.update(wire_curr, rhs_compiler(ast.Const(signal.reset, signal.nbits)))
# For every signal in every domain, assign \sig to \sig$next. The sensitivity list,
# however, differs between domains: for comb domains, it is `always`, for sync domains
@ -506,23 +529,22 @@ def convert_fragment(builder, fragment, name, top):
triggers.append(("always",))
else:
cd = fragment.domains[domain]
triggers.append(("posedge", xformer(cd.clk)))
triggers.append(("posedge", compiler_state.resolve_curr(cd.clk)))
if cd.async_reset:
triggers.append(("posedge", xformer(cd.rst)))
triggers.append(("posedge", compiler_state.resolve_curr(cd.rst)))
for trigger in triggers:
with process.sync(*trigger) as sync:
for signal in signals:
lhs_sigspec = xformer(signal)
with xformer.lhs():
sync.update(lhs_sigspec, xformer(signal))
wire_curr, wire_next = compiler_state.resolve(signal)
sync.update(wire_curr, wire_next)
# Finally, collect the names we've given to our ports in RTLIL, and correlate these with
# the signals represented by these ports. If we are a submodule, this will be necessary
# to create a cell for us in the parent module.
port_map = OrderedDict()
for signal in fragment.ports:
port_map[xformer(signal)] = signal
port_map[compiler_state.resolve_curr(signal)] = signal
return module.name, port_map