How to use the dask.core.flatten function in dask

To help you get started, we’ve selected a few dask examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github dask / dask / dask / array / core.py View on Github external
>>> x = np.array([[1, 2]])
    >>> concatenate3([[x, x, x], [x, x, x]])
    array([[1, 2, 1, 2, 1, 2],
           [1, 2, 1, 2, 1, 2]])

    >>> concatenate3([[x, x], [x, x], [x, x]])
    array([[1, 2, 1, 2],
           [1, 2, 1, 2],
           [1, 2, 1, 2]])
    """
    arrays = concrete(arrays)
    if not arrays:
        return np.empty(0)

    advanced = max(core.flatten(arrays, container=(list, tuple)),
                   key=lambda x: getattr(x, '__array_priority__', 0))
    module = package_of(type(advanced)) or np
    if module is not np and hasattr(module, 'concatenate'):
        x = unpack_singleton(arrays)
        return _concatenate2(arrays, axes=list(range(x.ndim)))

    ndim = ndimlist(arrays)
    if not ndim:
        return arrays
    chunks = chunks_from_arrays(arrays)
    shape = tuple(map(sum, chunks))

    def dtype(x):
        try:
            return x.dtype
        except AttributeError:
github dask / dask / dask / array / routines.py View on Github external
"See the numpy.histogram docstring for more information."
        )

    if not np.iterable(bins):
        bin_token = bins
        mn, mx = range
        if mn == mx:
            mn -= 0.5
            mx += 0.5

        bins = np.linspace(mn, mx, bins + 1, endpoint=True)
    else:
        bin_token = bins
    token = tokenize(a, bin_token, range, weights, density)

    nchunks = len(list(flatten(a.__dask_keys__())))
    chunks = ((1,) * nchunks, (len(bins) - 1,))

    name = "histogram-sum-" + token

    # Map the histogram to all bins
    def block_hist(x, range=None, weights=None):
        return np.histogram(x, bins, range=range, weights=weights)[0][np.newaxis]

    if weights is None:
        dsk = {
            (name, i, 0): (block_hist, k, range)
            for i, k in enumerate(flatten(a.__dask_keys__()))
        }
        dtype = np.histogram([])[0].dtype
    else:
        a_keys = flatten(a.__dask_keys__())
github dask / dask / dask / optimization.py View on Github external
>>> dsk # doctest: +SKIP
    {'c': (inc, (inc, 1))}
    >>> dsk, dependencies = fuse(d, keys=['b'], rename_keys=False)
    >>> dsk  # doctest: +SKIP
    {'b': (inc, 1), 'c': (inc, 'b')}

    Returns
    -------
    dsk: output graph with keys fused
    dependencies: dict mapping dependencies after fusion.  Useful side effect
        to accelerate other downstream optimizations.
    """
    if keys is not None and not isinstance(keys, set):
        if not isinstance(keys, list):
            keys = [keys]
        keys = set(flatten(keys))

    if dependencies is None:
        dependencies = {k: get_dependencies(dsk, k, as_list=True) for k in dsk}

    # locate all members of linear chains
    child2parent = {}
    unfusible = set()
    for parent in dsk:
        deps = dependencies[parent]
        has_many_children = len(deps) > 1
        for child in deps:
            if keys is not None and child in keys:
                unfusible.add(child)
            elif child in child2parent:
                del child2parent[child]
                unfusible.add(child)
