Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def clean_generated_code(code: str) -> str:
"""
Generalized sanity clean-up for all codegen so we can fix issues such as
Union[SingleType]. The transforms found here are strictly for form and
do not affect functionality.
"""
module = parse_module(code)
module = ensure_type(module.visit(SimplifyUnionsTransformer()), cst.Module)
module = ensure_type(module.visit(DoubleQuoteForwardRefsTransformer()), cst.Module)
return module.code
typecst = parse_expression(typestr)
typecst = typecst.visit(cleanser)
aliases: List[Alias] = []
# Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
if isinstance(typecst, cst.Subscript):
clean_type = _get_clean_type_from_subscript(aliases, typecst)
elif isinstance(typecst, (cst.Name, cst.SimpleString)):
clean_type = _get_clean_type_from_expression(aliases, typecst)
else:
raise Exception("Logic error, unexpected top level type!")
# Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
# This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
# spot that we would have originally allowed a SomeType.
clean_type = ensure_type(clean_type.visit(AddLogicMatchersToUnions()), cst.CSTNode)
# Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
# them. This relies on the previous OneOf/AllOf insertion to ensure that all
# sequences we care about are Sequence[Union[]].
clean_type = ensure_type(
clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
)
# Finally, generate the code given a default Module so we can spit it out.
return cst.Module(body=()).code_for_node(clean_type), aliases
# Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
if isinstance(typecst, cst.Subscript):
clean_type = _get_clean_type_from_subscript(aliases, typecst)
elif isinstance(typecst, (cst.Name, cst.SimpleString)):
clean_type = _get_clean_type_from_expression(aliases, typecst)
else:
raise Exception("Logic error, unexpected top level type!")
# Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
# This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
# spot that we would have originally allowed a SomeType.
clean_type = ensure_type(clean_type.visit(AddLogicMatchersToUnions()), cst.CSTNode)
# Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
# them. This relies on the previous OneOf/AllOf insertion to ensure that all
# sequences we care about are Sequence[Union[]].
clean_type = ensure_type(
clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
)
# Finally, generate the code given a default Module so we can spit it out.
return cst.Module(body=()).code_for_node(clean_type), aliases
def _visit_import_alike(self, node: Union[cst.Import, cst.ImportFrom]) -> bool:
names = node.names
if not isinstance(names, cst.ImportStar):
# make sure node.names is Sequence[ImportAlias]
for name in names:
asname = name.asname
if asname is not None:
name_value = cst.ensure_type(asname.name, cst.Name).value
else:
name_node = name.name
while isinstance(name_node, cst.Attribute):
# the value of Attribute in import alike can only be either Name or Attribute
name_node = name_node.value
if isinstance(name_node, cst.Name):
name_value = name_node.value
else:
raise Exception(
f"Unexpected ImportAlias name value: {name_node}"
)
self.scope.record_assignment(name_value, node)
# visit remaining attributes
if isinstance(node, cst.Import):
def _leave_union(
self, original_node: cst.Subscript, updated_node: cst.Subscript
) -> cst.BaseExpression:
slc = updated_node.slice
# TODO: We can remove the instance check after ExtSlice is deprecated.
if isinstance(slc, (cst.Slice, cst.Index)):
# This is deprecated, so lets not support it.
raise Exception("Unexpected Slice in Union!")
if len(slc) == 1:
# This is a Union[SimpleType,] which is equivalent to
# just SimpleType
return cst.ensure_type(slc[0].slice, cst.Index).value
return updated_node
def _check_formatted_string(
self,
_original_node: libcst.FormattedString,
updated_node: libcst.FormattedString,
) -> libcst.BaseExpression:
old_string_inner = libcst.ensure_type(
updated_node.parts[0], libcst.FormattedStringText
).value
if "{{" in old_string_inner or "}}" in old_string_inner:
# there are only two characters we need to worry about escaping.
return updated_node
old_string_literal = updated_node.start + old_string_inner + updated_node.end
new_string_literal = (
updated_node.start.replace("f", "").replace("F", "")
+ old_string_inner
+ updated_node.end
)
old_string_evaled = eval(old_string_literal) # noqa
new_string_evaled = eval(new_string_literal) # noqa
if old_string_evaled != new_string_evaled:
def leave_EmptyLine(
self, original_node: libcst.EmptyLine, updated_node: libcst.EmptyLine
) -> Union[libcst.EmptyLine, libcst.RemovalSentinel]:
if updated_node.comment is None or not bool(
self._regex_pattern.search(
libcst.ensure_type(updated_node.comment, libcst.Comment).value
)
):
# This is a normal comment
return updated_node
# This is a directive comment matching our tag, so remove it.
return libcst.RemoveFromParent()
def leave_Call( # noqa: C901
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.BaseExpression:
# Lets figure out if this is a "".format() call
if self.matches(
updated_node,
m.Call(func=m.Attribute(value=m.SimpleString(), attr=m.Name("format"))),
):
fstring: List[cst.BaseFormattedStringContent] = []
inserted_sequence: int = 0
# TODO: Use `extract` when it becomes available.
stringvalue = cst.ensure_type(
cst.ensure_type(updated_node.func, cst.Attribute).value,
cst.SimpleString,
).value
prefix, quote, innards = _string_prefix_and_quotes(stringvalue)
tokens = _get_tokens(innards)
for (literal_text, field_name, format_spec, conversion) in tokens:
if literal_text:
fstring.append(cst.FormattedStringText(literal_text))
if field_name is None:
# This is not a format-specification
continue
if format_spec is not None and len(format_spec) > 0:
# TODO: This is supportable since format specs are compatible
# with f-string format specs, but it would require matching
# format specifier expansions.
self.warn(f"Unsupported format_spec {format_spec} in format() call")
return updated_node