Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if node in self.optimizable_comprehension:
self.update = True
self.generic_visit(node)
iters = [self.make_Iterator(gen) for gen in node.generators]
variables = [ast.Name(gen.target.id, ast.Param(), None, None)
for gen in node.generators]
# If dim = 1, product is useless
if len(iters) == 1:
iterAST = iters[0]
varAST = ast.arguments([variables[0]], [], None, [], [], None, [])
else:
self.use_itertools = True
prodName = ast.Attribute(
value=ast.Name(id=mangle('itertools'),
ctx=ast.Load(),
annotation=None, type_comment=None),
attr='product', ctx=ast.Load())
varid = variables[0].id # retarget this id, it's free
renamings = {v.id: (i,) for i, v in enumerate(variables)}
node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
iterAST = ast.Call(prodName, iters, [])
varAST = ast.arguments([ast.Name(varid, ast.Param(), None, None)],
[], None, [], [], None, [])
ldBodymap = node.elt
ldmap = ast.Lambda(varAST, ldBodymap)
return make_attr(ldmap, iterAST)
def visit_FunctionDef(self, node):
self.generic_visit(node)
kept_decorators = []
for dec in node.decorator_list:
if isinstance(dec, gast.Call):
dec_func = dec.func
else:
dec_func = dec
# Special cases.
# TODO(mdan): Is there any way we can treat these more generically?
# We may want to forego using decorators altogether if we can't
# properly support them.
if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',):
# Assumption: decorators are only visible in the AST when converting
# a function inline (via another decorator).
# In that case, the converted function is no longer part of the
# original object that it was declared into.
# This is currently verified by tests.
continue
if not anno.hasanno(dec_func, 'live_val'):
raise ValueError(
'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func))
dec_value = anno.getanno(dec_func, 'live_val')
if dec_value not in self.remove_decorators:
kept_decorators.append((dec, dec_value))
for _, dec_value in kept_decorators:
def _process(self, node):
qn = anno.getanno(node, anno.Basic.QN)
if qn in self.name_map:
return gast.Name(str(self.name_map[qn]), node.ctx, None)
return self.generic_visit(node)
def visit_Assign(self, node):
"""
In case of assignment assign value depend on r-value type dependencies.
It is valid for subscript, `a[i] = foo()` means `a` type depend on
`foo` return type.
"""
value_deps = self.visit(node.value)
for target in node.targets:
name = get_variable(target)
if isinstance(name, ast.Name):
self.naming[name.id] = value_deps
loads[n.id].append(n)
if isinstance(child, gast.Assign):
name = child.targets[0].id
if name in loads:
if name in lcds:
raise NotImplementedError("cannot process LCD "
"stored to twice")
lcds.add(name)
node = SplitAttributes().visit(node)
synchronizes = []
for name in lcds:
synchronize = gast.Assign(
[gast.Name(name, gast.Store(), None)],
gast.Call(
gast.Attribute(
gast.Name(name, gast.Load(), None),
gast.Name('_synchronize', gast.Load(), None),
None),
[], []))
synchronizes.append(synchronize)
node.body.extend(synchronizes)
return node
def visit_Compare(self, node):
node = self.generic_visit(node)
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None
# Repeated comparisons are converted to conjunctions:
# a < b < c -> a < b and b < c
while ops_and_comps:
op, right = ops_and_comps.pop(0)
binary_comparison = self._as_function(
self._matching_func(op), (left, right))
if isinstance(left, gast.Name) and isinstance(right, gast.Name):
anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
if op_tree:
self._expect_simple_symbol(right)
op_tree = self._as_function('tf.logical_and',
(binary_comparison, op_tree))
else:
op_tree = binary_comparison
left = right
assert op_tree is not None
return op_tree
def _gate_symbols(self, guard_statement, guarded_args):
def template(args): # pylint:disable=unused-argument
(args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable
guards = templates.replace(
template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
guard_statement.body.extend(guards)
return guard_statement