github dask / distributed / distributed / client.py View on Github external
See Also
        --------
        Client.compute
        """
        if isinstance(collections, (tuple, list, set, frozenset)):
            singleton = False
        else:
            singleton = True
            collections = [collections]

        assert all(map(dask.is_dask_collection, collections))

        dsk = self.collections_to_dsk(collections, optimize_graph, **kwargs)

        names = {k for c in collections for k in flatten(c.__dask_keys__())}

        restrictions, loose_restrictions = self.get_restrictions(collections,
                                                                 workers, allow_other_workers)

        if resources:
            resources = self._expand_resources(resources,
                                               all_keys=itertools.chain(dsk, names))

        if retries:
            retries = self._expand_retries(retries,
                                           all_keys=itertools.chain(dsk, names))
        else:
            retries = None

        if not isinstance(priority, Number):
            priority = {k: p for c, p in priority.items()
github dask / distributed / distributed / executor.py View on Github external
def get_restrictions(collections, workers, allow_other_workers):
    """ Get restrictions from inputs to compute/persist """
    if isinstance(workers, (str, tuple, list)):
        workers = {tuple(collections): workers}
    if isinstance(workers, dict):
        restrictions = {}
        for colls, ws in workers.items():
            if isinstance(ws, str):
                ws = [ws]
            if hasattr(colls, '._keys'):
                keys = flatten(colls._keys())
            else:
                keys = list({k for c in flatten(colls)
                                for k in flatten(c._keys())})
            restrictions.update({k: ws for k in keys})
    else:
        restrictions = {}

    if allow_other_workers is True:
        loose_restrictions = list(restrictions)
    elif allow_other_workers:
        loose_restrictions = list({k for c in flatten(allow_other_workers)
                                     for k in c._keys()})
    else:
        loose_restrictions = []

    return restrictions, loose_restrictions
github dask / dask / dask / array / core.py View on Github external
if isinstance(regions, tuple) or regions is None:
        regions = [regions]

    if len(sources) > 1 and len(regions) == 1:
        regions *= len(sources)

    if len(sources) != len(regions):
        raise ValueError("Different number of sources [%d] and targets [%d] than regions [%d]"
                         % (len(sources), len(targets), len(regions)))

    # Optimize all sources together
    sources_dsk = sharedict.merge(*[e.__dask_graph__() for e in sources])
    sources_dsk = Array.__dask_optimize__(
        sources_dsk,
        list(core.flatten([e.__dask_keys__() for e in sources]))
    )
    sources2 = [Array(sources_dsk, e.name, e.chunks, e.dtype) for e in sources]

    # Optimize all targets together
    targets2 = []
    targets_keys = []
    targets_dsk = []
    for e in targets:
        if isinstance(e, Delayed):
            targets2.append(e.key)
            targets_keys.extend(e.__dask_keys__())
            targets_dsk.append(e.__dask_graph__())
        elif is_dask_collection(e):
            raise TypeError(
                "Targets must be either Delayed objects or array-likes"
            )
github dask / dask / dask / array / core.py View on Github external
try:
            if region is None:
                out[index] = np.asanyarray(x)
            else:
                out[fuse_slice(region, index)] = np.asanyarray(x)
        finally:
            if lock:
                lock.release()

        return None

    slices = slices_from_chunks(arr.chunks)

    name = 'store-%s' % arr.name
    dsk = dict(((name,) + t[1:], (store, out, t, slc, lock, region))
               for t, slc in zip(core.flatten(arr._keys()), slices))

    return dsk
github dask / dask / dask / array / reshape.py View on Github external
"Array chunk size or shape is unknown. shape: %s\n\n"
            "Possible solution with x.compute_chunk_sizes()" % x.shape
        )

    if reduce(mul, shape, 1) != x.size:
        raise ValueError("total size of new array must be unchanged")

    if x.shape == shape:
        return x

    meta = meta_from_array(x, len(shape))

    name = "reshape-" + tokenize(x, shape)

    if x.npartitions == 1:
        key = next(flatten(x.__dask_keys__()))
        dsk = {(name,) + (0,) * len(shape): (M.reshape, key, shape)}
        chunks = tuple((d,) for d in shape)
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
        return Array(graph, name, chunks, meta=meta)

    # Logic for how to rechunk
    inchunks, outchunks = reshape_rechunk(x.shape, shape, x.chunks)
    x2 = x.rechunk(inchunks)

    # Construct graph
    in_keys = list(product([x2.name], *[range(len(c)) for c in inchunks]))
    out_keys = list(product([name], *[range(len(c)) for c in outchunks]))
    shapes = list(product(*outchunks))
    dsk = {a: (M.reshape, b, shape) for a, b, shape in zip(out_keys, in_keys, shapes)}

    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x2])
github dask / dask / dask / array / optimization.py View on Github external
def optimize(
    dsk,
    keys,
    fuse_keys=None,
    fast_functions=None,
    inline_functions_fast_functions=(getter_inline,),
    rename_fused_keys=True,
    **kwargs
):
    """ Optimize dask for array computation

    1.  Cull tasks not necessary to evaluate keys
    2.  Remove full slicing, e.g. x[:]
    3.  Inline fast functions like getitem and np.transpose
    """
    keys = list(flatten(keys))

    # High level stage optimization
    if isinstance(dsk, HighLevelGraph):
        dsk = optimize_blockwise(dsk, keys=keys)
        dsk = fuse_roots(dsk, keys=keys)

    # Low level task optimizations
    dsk = ensure_dict(dsk)
    if fast_functions is not None:
        inline_functions_fast_functions = fast_functions

    dsk2, dependencies = cull(dsk, keys)
    hold = hold_keys(dsk2, dependencies)

    dsk3, dependencies = fuse(
        dsk2,
github dask / distributed / distributed / client.py View on Github external
def get_restrictions(cls, collections, workers, allow_other_workers):
        """ Get restrictions from inputs to compute/persist """
        if isinstance(workers, (str, tuple, list)):
            workers = {tuple(collections): workers}
        if isinstance(workers, dict):
            restrictions = {}
            for colls, ws in workers.items():
                if isinstance(ws, str):
                    ws = [ws]
                if dask.is_dask_collection(colls):
                    keys = flatten(colls.__dask_keys__())
                else:
                    keys = list({k for c in flatten(colls)
                                 for k in flatten(c.__dask_keys__())})
                restrictions.update({k: ws for k in keys})
        else:
            restrictions = {}

        if allow_other_workers is True:
            loose_restrictions = list(restrictions)
        elif allow_other_workers:
            loose_restrictions = list({k for c in flatten(allow_other_workers)
                                       for k in c.__dask_keys__()})
        else:
            loose_restrictions = []

        return restrictions, loose_restrictions