Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Generate an If node."""
test = self.generate_Compare()
# Generate true branch statements
body = self.sample_node_list(
low=1,
high=N_CONTROLFLOW_STATEMENTS // 2,
generator=self.generate_statement)
# Generate false branch statements
orelse = self.sample_node_list(
low=1,
high=N_CONTROLFLOW_STATEMENTS // 2,
generator=self.generate_statement)
node = gast.If(test, body, orelse)
return node
def visit(self, node):
if node in self.to_remove:
self.remove = True
if anno.hasanno(node, 'pri_call') or anno.hasanno(node, 'adj_call'):
# We don't remove function calls for now; removing them also
# removes the push statements inside of them, but not the
# corresponding pop statements
self.is_call = True
new_node = super(Remove, self).visit(node)
if isinstance(node, grammar.STATEMENTS):
if self.remove and not self.is_call:
new_node = None
self.remove = self.is_call = False
if isinstance(node, gast.If) and not node.body:
# If we optimized away an entire if block, we need to handle that
if not node.orelse:
return
else:
node.test = gast.UnaryOp(op=gast.Not, operand=node.test)
node.body, node.orelse = node.orelse, node.body
elif isinstance(node, (gast.While, gast.For)) and not node.body:
return node.orelse
return new_node
def visit_nodelist(self, nodelist):
for i in range(len(nodelist)):
node = nodelist[i]
if isinstance(node, gast.If):
true_branch_returns = isinstance(node.body[-1], gast.Return)
false_branch_returns = len(node.orelse) and isinstance(
node.orelse[-1], gast.Return)
# If the last node in the if body is a return,
# then every line after this if statement effectively
# belongs in the else.
if true_branch_returns and not false_branch_returns:
for j in range(i + 1, len(nodelist)):
nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j]))
if nodelist[i + 1:]:
self.changes_made = True
return nodelist[:i + 1]
elif not true_branch_returns and false_branch_returns:
for j in range(i + 1, len(nodelist)):
nodelist[i].body.append(ast_util.copy_clean(nodelist[j]))
if nodelist[i + 1:]:
assign = cont_ass = [ast.Assign(
[ast.Tuple(expected_return, ast.Store())],
ast.Name(cont_n, ast.Load(), None))]
else:
assign = cont_ass = []
if has_cont:
cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None),
[ast.Eq()], [ast.Num(LOOP_CONT)])
cont_ass = [ast.If(cmpr,
deepcopy(assign) + [ast.Continue()],
cont_ass)]
if has_break:
cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None),
[ast.Eq()], [ast.Num(LOOP_BREAK)])
cont_ass = [ast.If(cmpr,
deepcopy(assign) + [ast.Break()],
cont_ass)]
return cont_ass
unify(ty_iteration, TyList(ty_i))
if ty_iteration.is_fixed_len:
ty_iteration.coerce_to_variable_len(ty_i)
for stmt in node.body:
self.infer_stmt(stmt)
self.nodetype[node] = TyNone()
elif isinstance(node, gast.While):
# While(expr test, stmt* body, stmt* orelse)
pass
elif isinstance(node, gast.If):
# If(expr test, stmt* body, stmt* orelse)
ty_test = self.infer_expr(node.test)
# TODO(momohatt): determine what type should ty_test be
if node.orelse == []:
tc = TypeChecker(self.tyenv)
for stmt in node.body:
tc.infer_stmt(stmt)
# 1. unify the intersection of 2 tyenvs
for name, ty in tc.tyenv.items():
if name in self.tyenv.keys():
unify(ty, self.tyenv[name])
# 2. update local tyenv
for name, ty in tc.tyenv.items():
def make_fake(stmts):
return ast.If(ast.Num(0), stmts, [])
bool_values.append(gast.Name(id=self.breaked_flag + str(breaked_id), ctx=gast.Load(), annotation=None))
if len(self.func_returned_stack) > 0:
returned_id = len(self.func_returned_stack)
returned_flags = self.func_returned_stack[-1]
if returned_flags:
bool_values.append(gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Load(), annotation=None))
if len(bool_values) > 0:
if len(bool_values) == 1:
cond = bool_values[0]
elif len(bool_values) > 1:
cond = gast.BoolOp(op=gast.Or(), values=bool_values)
node.body.append(gast.Assign(targets=[gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None)], value=gast.UnaryOp(op=gast.Not(), operand=cond)))
node.body.append(gast.If(test=cond, body=[gast.Break()], orelse=[]))
return modified_node
if node_transonic_obj is not None:
break
if node_transonic_obj is None:
return blocks
def_ = duc.chains[node_transonic_obj]
nodes_using_ts = [user.node for user in def_.users()]
for user in def_.users():
for user1 in user.users():
if isinstance(user1.node, ast.Attribute):
attribute = user1.node
if attribute.attr == "is_transpiled":
parent = ancestors.parent(attribute)
if isinstance(parent, ast.If):
# it could be the begining of a block
if_node = parent
if len(parent.body) != 1:
# no it's not a block definition
continue
node = parent.body[0]
call = node.value
if isinstance(node, ast.Expr):
results = []
elif isinstance(node, ast.Assign):
results = [target.id for target in node.targets]
else:
# no it's not a block definition
continue
attribute = call.func
# We will store the condition on the stack
cond = self.namer.cond()
push, pop, op_id = get_push_pop()
# Fill in the templates
primal_template = grads.primals[gast.If]
primal = template.replace(
primal_template,
body=body,
cond=cond,
test=node.test,
orelse=orelse,
push=push,
_stack=self.stack,
op_id=op_id)
adjoint_template = grads.adjoints[gast.If]
adjoint = template.replace(
adjoint_template,
cond=cond,
adjoint_body=adjoint_body,
adjoint_orelse=adjoint_orelse,
pop=pop,
_stack=self.stack,
op_id=op_id)
return primal, adjoint
@adjoint(gast.If)
def dif_(cond, adjoint_body, adjoint_orelse, pop, _stack, op_id):
cond = pop(_stack, op_id)
if cond:
adjoint_body
else:
adjoint_orelse