Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@mutate.register(LazyTbl)
def _mutate(__data, **kwargs):
# Cases
# - work with group by
# - window functions
# TODO: verify it can follow a renaming select
# track labeled columns in set
sel = __data.last_op
# evaluate each call
for colname, func in kwargs.items():
# keep set of columns labeled (aliased) in this select statement
# need to use inner cols, since sel.columns uses ColumnClause, not Label
labs = set(k for k,v in lift_inner_cols(sel).items() if isinstance(v, sql.elements.Label))
new_call = __data.shape_call(func, verb_name = "Mutate", arg_name = colname)
def _distinct(__data, *args, _keep_all = False, **kwargs):
if (args or kwargs) and _keep_all:
raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False")
inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op
# TODO: this is copied from the df distinct version
# cols dict below is used as ordered set
cols = {simple_varname(x): True for x in args}
cols.update(kwargs)
if None in cols:
raise KeyError("positional arguments must be simple column, "
"e.g. _.colname or _['colname']"
)
# use all columns by default
if not cols:
cols = list(inner_sel.columns.keys())
if not len(inner_sel._order_by_clause):
# similar to filter verb, we need two select statements,
# an inner one for derived cols, and outer to group by them
# inner select ----
# holds any mutation style columns
arg_names = []
for arg in args:
name = simple_varname(arg)
if name is None:
raise NotImplementedError(
"Count positional arguments must be single column name. "
"Use a named argument to count using complex expressions."
)
arg_names.append(name)
tbl_inner = mutate(__data, **kwargs)
sel_inner = tbl_inner.last_op
group_cols = arg_names + list(kwargs)
# outer select ----
# holds selected columns and tally (n)
sel_inner_cte = sel_inner.alias()
inner_cols = sel_inner_cte.columns
sel_outer = sql.select(from_obj = sel_inner_cte)
# apply any group vars from a group_by verb call first
prev_group_cols = [inner_cols[k] for k in tbl_inner.group_by]
if prev_group_cols:
sel_outer.append_group_by(*prev_group_cols)
sel_outer.append_column(*prev_group_cols)
# now any defined in the count verb call
def _fast_mutate_default(__data, **kwargs):
# TODO: had to register object second, since singledispatch2 sets object dispatch
# to be a pipe (e.g. unknown types become a pipe by default)
# by default dispatch to regular mutate
f = mutate.registry[type(__data)]
return f(__data, **kwargs)
def _group_by(__data, *args, add = False, **kwargs):
if kwargs:
data = mutate(__data, **kwargs)
else:
data = __data
cols = data.last_op.columns
# put kwarg grouping vars last, so similar order to function call
groups = tuple(simple_varname(arg) for arg in args) + tuple(kwargs)
if None in groups:
raise NotImplementedError("Complex expressions not supported in sql group_by")
unmatched = set(groups) - set(cols.keys())
if unmatched:
raise KeyError("group_by specifies columns missing from table: %s" %unmatched)
if add:
groups = ordered_union(data.group_by, groups)