back.pysim: use bare ints for signal values (-5% runtime).
This commit is contained in:
parent
55e729f68a
commit
3e59d857e1
|
@ -21,20 +21,20 @@ class _State:
|
||||||
return self.curr[signal]
|
return self.curr[signal]
|
||||||
|
|
||||||
def set_curr(self, signal, value):
|
def set_curr(self, signal, value):
|
||||||
assert isinstance(value, Const)
|
assert isinstance(value, int)
|
||||||
if self.curr[signal].value != value.value:
|
if self.curr[signal] != value:
|
||||||
self.curr_dirty.add(signal)
|
self.curr_dirty.add(signal)
|
||||||
self.curr[signal] = value
|
self.curr[signal] = value
|
||||||
|
|
||||||
def set_next(self, signal, value):
|
def set_next(self, signal, value):
|
||||||
assert isinstance(value, Const)
|
assert isinstance(value, int)
|
||||||
if self.next[signal].value != value.value:
|
if self.next[signal] != value:
|
||||||
self.next_dirty.add(signal)
|
self.next_dirty.add(signal)
|
||||||
self.next[signal] = value
|
self.next[signal] = value
|
||||||
|
|
||||||
def commit(self, signal):
|
def commit(self, signal):
|
||||||
old_value = self.curr[signal]
|
old_value = self.curr[signal]
|
||||||
if self.curr[signal].value != self.next[signal].value:
|
if self.curr[signal] != self.next[signal]:
|
||||||
self.next_dirty.remove(signal)
|
self.next_dirty.remove(signal)
|
||||||
self.curr_dirty.add(signal)
|
self.curr_dirty.add(signal)
|
||||||
self.curr[signal] = self.next[signal]
|
self.curr[signal] = self.next[signal]
|
||||||
|
@ -47,12 +47,15 @@ class _State:
|
||||||
yield signal, self.curr[signal], self.next[signal]
|
yield signal, self.curr[signal], self.next[signal]
|
||||||
|
|
||||||
|
|
||||||
|
normalize = Const.normalize
|
||||||
|
|
||||||
|
|
||||||
class _RHSValueCompiler(ValueTransformer):
|
class _RHSValueCompiler(ValueTransformer):
|
||||||
def __init__(self, sensitivity):
|
def __init__(self, sensitivity):
|
||||||
self.sensitivity = sensitivity
|
self.sensitivity = sensitivity
|
||||||
|
|
||||||
def on_Const(self, value):
|
def on_Const(self, value):
|
||||||
return lambda state: value
|
return lambda state: value.value
|
||||||
|
|
||||||
def on_Signal(self, value):
|
def on_Signal(self, value):
|
||||||
self.sensitivity.add(value)
|
self.sensitivity.add(value)
|
||||||
|
@ -69,28 +72,27 @@ class _RHSValueCompiler(ValueTransformer):
|
||||||
if len(value.operands) == 1:
|
if len(value.operands) == 1:
|
||||||
arg, = map(self, value.operands)
|
arg, = map(self, value.operands)
|
||||||
if value.op == "~":
|
if value.op == "~":
|
||||||
return lambda state: Const(~arg(state).value, shape)
|
return lambda state: normalize(~arg(state), shape)
|
||||||
elif value.op == "-":
|
if value.op == "-":
|
||||||
return lambda state: Const(-arg(state).value, shape)
|
return lambda state: normalize(-arg(state), shape)
|
||||||
elif len(value.operands) == 2:
|
elif len(value.operands) == 2:
|
||||||
lhs, rhs = map(self, value.operands)
|
lhs, rhs = map(self, value.operands)
|
||||||
if value.op == "+":
|
if value.op == "+":
|
||||||
return lambda state: Const(lhs(state).value + rhs(state).value, shape)
|
return lambda state: normalize(lhs(state) + rhs(state), shape)
|
||||||
if value.op == "-":
|
if value.op == "-":
|
||||||
return lambda state: Const(lhs(state).value - rhs(state).value, shape)
|
return lambda state: normalize(lhs(state) - rhs(state), shape)
|
||||||
if value.op == "&":
|
if value.op == "&":
|
||||||
return lambda state: Const(lhs(state).value & rhs(state).value, shape)
|
return lambda state: normalize(lhs(state) & rhs(state), shape)
|
||||||
if value.op == "|":
|
if value.op == "|":
|
||||||
return lambda state: Const(lhs(state).value | rhs(state).value, shape)
|
return lambda state: normalize(lhs(state) | rhs(state), shape)
|
||||||
if value.op == "^":
|
if value.op == "^":
|
||||||
return lambda state: Const(lhs(state).value ^ rhs(state).value, shape)
|
return lambda state: normalize(lhs(state) ^ rhs(state), shape)
|
||||||
elif value.op == "==":
|
if value.op == "==":
|
||||||
lhs, rhs = map(self, value.operands)
|
return lambda state: normalize(lhs(state) == rhs(state), shape)
|
||||||
return lambda state: Const(lhs(state).value == rhs(state).value, shape)
|
|
||||||
elif len(value.operands) == 3:
|
elif len(value.operands) == 3:
|
||||||
if value.op == "m":
|
if value.op == "m":
|
||||||
sel, val1, val0 = map(self, value.operands)
|
sel, val1, val0 = map(self, value.operands)
|
||||||
return lambda state: val1(state) if sel(state).value else val0(state)
|
return lambda state: val1(state) if sel(state) else val0(state)
|
||||||
raise NotImplementedError("Operator '{}' not implemented".format(value.op))
|
raise NotImplementedError("Operator '{}' not implemented".format(value.op))
|
||||||
|
|
||||||
def on_Slice(self, value):
|
def on_Slice(self, value):
|
||||||
|
@ -98,7 +100,7 @@ class _RHSValueCompiler(ValueTransformer):
|
||||||
arg = self(value.value)
|
arg = self(value.value)
|
||||||
shift = value.start
|
shift = value.start
|
||||||
mask = (1 << (value.end - value.start)) - 1
|
mask = (1 << (value.end - value.start)) - 1
|
||||||
return lambda state: Const((arg(state).value >> shift) & mask, shape)
|
return lambda state: normalize((arg(state) >> shift) & mask, shape)
|
||||||
|
|
||||||
def on_Part(self, value):
|
def on_Part(self, value):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -113,8 +115,8 @@ class _RHSValueCompiler(ValueTransformer):
|
||||||
def eval(state):
|
def eval(state):
|
||||||
result = 0
|
result = 0
|
||||||
for offset, mask, opnd in parts:
|
for offset, mask, opnd in parts:
|
||||||
result |= (opnd(state).value & mask) << offset
|
result |= (opnd(state) & mask) << offset
|
||||||
return Const(result, shape)
|
return normalize(result, shape)
|
||||||
return eval
|
return eval
|
||||||
|
|
||||||
def on_Repl(self, value):
|
def on_Repl(self, value):
|
||||||
|
@ -127,8 +129,8 @@ class _RHSValueCompiler(ValueTransformer):
|
||||||
result = 0
|
result = 0
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
result <<= offset
|
result <<= offset
|
||||||
result |= opnd(state).value
|
result |= opnd(state)
|
||||||
return Const(result, shape)
|
return normalize(result, shape)
|
||||||
return eval
|
return eval
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,7 +149,7 @@ class _StatementCompiler(StatementTransformer):
|
||||||
lhs = self.lhs_compiler(stmt.lhs)
|
lhs = self.lhs_compiler(stmt.lhs)
|
||||||
rhs = self.rhs_compiler(stmt.rhs)
|
rhs = self.rhs_compiler(stmt.rhs)
|
||||||
def run(state):
|
def run(state):
|
||||||
lhs(state, Const(rhs(state).value, shape))
|
lhs(state, normalize(rhs(state), shape))
|
||||||
return run
|
return run
|
||||||
|
|
||||||
def on_Switch(self, stmt):
|
def on_Switch(self, stmt):
|
||||||
|
@ -164,7 +166,7 @@ class _StatementCompiler(StatementTransformer):
|
||||||
cases.append((lambda test: test & mask == value,
|
cases.append((lambda test: test & mask == value,
|
||||||
self.on_statements(stmts)))
|
self.on_statements(stmts)))
|
||||||
def run(state):
|
def run(state):
|
||||||
test_value = test(state).value
|
test_value = test(state)
|
||||||
for check, body in cases:
|
for check, body in cases:
|
||||||
if check(test_value):
|
if check(test_value):
|
||||||
body(state)
|
body(state)
|
||||||
|
@ -255,7 +257,7 @@ class Simulator:
|
||||||
self._signals.add(signal)
|
self._signals.add(signal)
|
||||||
|
|
||||||
self._state.curr[signal] = self._state.next[signal] = \
|
self._state.curr[signal] = self._state.next[signal] = \
|
||||||
Const(signal.reset, signal.shape())
|
normalize(signal.reset, signal.shape())
|
||||||
self._state.curr_dirty.add(signal)
|
self._state.curr_dirty.add(signal)
|
||||||
|
|
||||||
if signal not in self._vcd_signals:
|
if signal not in self._vcd_signals:
|
||||||
|
@ -295,7 +297,7 @@ class Simulator:
|
||||||
|
|
||||||
def _commit_signal(self, signal):
|
def _commit_signal(self, signal):
|
||||||
old, new = self._state.commit(signal)
|
old, new = self._state.commit(signal)
|
||||||
if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
|
if (old, new) == (0, 1) and signal in self._domain_triggers:
|
||||||
domain = self._domain_triggers[signal]
|
domain = self._domain_triggers[signal]
|
||||||
for sync_signal in self._state.next_dirty:
|
for sync_signal in self._state.next_dirty:
|
||||||
if sync_signal in self._domain_signals[domain]:
|
if sync_signal in self._domain_signals[domain]:
|
||||||
|
@ -303,7 +305,7 @@ class Simulator:
|
||||||
|
|
||||||
if self._vcd_writer:
|
if self._vcd_writer:
|
||||||
for vcd_signal in self._vcd_signals[signal]:
|
for vcd_signal in self._vcd_signals[signal]:
|
||||||
self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
|
self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new)
|
||||||
|
|
||||||
def _handle_event(self):
|
def _handle_event(self):
|
||||||
handlers = set()
|
handlers = set()
|
||||||
|
@ -320,7 +322,7 @@ class Simulator:
|
||||||
self._commit_signal(signal)
|
self._commit_signal(signal)
|
||||||
|
|
||||||
def _force_signal(self, signal, value):
|
def _force_signal(self, signal, value):
|
||||||
assert signal in self._comb_signals or signal in self._user_signals
|
assert signal in self._user_signals
|
||||||
self._state.set_next(signal, value)
|
self._state.set_next(signal, value)
|
||||||
self._commit_signal(signal)
|
self._commit_signal(signal)
|
||||||
|
|
||||||
|
@ -339,7 +341,7 @@ class Simulator:
|
||||||
elif isinstance(stmt, Assign):
|
elif isinstance(stmt, Assign):
|
||||||
assert isinstance(stmt.lhs, Signal)
|
assert isinstance(stmt.lhs, Signal)
|
||||||
assert isinstance(stmt.rhs, Const)
|
assert isinstance(stmt.rhs, Const)
|
||||||
self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
|
self._force_signal(stmt.lhs, normalize(stmt.rhs.value, stmt.lhs.shape()))
|
||||||
else:
|
else:
|
||||||
raise TypeError("Received unsupported statement '{!r}' from process {}"
|
raise TypeError("Received unsupported statement '{!r}' from process {}"
|
||||||
.format(stmt, proc))
|
.format(stmt, proc))
|
||||||
|
|
|
@ -218,6 +218,15 @@ class Const(Value):
|
||||||
"""
|
"""
|
||||||
src_loc = None
|
src_loc = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize(value, shape):
|
||||||
|
nbits, signed = shape
|
||||||
|
mask = (1 << nbits) - 1
|
||||||
|
value &= mask
|
||||||
|
if signed and value >> (nbits - 1):
|
||||||
|
value |= ~mask
|
||||||
|
return value
|
||||||
|
|
||||||
def __init__(self, value, shape=None):
|
def __init__(self, value, shape=None):
|
||||||
self.value = int(value)
|
self.value = int(value)
|
||||||
if shape is None:
|
if shape is None:
|
||||||
|
@ -227,11 +236,7 @@ class Const(Value):
|
||||||
self.nbits, self.signed = shape
|
self.nbits, self.signed = shape
|
||||||
if not isinstance(self.nbits, int) or self.nbits < 0:
|
if not isinstance(self.nbits, int) or self.nbits < 0:
|
||||||
raise TypeError("Width must be a positive integer")
|
raise TypeError("Width must be a positive integer")
|
||||||
|
self.value = self.normalize(self.value, shape)
|
||||||
mask = (1 << self.nbits) - 1
|
|
||||||
self.value &= mask
|
|
||||||
if self.signed and self.value >> (self.nbits - 1):
|
|
||||||
self.value |= ~mask
|
|
||||||
|
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self.nbits, self.signed
|
return self.nbits, self.signed
|
||||||
|
|
Loading…
Reference in a new issue