sim._pyrtl: optimize uses of reflexive operators.
When a literal is used on the left-hand side of a numeric operator, Python is able to constant-fold some expressions: >>> dis.dis(lambda x: 0 + 0 + x) 1 0 LOAD_CONST 1 (0) 2 LOAD_FAST 0 (x) 4 BINARY_ADD 6 RETURN_VALUE If a literal is used on the right-hand side such that the left-hand side is variable, this doesn't happen: >>> dis.dis(lambda x: x + 0 + 0) 1 0 LOAD_FAST 0 (x) 2 LOAD_CONST 1 (0) 4 BINARY_ADD 6 LOAD_CONST 1 (0) 8 BINARY_ADD 10 RETURN_VALUE PyRTL generates fairly redundant code due to the pervasive masking, and because of that, transforming expressions into the former form, where possible, improves runtime by about 10% on Minerva SRAM SoC.
This commit is contained in:
parent
38b75ba4bc
commit
8c6c3643cd
|
@ -103,7 +103,7 @@ class _RHSValueCompiler(_ValueCompiler):
|
||||||
def on_Operator(self, value):
|
def on_Operator(self, value):
|
||||||
def mask(value):
|
def mask(value):
|
||||||
value_mask = (1 << len(value)) - 1
|
value_mask = (1 << len(value)) - 1
|
||||||
return f"({self(value)} & {value_mask})"
|
return f"({value_mask} & {self(value)})"
|
||||||
|
|
||||||
def sign(value):
|
def sign(value):
|
||||||
if value.shape().signed:
|
if value.shape().signed:
|
||||||
|
@ -120,9 +120,9 @@ class _RHSValueCompiler(_ValueCompiler):
|
||||||
if value.operator == "b":
|
if value.operator == "b":
|
||||||
return f"bool({mask(arg)})"
|
return f"bool({mask(arg)})"
|
||||||
if value.operator == "r|":
|
if value.operator == "r|":
|
||||||
return f"({mask(arg)} != 0)"
|
return f"(0 != {mask(arg)})"
|
||||||
if value.operator == "r&":
|
if value.operator == "r&":
|
||||||
return f"({mask(arg)} == {(1 << len(arg)) - 1})"
|
return f"({(1 << len(arg)) - 1} == {mask(arg)})"
|
||||||
if value.operator == "r^":
|
if value.operator == "r^":
|
||||||
# Believe it or not, this is the fastest way to compute a sideways XOR in Python.
|
# Believe it or not, this is the fastest way to compute a sideways XOR in Python.
|
||||||
return f"(format({mask(arg)}, 'b').count('1') % 2)"
|
return f"(format({mask(arg)}, 'b').count('1') % 2)"
|
||||||
|
@ -172,20 +172,20 @@ class _RHSValueCompiler(_ValueCompiler):
|
||||||
raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
|
raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
|
||||||
|
|
||||||
def on_Slice(self, value):
|
def on_Slice(self, value):
|
||||||
return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
|
return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
|
||||||
|
|
||||||
def on_Part(self, value):
|
def on_Part(self, value):
|
||||||
offset_mask = (1 << len(value.offset)) - 1
|
offset_mask = (1 << len(value.offset)) - 1
|
||||||
offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
|
offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
|
||||||
return f"({self(value.value)} >> {offset} & " \
|
return f"({(1 << value.width) - 1} & " \
|
||||||
f"{(1 << value.width) - 1})"
|
f"{self(value.value)} >> {offset})"
|
||||||
|
|
||||||
def on_Cat(self, value):
|
def on_Cat(self, value):
|
||||||
gen_parts = []
|
gen_parts = []
|
||||||
offset = 0
|
offset = 0
|
||||||
for part in value.parts:
|
for part in value.parts:
|
||||||
part_mask = (1 << len(part)) - 1
|
part_mask = (1 << len(part)) - 1
|
||||||
gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
|
gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
|
||||||
offset += len(part)
|
offset += len(part)
|
||||||
if gen_parts:
|
if gen_parts:
|
||||||
return f"({' | '.join(gen_parts)})"
|
return f"({' | '.join(gen_parts)})"
|
||||||
|
@ -193,7 +193,7 @@ class _RHSValueCompiler(_ValueCompiler):
|
||||||
|
|
||||||
def on_Repl(self, value):
|
def on_Repl(self, value):
|
||||||
part_mask = (1 << len(value.value)) - 1
|
part_mask = (1 << len(value.value)) - 1
|
||||||
gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
|
gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
|
||||||
gen_parts = []
|
gen_parts = []
|
||||||
offset = 0
|
offset = 0
|
||||||
for _ in range(value.count):
|
for _ in range(value.count):
|
||||||
|
@ -205,15 +205,15 @@ class _RHSValueCompiler(_ValueCompiler):
|
||||||
|
|
||||||
def on_ArrayProxy(self, value):
|
def on_ArrayProxy(self, value):
|
||||||
index_mask = (1 << len(value.index)) - 1
|
index_mask = (1 << len(value.index)) - 1
|
||||||
gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
|
gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
|
||||||
gen_value = self.emitter.gen_var("rhs_proxy")
|
gen_value = self.emitter.gen_var("rhs_proxy")
|
||||||
if value.elems:
|
if value.elems:
|
||||||
gen_elems = []
|
gen_elems = []
|
||||||
for index, elem in enumerate(value.elems):
|
for index, elem in enumerate(value.elems):
|
||||||
if index == 0:
|
if index == 0:
|
||||||
self.emitter.append(f"if {gen_index} == {index}:")
|
self.emitter.append(f"if {index} == {gen_index}:")
|
||||||
else:
|
else:
|
||||||
self.emitter.append(f"elif {gen_index} == {index}:")
|
self.emitter.append(f"elif {index} == {gen_index}:")
|
||||||
with self.emitter.indent():
|
with self.emitter.indent():
|
||||||
self.emitter.append(f"{gen_value} = {self(elem)}")
|
self.emitter.append(f"{gen_value} = {self(elem)}")
|
||||||
self.emitter.append(f"else:")
|
self.emitter.append(f"else:")
|
||||||
|
@ -253,9 +253,9 @@ class _LHSValueCompiler(_ValueCompiler):
|
||||||
def gen(arg):
|
def gen(arg):
|
||||||
value_mask = (1 << len(value)) - 1
|
value_mask = (1 << len(value)) - 1
|
||||||
if value.shape().signed:
|
if value.shape().signed:
|
||||||
value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
|
value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
|
||||||
else: # unsigned
|
else: # unsigned
|
||||||
value_sign = f"{arg} & {value_mask}"
|
value_sign = f"{value_mask} & {arg}"
|
||||||
self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
|
self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
|
@ -267,17 +267,17 @@ class _LHSValueCompiler(_ValueCompiler):
|
||||||
width_mask = (1 << (value.stop - value.start)) - 1
|
width_mask = (1 << (value.stop - value.start)) - 1
|
||||||
self(value.value)(f"({self.lrhs(value.value)} & " \
|
self(value.value)(f"({self.lrhs(value.value)} & " \
|
||||||
f"{~(width_mask << value.start)} | " \
|
f"{~(width_mask << value.start)} | " \
|
||||||
f"(({arg} & {width_mask}) << {value.start}))")
|
f"(({width_mask} & {arg}) << {value.start}))")
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
def on_Part(self, value):
|
def on_Part(self, value):
|
||||||
def gen(arg):
|
def gen(arg):
|
||||||
width_mask = (1 << value.width) - 1
|
width_mask = (1 << value.width) - 1
|
||||||
offset_mask = (1 << len(value.offset)) - 1
|
offset_mask = (1 << len(value.offset)) - 1
|
||||||
offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
|
offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
|
||||||
self(value.value)(f"({self.lrhs(value.value)} & " \
|
self(value.value)(f"({self.lrhs(value.value)} & " \
|
||||||
f"~({width_mask} << {offset}) | " \
|
f"~({width_mask} << {offset}) | " \
|
||||||
f"(({arg} & {width_mask}) << {offset}))")
|
f"(({width_mask} & {arg}) << {offset}))")
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
def on_Cat(self, value):
|
def on_Cat(self, value):
|
||||||
|
@ -287,7 +287,7 @@ class _LHSValueCompiler(_ValueCompiler):
|
||||||
offset = 0
|
offset = 0
|
||||||
for part in value.parts:
|
for part in value.parts:
|
||||||
part_mask = (1 << len(part)) - 1
|
part_mask = (1 << len(part)) - 1
|
||||||
self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
|
self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
|
||||||
offset += len(part)
|
offset += len(part)
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
|
@ -302,9 +302,9 @@ class _LHSValueCompiler(_ValueCompiler):
|
||||||
gen_elems = []
|
gen_elems = []
|
||||||
for index, elem in enumerate(value.elems):
|
for index, elem in enumerate(value.elems):
|
||||||
if index == 0:
|
if index == 0:
|
||||||
self.emitter.append(f"if {gen_index} == {index}:")
|
self.emitter.append(f"if {index} == {gen_index}:")
|
||||||
else:
|
else:
|
||||||
self.emitter.append(f"elif {gen_index} == {index}:")
|
self.emitter.append(f"elif {index} == {gen_index}:")
|
||||||
with self.emitter.indent():
|
with self.emitter.indent():
|
||||||
self(elem)(arg)
|
self(elem)(arg)
|
||||||
self.emitter.append(f"else:")
|
self.emitter.append(f"else:")
|
||||||
|
@ -332,7 +332,7 @@ class _StatementCompiler(StatementVisitor, _Compiler):
|
||||||
|
|
||||||
def on_Switch(self, stmt):
|
def on_Switch(self, stmt):
|
||||||
gen_test = self.emitter.def_var("test",
|
gen_test = self.emitter.def_var("test",
|
||||||
f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
|
f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
|
||||||
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
|
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
|
||||||
gen_checks = []
|
gen_checks = []
|
||||||
if not patterns:
|
if not patterns:
|
||||||
|
@ -342,10 +342,10 @@ class _StatementCompiler(StatementVisitor, _Compiler):
|
||||||
if "-" in pattern:
|
if "-" in pattern:
|
||||||
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
|
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
|
||||||
value = int("".join("0" if b == "-" else b for b in pattern), 2)
|
value = int("".join("0" if b == "-" else b for b in pattern), 2)
|
||||||
gen_checks.append(f"({gen_test} & {mask}) == {value}")
|
gen_checks.append(f"{value} == ({mask} & {gen_test})")
|
||||||
else:
|
else:
|
||||||
value = int(pattern, 2)
|
value = int(pattern, 2)
|
||||||
gen_checks.append(f"{gen_test} == {value}")
|
gen_checks.append(f"{value} == {gen_test}")
|
||||||
if index == 0:
|
if index == 0:
|
||||||
self.emitter.append(f"if {' or '.join(gen_checks)}:")
|
self.emitter.append(f"if {' or '.join(gen_checks)}:")
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue