back.pysim: use bare ints for signal values (-5% runtime).

This commit is contained in:
whitequark 2018-12-14 03:05:57 +00:00
parent 55e729f68a
commit 3e59d857e1
2 changed files with 42 additions and 35 deletions

View file

@ -21,20 +21,20 @@ class _State:
return self.curr[signal]
def set_curr(self, signal, value):
assert isinstance(value, Const)
if self.curr[signal].value != value.value:
assert isinstance(value, int)
if self.curr[signal] != value:
self.curr_dirty.add(signal)
self.curr[signal] = value
def set_next(self, signal, value):
assert isinstance(value, Const)
if self.next[signal].value != value.value:
assert isinstance(value, int)
if self.next[signal] != value:
self.next_dirty.add(signal)
self.next[signal] = value
def commit(self, signal):
old_value = self.curr[signal]
if self.curr[signal].value != self.next[signal].value:
if self.curr[signal] != self.next[signal]:
self.next_dirty.remove(signal)
self.curr_dirty.add(signal)
self.curr[signal] = self.next[signal]
@ -47,12 +47,15 @@ class _State:
yield signal, self.curr[signal], self.next[signal]
normalize = Const.normalize
class _RHSValueCompiler(ValueTransformer):
def __init__(self, sensitivity):
self.sensitivity = sensitivity
def on_Const(self, value):
return lambda state: value
return lambda state: value.value
def on_Signal(self, value):
self.sensitivity.add(value)
@ -69,28 +72,27 @@ class _RHSValueCompiler(ValueTransformer):
if len(value.operands) == 1:
arg, = map(self, value.operands)
if value.op == "~":
return lambda state: Const(~arg(state).value, shape)
elif value.op == "-":
return lambda state: Const(-arg(state).value, shape)
return lambda state: normalize(~arg(state), shape)
if value.op == "-":
return lambda state: normalize(-arg(state), shape)
elif len(value.operands) == 2:
lhs, rhs = map(self, value.operands)
if value.op == "+":
return lambda state: Const(lhs(state).value + rhs(state).value, shape)
return lambda state: normalize(lhs(state) + rhs(state), shape)
if value.op == "-":
return lambda state: Const(lhs(state).value - rhs(state).value, shape)
return lambda state: normalize(lhs(state) - rhs(state), shape)
if value.op == "&":
return lambda state: Const(lhs(state).value & rhs(state).value, shape)
return lambda state: normalize(lhs(state) & rhs(state), shape)
if value.op == "|":
return lambda state: Const(lhs(state).value | rhs(state).value, shape)
return lambda state: normalize(lhs(state) | rhs(state), shape)
if value.op == "^":
return lambda state: Const(lhs(state).value ^ rhs(state).value, shape)
elif value.op == "==":
lhs, rhs = map(self, value.operands)
return lambda state: Const(lhs(state).value == rhs(state).value, shape)
return lambda state: normalize(lhs(state) ^ rhs(state), shape)
if value.op == "==":
return lambda state: normalize(lhs(state) == rhs(state), shape)
elif len(value.operands) == 3:
if value.op == "m":
sel, val1, val0 = map(self, value.operands)
return lambda state: val1(state) if sel(state).value else val0(state)
return lambda state: val1(state) if sel(state) else val0(state)
raise NotImplementedError("Operator '{}' not implemented".format(value.op))
def on_Slice(self, value):
@ -98,7 +100,7 @@ class _RHSValueCompiler(ValueTransformer):
arg = self(value.value)
shift = value.start
mask = (1 << (value.end - value.start)) - 1
return lambda state: Const((arg(state).value >> shift) & mask, shape)
return lambda state: normalize((arg(state) >> shift) & mask, shape)
def on_Part(self, value):
raise NotImplementedError
@ -113,8 +115,8 @@ class _RHSValueCompiler(ValueTransformer):
def eval(state):
result = 0
for offset, mask, opnd in parts:
result |= (opnd(state).value & mask) << offset
return Const(result, shape)
result |= (opnd(state) & mask) << offset
return normalize(result, shape)
return eval
def on_Repl(self, value):
@ -127,8 +129,8 @@ class _RHSValueCompiler(ValueTransformer):
result = 0
for _ in range(count):
result <<= offset
result |= opnd(state).value
return Const(result, shape)
result |= opnd(state)
return normalize(result, shape)
return eval
@ -147,7 +149,7 @@ class _StatementCompiler(StatementTransformer):
lhs = self.lhs_compiler(stmt.lhs)
rhs = self.rhs_compiler(stmt.rhs)
def run(state):
lhs(state, Const(rhs(state).value, shape))
lhs(state, normalize(rhs(state), shape))
return run
def on_Switch(self, stmt):
@ -164,7 +166,7 @@ class _StatementCompiler(StatementTransformer):
cases.append((lambda test: test & mask == value,
self.on_statements(stmts)))
def run(state):
test_value = test(state).value
test_value = test(state)
for check, body in cases:
if check(test_value):
body(state)
@ -255,7 +257,7 @@ class Simulator:
self._signals.add(signal)
self._state.curr[signal] = self._state.next[signal] = \
Const(signal.reset, signal.shape())
normalize(signal.reset, signal.shape())
self._state.curr_dirty.add(signal)
if signal not in self._vcd_signals:
@ -295,7 +297,7 @@ class Simulator:
def _commit_signal(self, signal):
old, new = self._state.commit(signal)
if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
if (old, new) == (0, 1) and signal in self._domain_triggers:
domain = self._domain_triggers[signal]
for sync_signal in self._state.next_dirty:
if sync_signal in self._domain_signals[domain]:
@ -303,7 +305,7 @@ class Simulator:
if self._vcd_writer:
for vcd_signal in self._vcd_signals[signal]:
self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new)
def _handle_event(self):
handlers = set()
@ -320,7 +322,7 @@ class Simulator:
self._commit_signal(signal)
def _force_signal(self, signal, value):
assert signal in self._comb_signals or signal in self._user_signals
assert signal in self._user_signals
self._state.set_next(signal, value)
self._commit_signal(signal)
@ -339,7 +341,7 @@ class Simulator:
elif isinstance(stmt, Assign):
assert isinstance(stmt.lhs, Signal)
assert isinstance(stmt.rhs, Const)
self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
self._force_signal(stmt.lhs, normalize(stmt.rhs.value, stmt.lhs.shape()))
else:
raise TypeError("Received unsupported statement '{!r}' from process {}"
.format(stmt, proc))

View file

@ -218,6 +218,15 @@ class Const(Value):
"""
src_loc = None
@staticmethod
def normalize(value, shape):
nbits, signed = shape
mask = (1 << nbits) - 1
value &= mask
if signed and value >> (nbits - 1):
value |= ~mask
return value
def __init__(self, value, shape=None):
self.value = int(value)
if shape is None:
@ -227,11 +236,7 @@ class Const(Value):
self.nbits, self.signed = shape
if not isinstance(self.nbits, int) or self.nbits < 0:
raise TypeError("Width must be a positive integer")
mask = (1 << self.nbits) - 1
self.value &= mask
if self.signed and self.value >> (self.nbits - 1):
self.value |= ~mask
self.value = self.normalize(self.value, shape)
def shape(self):
return self.nbits, self.signed