hdl.xfrm: separate AST traversal from AST identity mapping.

This is useful because backends don't generally want or need AST
identity mapping (unlike all other transforms) and when adding a new
node, it results in confusing type errors.
This commit is contained in:
whitequark 2018-12-16 11:24:23 +00:00
parent 286a8009c8
commit 2be76fda3c
3 changed files with 94 additions and 26 deletions

View file

@ -1,3 +1,4 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Iterable
@ -8,41 +9,52 @@ from .cd import *
from .ir import *
__all__ = ["ValueTransformer", "StatementTransformer", "FragmentTransformer",
__all__ = ["AbstractValueTransformer", "ValueTransformer",
"AbstractStatementTransformer", "StatementTransformer",
"FragmentTransformer",
"DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"]
class ValueTransformer:
class AbstractValueTransformer(metaclass=ABCMeta):
@abstractmethod
def on_Const(self, value):
return value
pass
@abstractmethod
def on_Signal(self, value):
return value
pass
@abstractmethod
def on_ClockSignal(self, value):
return value
pass
@abstractmethod
def on_ResetSignal(self, value):
return value
pass
@abstractmethod
def on_Operator(self, value):
return Operator(value.op, [self.on_value(o) for o in value.operands])
pass
@abstractmethod
def on_Slice(self, value):
return Slice(self.on_value(value.value), value.start, value.end)
pass
@abstractmethod
def on_Part(self, value):
return Part(self.on_value(value.value), self.on_value(value.offset), value.width)
pass
@abstractmethod
def on_Cat(self, value):
return Cat(self.on_value(o) for o in value.operands)
pass
@abstractmethod
def on_Repl(self, value):
return Repl(self.on_value(value.value), value.count)
pass
@abstractmethod
def on_ArrayProxy(self, value):
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
self.on_value(value.index))
pass
def on_unknown_value(self, value):
raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
@ -78,19 +90,51 @@ class ValueTransformer:
return self.on_value(value)
class StatementTransformer:
def on_value(self, value):
class ValueTransformer(AbstractValueTransformer):
def on_Const(self, value):
return value
def on_Signal(self, value):
return value
def on_ClockSignal(self, value):
return value
def on_ResetSignal(self, value):
return value
def on_Operator(self, value):
return Operator(value.op, [self.on_value(o) for o in value.operands])
def on_Slice(self, value):
return Slice(self.on_value(value.value), value.start, value.end)
def on_Part(self, value):
return Part(self.on_value(value.value), self.on_value(value.offset), value.width)
def on_Cat(self, value):
return Cat(self.on_value(o) for o in value.operands)
def on_Repl(self, value):
return Repl(self.on_value(value.value), value.count)
def on_ArrayProxy(self, value):
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
self.on_value(value.index))
class AbstractStatementTransformer(metaclass=ABCMeta):
@abstractmethod
def on_Assign(self, stmt):
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
pass
@abstractmethod
def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases)
pass
@abstractmethod
def on_statements(self, stmt):
return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
pass
def on_unknown_statement(self, stmt):
raise TypeError("Cannot transform statement '{!r}'".format(stmt)) # :nocov:
@ -109,6 +153,21 @@ class StatementTransformer:
return self.on_statement(value)
class StatementTransformer(AbstractStatementTransformer):
def on_value(self, value):
return value
def on_Assign(self, stmt):
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
def on_Switch(self, stmt):
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases)
def on_statements(self, stmt):
return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
class FragmentTransformer:
def map_subfragments(self, fragment, new_fragment):
for subfragment, name in fragment.subfragments: