sim._pyrtl: optimize uses of reflexive operators.
When a literal is used on the left-hand side of a numeric operator, Python is able to constant-fold some expressions: >>> dis.dis(lambda x: 0 + 0 + x) 1 0 LOAD_CONST 1 (0) 2 LOAD_FAST 0 (x) 4 BINARY_ADD 6 RETURN_VALUE If a literal is used on the right-hand side such that the left-hand side is variable, this doesn't happen: >>> dis.dis(lambda x: x + 0 + 0) 1 0 LOAD_FAST 0 (x) 2 LOAD_CONST 1 (0) 4 BINARY_ADD 6 LOAD_CONST 1 (0) 8 BINARY_ADD 10 RETURN_VALUE PyRTL generates fairly redundant code due to the pervasive masking, and because of that, transforming expressions into the former form, where possible, improves runtime by about 10% on Minerva SRAM SoC.
This commit is contained in:
parent
38b75ba4bc
commit
8c6c3643cd
|
@ -103,7 +103,7 @@ class _RHSValueCompiler(_ValueCompiler):
|
|||
def on_Operator(self, value):
|
||||
def mask(value):
|
||||
value_mask = (1 << len(value)) - 1
|
||||
return f"({self(value)} & {value_mask})"
|
||||
return f"({value_mask} & {self(value)})"
|
||||
|
||||
def sign(value):
|
||||
if value.shape().signed:
|
||||
|
@ -120,9 +120,9 @@ class _RHSValueCompiler(_ValueCompiler):
|
|||
if value.operator == "b":
|
||||
return f"bool({mask(arg)})"
|
||||
if value.operator == "r|":
|
||||
return f"({mask(arg)} != 0)"
|
||||
return f"(0 != {mask(arg)})"
|
||||
if value.operator == "r&":
|
||||
return f"({mask(arg)} == {(1 << len(arg)) - 1})"
|
||||
return f"({(1 << len(arg)) - 1} == {mask(arg)})"
|
||||
if value.operator == "r^":
|
||||
# Believe it or not, this is the fastest way to compute a sideways XOR in Python.
|
||||
return f"(format({mask(arg)}, 'b').count('1') % 2)"
|
||||
|
@ -172,20 +172,20 @@ class _RHSValueCompiler(_ValueCompiler):
|
|||
raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
|
||||
|
||||
def on_Slice(self, value):
|
||||
return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
|
||||
return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
|
||||
|
||||
def on_Part(self, value):
|
||||
offset_mask = (1 << len(value.offset)) - 1
|
||||
offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
|
||||
return f"({self(value.value)} >> {offset} & " \
|
||||
f"{(1 << value.width) - 1})"
|
||||
offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
|
||||
return f"({(1 << value.width) - 1} & " \
|
||||
f"{self(value.value)} >> {offset})"
|
||||
|
||||
def on_Cat(self, value):
|
||||
gen_parts = []
|
||||
offset = 0
|
||||
for part in value.parts:
|
||||
part_mask = (1 << len(part)) - 1
|
||||
gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
|
||||
gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
|
||||
offset += len(part)
|
||||
if gen_parts:
|
||||
return f"({' | '.join(gen_parts)})"
|
||||
|
@ -193,7 +193,7 @@ class _RHSValueCompiler(_ValueCompiler):
|
|||
|
||||
def on_Repl(self, value):
|
||||
part_mask = (1 << len(value.value)) - 1
|
||||
gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
|
||||
gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
|
||||
gen_parts = []
|
||||
offset = 0
|
||||
for _ in range(value.count):
|
||||
|
@ -205,15 +205,15 @@ class _RHSValueCompiler(_ValueCompiler):
|
|||
|
||||
def on_ArrayProxy(self, value):
|
||||
index_mask = (1 << len(value.index)) - 1
|
||||
gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
|
||||
gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
|
||||
gen_value = self.emitter.gen_var("rhs_proxy")
|
||||
if value.elems:
|
||||
gen_elems = []
|
||||
for index, elem in enumerate(value.elems):
|
||||
if index == 0:
|
||||
self.emitter.append(f"if {gen_index} == {index}:")
|
||||
self.emitter.append(f"if {index} == {gen_index}:")
|
||||
else:
|
||||
self.emitter.append(f"elif {gen_index} == {index}:")
|
||||
self.emitter.append(f"elif {index} == {gen_index}:")
|
||||
with self.emitter.indent():
|
||||
self.emitter.append(f"{gen_value} = {self(elem)}")
|
||||
self.emitter.append(f"else:")
|
||||
|
@ -253,9 +253,9 @@ class _LHSValueCompiler(_ValueCompiler):
|
|||
def gen(arg):
|
||||
value_mask = (1 << len(value)) - 1
|
||||
if value.shape().signed:
|
||||
value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
|
||||
value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
|
||||
else: # unsigned
|
||||
value_sign = f"{arg} & {value_mask}"
|
||||
value_sign = f"{value_mask} & {arg}"
|
||||
self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
|
||||
return gen
|
||||
|
||||
|
@ -267,17 +267,17 @@ class _LHSValueCompiler(_ValueCompiler):
|
|||
width_mask = (1 << (value.stop - value.start)) - 1
|
||||
self(value.value)(f"({self.lrhs(value.value)} & " \
|
||||
f"{~(width_mask << value.start)} | " \
|
||||
f"(({arg} & {width_mask}) << {value.start}))")
|
||||
f"(({width_mask} & {arg}) << {value.start}))")
|
||||
return gen
|
||||
|
||||
def on_Part(self, value):
|
||||
def gen(arg):
|
||||
width_mask = (1 << value.width) - 1
|
||||
offset_mask = (1 << len(value.offset)) - 1
|
||||
offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
|
||||
offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
|
||||
self(value.value)(f"({self.lrhs(value.value)} & " \
|
||||
f"~({width_mask} << {offset}) | " \
|
||||
f"(({arg} & {width_mask}) << {offset}))")
|
||||
f"(({width_mask} & {arg}) << {offset}))")
|
||||
return gen
|
||||
|
||||
def on_Cat(self, value):
|
||||
|
@ -287,7 +287,7 @@ class _LHSValueCompiler(_ValueCompiler):
|
|||
offset = 0
|
||||
for part in value.parts:
|
||||
part_mask = (1 << len(part)) - 1
|
||||
self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
|
||||
self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
|
||||
offset += len(part)
|
||||
return gen
|
||||
|
||||
|
@ -302,9 +302,9 @@ class _LHSValueCompiler(_ValueCompiler):
|
|||
gen_elems = []
|
||||
for index, elem in enumerate(value.elems):
|
||||
if index == 0:
|
||||
self.emitter.append(f"if {gen_index} == {index}:")
|
||||
self.emitter.append(f"if {index} == {gen_index}:")
|
||||
else:
|
||||
self.emitter.append(f"elif {gen_index} == {index}:")
|
||||
self.emitter.append(f"elif {index} == {gen_index}:")
|
||||
with self.emitter.indent():
|
||||
self(elem)(arg)
|
||||
self.emitter.append(f"else:")
|
||||
|
@ -332,7 +332,7 @@ class _StatementCompiler(StatementVisitor, _Compiler):
|
|||
|
||||
def on_Switch(self, stmt):
|
||||
gen_test = self.emitter.def_var("test",
|
||||
f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
|
||||
f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
|
||||
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
|
||||
gen_checks = []
|
||||
if not patterns:
|
||||
|
@ -342,10 +342,10 @@ class _StatementCompiler(StatementVisitor, _Compiler):
|
|||
if "-" in pattern:
|
||||
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
|
||||
value = int("".join("0" if b == "-" else b for b in pattern), 2)
|
||||
gen_checks.append(f"({gen_test} & {mask}) == {value}")
|
||||
gen_checks.append(f"{value} == ({mask} & {gen_test})")
|
||||
else:
|
||||
value = int(pattern, 2)
|
||||
gen_checks.append(f"{gen_test} == {value}")
|
||||
gen_checks.append(f"{value} == {gen_test}")
|
||||
if index == 0:
|
||||
self.emitter.append(f"if {' or '.join(gen_checks)}:")
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue