How to use the heat.core.factories.array function in heat

To help you get started, we’ve selected a few heat examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github helmholtz-analytics / heat / heat / core / linalg / basics.py View on Github external
del b_lp_data[pr - 1]

            if pr == b.comm.size - 1:
                req[pr].wait()
                st0 = index_map[pr, 0, 0, 0].item()
                sp0 = index_map[pr, 0, 0, 1].item() + 1
                st1 = index_map[pr, 1, 1, 0].item()
                sp1 = index_map[pr, 1, 1, 1].item()
                c._DNDarray__array[: sp0 - st0, st1:sp1] += a._DNDarray__array @ b_lp_data[pr]

                del b_lp_data[pr]

        c = (
            c
            if not vector_flag
            else factories.array(c._DNDarray__array.squeeze(), is_split=0, device=a.device)
        )
        return c

    elif split_10_flag:
        # for this case, only a sum is needed at the end
        a_rem_locs1 = (rem_map[:, 0, 1] == 1).nonzero()
        # locations of the remainders in b
        b_rem_locs0 = (rem_map[:, 1, 0] == 1).nonzero()
        res = torch.zeros(
            (a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=c.device.torch_device
        )
        for i in range(a.lshape[-1] // kB):
            res += (
                a._DNDarray__array[:mB, i * kB : i * kB + kB]
                @ b._DNDarray__array[i * kB : i * kB + kB, :nB]
            )
github helmholtz-analytics / heat / heat / core / linalg / basics.py View on Github external
c._DNDarray__array[r_loc.item(), st1:sp1] += r[st:sp] @ b_lp_data[pr]
                        else:
                            c._DNDarray__array[r_loc.item(), :] += r[st:sp] @ b_lp_data[pr]

                # set the final blocks on the last loop, then adjust for the
                # the remainders which were collected in b_rem
                if b_rem_locs0.numel():
                    c._DNDarray__array[: a_node_rem_s0.shape[0]] += a_node_rem_s0 @ b_rem

                del b_lp_data[pr]

        if vector_flag:
            c_loc = c._DNDarray__array.squeeze()
            if c_loc.nelement() == 1:
                c = torch.tensor(c_loc, device=c._DNDarray__array.device)
            c = factories.array(c_loc, is_split=0, device=a.device)

        return c

    elif split_1_flag:
        # for this case, a is sent to b
        #   this is because 'b' has complete columns and the rows of 'a' are split
        # locations of the remainders in b
        b_rem_locs1 = (rem_map[:, 1, 1] == 1).nonzero()
        a_rem_locs1 = (rem_map[:, 0, 1] == 1).nonzero()
        b_node_rem_s1 = b._DNDarray__array[
            kB : (kB + 1) * a_rem_locs1.numel() : kB + 1, :nB
        ]  # remainders for a in the
        a_rem = torch.empty(
            a.lshape[-2],
            a_rem_locs1.numel(),
            dtype=b.dtype.torch_type(),
github helmholtz-analytics / heat / heat / core / logical.py View on Github external
def sanitize_input_type(x, y):
        """
        Verifies that x is either a scalar, or a ht.DNDarray. If a scalar, x gets wrapped in a ht.DNDarray.
        Raises TypeError if x is neither.
        """
        if not isinstance(x, dndarray.DNDarray):
            if np.ndim(x) == 0:
                dtype = getattr(x, "dtype", float)
                device = getattr(y, "device", None)
                x = factories.array(x, dtype=dtype, device=device)
            else:
                raise TypeError("Expected DNDarray or numeric scalar, input was {}".format(type(x)))

        return x
github helmholtz-analytics / heat / heat / core / statistics.py View on Github external
if unbiased:
        kwargs["unbiased"] = unbiased
    if Fischer:
        kwargs["Fischer"] = Fischer

    output_shape = list(x.shape)
    if isinstance(axis, int):
        if axis >= len(x.shape):
            raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis))
        axis = stride_tricks.sanitize_axis(x.shape, axis)
        # only one axis given
        output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis]
        output_shape = output_shape if output_shape else (1,)

        if x.split is None:  # x is *not* distributed -> no need to distributed
            return factories.array(
                function(x._DNDarray__array, **kwargs), dtype=x.dtype, device=x.device
            )
        elif axis == x.split:  # x is distributed and axis chosen is == to split
            return elementwise_function(output_shape)
        # singular axis given (axis) not equal to split direction (x.split)
        lcl = function(x._DNDarray__array, **kwargs)
        return factories.array(
            lcl, is_split=x.split if axis > x.split else x.split - 1, dtype=x.dtype, device=x.device
        )
    elif not isinstance(axis, (list, tuple, torch.Tensor)):
        raise TypeError(
            f"axis must be an int, tuple, list, or torch.Tensor; currently it is {type(axis)}."
        )
    # else:
    if isinstance(axis, torch.Tensor):
        axis = axis.tolist()
github helmholtz-analytics / heat / heat / core / linalg / qr.py View on Github external
if not isinstance(overwrite_a, bool):
        raise TypeError("overwrite_a must be a bool, currently {}".format(type(overwrite_a)))
    if isinstance(tiles_per_proc, torch.Tensor):
        raise ValueError(
            "tiles_per_proc must be a single element torch.Tenor or int, "
            "currently has {} entries".format(tiles_per_proc.numel())
        )
    if len(a.shape) != 2:
        raise ValueError("Array 'a' must be 2 dimensional")

    QR = collections.namedtuple("QR", "Q, R")

    if a.split is None:
        q, r = a._DNDarray__array.qr(some=False)
        q = factories.array(q, device=a.device)
        r = factories.array(r, device=a.device)
        ret = QR(q if calc_q else None, r)
        return ret
    # =============================== Prep work ====================================================
    r = a if overwrite_a else a.copy()
    # r.create_square_diag_tiles(tiles_per_proc=tiles_per_proc)
    r_tiles = tiling.SquareDiagTiles(arr=r, tiles_per_proc=tiles_per_proc)
    tile_columns = r_tiles.tile_columns
    tile_rows = r_tiles.tile_rows
    if calc_q:
        q = factories.eye(
            (r.gshape[0], r.gshape[0]), split=0, dtype=r.dtype, comm=r.comm, device=r.device
        )
        q_tiles = tiling.SquareDiagTiles(arr=q, tiles_per_proc=tiles_per_proc)
        q_tiles.match_tiles(r_tiles)
    else:
        q, q_tiles = None, None
github helmholtz-analytics / heat / heat / core / manipulations.py View on Github external
>>> ht.flip(a, [0,1])
    (1/2) tensor([5,4,3])
    (2/2) tensor([2,1,0])
    """
    # flip all dimensions
    if axis is None:
        axis = tuple(range(a.ndim))

    # torch.flip only accepts tuples
    if isinstance(axis, int):
        axis = [axis]

    flipped = torch.flip(a._DNDarray__array, axis)

    if a.split not in axis:
        return factories.array(
            flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
        )

    # Need to redistribute tensors on split axis
    # Get local shapes
    old_lshape = a.lshape
    dest_proc = a.comm.size - 1 - a.comm.rank
    new_lshape = a.comm.sendrecv(old_lshape, dest=dest_proc, source=dest_proc)

    # Exchange local tensors
    req = a.comm.Isend(flipped, dest=dest_proc)
    received = torch.empty(new_lshape, dtype=a._DNDarray__array.dtype, device=a.device.torch_device)
    a.comm.Recv(received, source=dest_proc)

    res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)
    res.balance_()  # after swapping, first processes may be empty
github helmholtz-analytics / heat / heat / core / statistics.py View on Github external
data.get_halo(1)
        t_data = data.array_with_halos
        # fill out percentile
        t_ind_on_rank -= offset
        t_map_sum = t_indices_map.sum(axis=1)
        perc_ranks = torch.where(t_map_sum > -1 * nperc)[0].tolist()
        for r_id, r in enumerate(perc_ranks):
            # chunk of the global percentile that will be populated by rank r
            _, _, perc_chunk = x.comm.chunk(output_shape, join, rank=r_id, w_size=len(perc_ranks))
            perc_slice = perc_slice[:join] + (perc_chunk[join],) + perc_slice[join + 1 :]
            local_p = factories.zeros(percentile[perc_slice].shape, dtype=perc_dtype, comm=x.comm)
            if rank == r:
                if rank > 0:
                    # correct indices for halo
                    t_ind_on_rank += 1
                local_p = factories.array(_local_percentile(t_data, axis, t_ind_on_rank))
            x.comm.Bcast(local_p, root=r)
            percentile[perc_slice] = local_p
    else:
        if x.comm.is_distributed() and split is not None:
            # split != axis, calculate percentiles locally, then gather
            percentile = factories.empty(
                output_shape, dtype=perc_dtype, split=join, device=x.device
            )
            percentile._DNDarray__array = _local_percentile(t_data, axis, t_indices)
            percentile.resplit_(axis=None)
        else:
            # non-distributed case
            percentile = factories.array(_local_percentile(t_data, axis, t_indices))

    if percentile.shape[0] == 1:
        percentile = manipulations.squeeze(percentile, axis=0)
github helmholtz-analytics / heat / heat / core / manipulations.py View on Github external
# different communicators may not be concatenated
    if arr0.comm != arr1.comm:
        raise RuntimeError("Communicators of passed arrays mismatch.")

    # identify common data type
    out_dtype = types.promote_types(arr0.dtype, arr1.dtype)
    if arr0.dtype != out_dtype:
        arr0 = out_dtype(arr0, device=arr0.device)
    if arr1.dtype != out_dtype:
        arr1 = out_dtype(arr1, device=arr1.device)

    s0, s1 = arr0.split, arr1.split
    # no splits, local concat
    if s0 is None and s1 is None:
        return factories.array(
            torch.cat((arr0._DNDarray__array, arr1._DNDarray__array), dim=axis),
            device=arr0.device,
            comm=arr0.comm,
        )

    # non-matching splits when both arrays are split
    elif s0 != s1 and all([s is not None for s in [s0, s1]]):
        raise RuntimeError(
            "DNDarrays given have differing split axes, arr0 {} arr1 {}".format(s0, s1)
        )

    # unsplit and split array
    elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis):
        out_shape = tuple(
            arr1.gshape[x] if x != axis else arr0.gshape[x] + arr1.gshape[x]
            for x in range(len(arr1.gshape))
github helmholtz-analytics / heat / heat / core / statistics.py View on Github external
if out.dtype is not perc_dtype:
            raise TypeError(
                "Wrong datatype for out: expected {}, got {}".format(perc_dtype, out.dtype)
            )
        if out.gshape != output_shape:
            raise ValueError("out must have shape {}, got {}".format(output_shape, out.gshape))
        if out.split is not None:
            raise ValueError(
                "Split dimension mismatch for out: expected {}, got {}".format(None, out.split)
            )
    # END OF SANITATION

    # edge-case: x is a scalar. Return x
    if x.ndim == 0:
        percentile = t_x * torch.ones(nperc, dtype=t_perc_dtype, device=t_x.device)
        return factories.array(
            percentile, split=None, dtype=perc_dtype, device=x.device, comm=x.comm
        )

    # compute indices
    length = gshape[axis]
    t_indices = t_q / 100 * (length - 1)
    if interpolation == "linear":
        # leave fractional indices, interpolate linearly
        pass
    elif interpolation == "lower":
        t_indices = t_indices.floor().type(torch.int)
    elif interpolation == "higher":
        t_indices = t_indices.ceil().type(torch.int)
    elif interpolation == "midpoint":
        t_indices = 0.5 * (t_indices.floor() + t_indices.ceil())
    elif interpolation == "nearest":