Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_integer(self, tree: Union[ast.Primary, ast.ComponentRef, ast.Expression, ast.Slice]) -> Union[int, ca.MX, np.ndarray]:
# CasADi needs to know the dimensions of symbols at instantiation.
# We therefore need a mechanism to evaluate expressions that define dimensions of symbols.
if isinstance(tree, ast.Primary):
return None if tree.value is None else int(tree.value)
if isinstance(tree, ast.ComponentRef):
s = self.current_class.symbols[tree.name]
assert (s.type.name == 'Integer')
return self.get_integer(s.value)
if isinstance(tree, ast.Expression):
# Make sure that the expression has been converted to MX by (re)visiting the
# relevant part of the AST.
ast_walker = TreeWalker()
ast_walker.walk(self, tree)
# Obtain expression
expr = self.get_mx(tree)
# Obtain the symbols it depends on
free_vars = ca.symvar(expr)
# Find the values of the symbols
vals = []
for free_var in free_vars:
if free_var.is_symbolic():
if (len(self.for_loops) > 0) and (free_var.name() == self.for_loops[-1].name):
vals.append(self.for_loops[-1].index_variable)
else:
vals.append(self.get_integer(self.current_class.symbols[free_var.name()].value))
def generate(ast_tree: ast.Tree, model_name: str):
"""
:param ast_tree: AST to generate from
:param model_name: class to generate
:return: sympy source code for model
"""
component_ref = ast.ComponentRef.from_string(model_name)
ast_tree_new = copy.deepcopy(ast_tree)
ast_walker = TreeWalker()
flat_tree = flatten(ast_tree_new, component_ref)
gen = XmlGenerator()
ast_walker.walk(gen, flat_tree)
return etree.tostring(gen.xml[flat_tree], pretty_print=True).decode('utf-8')
def register_indexed_symbol(self, e, index_function, transpose, tree, index_expr=None):
if isinstance(index_expr, ca.MX) and index_expr is not self.index_variable:
F = ca.Function('index_expr', [self.index_variable], [index_expr])
# expr = lambda ar: np.array([F(a)[0] for a in ar], dtype=np.int)
Fmap = F.map("map", self.generator.map_mode, len(self.values), [], [])
res = Fmap.call([self.values])
indices = np.array(res[0].T, dtype=np.int)
else:
indices = self.values
self.indexed_symbols[e] = ForLoopIndexedSymbol(tree, transpose, index_function(indices - 1))
Assignment = namedtuple('Assignment', ['left', 'right'])
class GeneratorWalker(TreeWalker):
"""TreeWalker that skips processing of annotations"""
def skip_child(self, tree: ast.Node, child_name: str) -> bool:
skip = super().skip_child(tree, child_name)
if isinstance(tree, ast.Class) and child_name == "annotation":
return True
return skip
def order_keys(self, keys: Iterable[str]):
# Symbols must come before classes, as we need to access symbol values when creating
# CasADi interpolant functions.
return sorted(keys, key=lambda attr: 0 if attr == 'symbols' else 1)
# noinspection PyPep8Naming,PyUnresolvedReferences
class Generator(TreeListener):
def generate(ast_tree: ast.Tree, model_name: str):
"""
:param ast_tree: AST to generate from
:param model_name: class to generate
:return: sympy source code for model
"""
component_ref = ast.ComponentRef.from_string(model_name)
ast_tree_new = copy.deepcopy(ast_tree)
ast_walker = TreeWalker()
flat_tree = flatten(ast_tree_new, component_ref)
sympy_gen = SympyGenerator()
ast_walker.walk(sympy_gen, flat_tree)
return sympy_gen.src[flat_tree]