back.pysim: implement LHS for Part, Slice, Cat, ArrayProxy.
This commit is contained in:
parent
d9579219ee
commit
d4e8d3e95a
3 changed files with 93 additions and 17 deletions
|
|
@ -45,8 +45,9 @@ normalize = Const.normalize
|
|||
|
||||
|
||||
class _RHSValueCompiler(ValueTransformer):
|
||||
def __init__(self, sensitivity=None):
|
||||
def __init__(self, sensitivity=None, mode="rhs"):
|
||||
self.sensitivity = sensitivity
|
||||
self.signal_mode = mode
|
||||
|
||||
def on_Const(self, value):
|
||||
return lambda state: value.value
|
||||
|
|
@ -54,7 +55,12 @@ class _RHSValueCompiler(ValueTransformer):
|
|||
def on_Signal(self, value):
|
||||
if self.sensitivity is not None:
|
||||
self.sensitivity.add(value)
|
||||
return lambda state: state.curr[value]
|
||||
if self.signal_mode == "rhs":
|
||||
return lambda state: state.curr[value]
|
||||
elif self.signal_mode == "lhs":
|
||||
return lambda state: state.next[value]
|
||||
else:
|
||||
raise ValueError # :nocov:
|
||||
|
||||
def on_ClockSignal(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
|
@ -160,11 +166,17 @@ class _RHSValueCompiler(ValueTransformer):
|
|||
|
||||
|
||||
class _LHSValueCompiler(ValueTransformer):
|
||||
def __init__(self, rhs_compiler):
|
||||
self.rhs_compiler = rhs_compiler
|
||||
|
||||
def on_Const(self, value):
|
||||
raise TypeError # :nocov:
|
||||
|
||||
def on_Signal(self, value):
|
||||
return lambda state, arg: state.set(value, arg)
|
||||
shape = value.shape()
|
||||
def eval(state, rhs):
|
||||
state.set(value, normalize(rhs, shape))
|
||||
return eval
|
||||
|
||||
def on_ClockSignal(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
|
@ -176,37 +188,69 @@ class _LHSValueCompiler(ValueTransformer):
|
|||
raise TypeError # :nocov:
|
||||
|
||||
def on_Slice(self, value):
|
||||
raise NotImplementedError
|
||||
lhs_r = self.rhs_compiler(value.value)
|
||||
lhs_l = self(value.value)
|
||||
shift = value.start
|
||||
mask = (1 << (value.end - value.start)) - 1
|
||||
def eval(state, rhs):
|
||||
lhs_value = lhs_r(state)
|
||||
lhs_value &= ~(mask << shift)
|
||||
lhs_value |= (rhs & mask) << shift
|
||||
lhs_l(state, lhs_value)
|
||||
return eval
|
||||
|
||||
def on_Part(self, value):
|
||||
raise NotImplementedError
|
||||
lhs_r = self.rhs_compiler(value.value)
|
||||
lhs_l = self(value.value)
|
||||
shift = self.rhs_compiler(value.offset)
|
||||
mask = (1 << value.width) - 1
|
||||
def eval(state, rhs):
|
||||
lhs_value = lhs_r(state)
|
||||
shift_value = shift(state)
|
||||
lhs_value &= ~(mask << shift_value)
|
||||
lhs_value |= (rhs & mask) << shift_value
|
||||
lhs_l(state, lhs_value)
|
||||
return eval
|
||||
|
||||
def on_Cat(self, value):
|
||||
raise NotImplementedError
|
||||
parts = []
|
||||
offset = 0
|
||||
for opnd in value.operands:
|
||||
parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
|
||||
offset += len(opnd)
|
||||
def eval(state, rhs):
|
||||
for offset, mask, opnd in parts:
|
||||
opnd(state, (rhs >> offset) & mask)
|
||||
return eval
|
||||
|
||||
def on_Repl(self, value):
|
||||
raise TypeError # :nocov:
|
||||
|
||||
def on_ArrayProxy(self, value):
|
||||
raise NotImplementedError
|
||||
elems = list(map(self, value.elems))
|
||||
index = self.rhs_compiler(value.index)
|
||||
def eval(state, rhs):
|
||||
elems[index(state)](state, rhs)
|
||||
return eval
|
||||
|
||||
|
||||
class _StatementCompiler(StatementTransformer):
|
||||
def __init__(self):
|
||||
self.sensitivity = ValueSet()
|
||||
self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
|
||||
self.lhs_compiler = _LHSValueCompiler()
|
||||
self.sensitivity = ValueSet()
|
||||
self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
|
||||
self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs")
|
||||
self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler)
|
||||
|
||||
def on_Assign(self, stmt):
|
||||
shape = stmt.lhs.shape()
|
||||
lhs = self.lhs_compiler(stmt.lhs)
|
||||
rhs = self.rhs_compiler(stmt.rhs)
|
||||
rhs = self.rrhs_compiler(stmt.rhs)
|
||||
def run(state):
|
||||
lhs(state, normalize(rhs(state), shape))
|
||||
return run
|
||||
|
||||
def on_Switch(self, stmt):
|
||||
test = self.rhs_compiler(stmt.test)
|
||||
test = self.rrhs_compiler(stmt.test)
|
||||
cases = []
|
||||
for value, stmts in stmt.cases.items():
|
||||
if "-" in value:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue