Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
del b_lp_data[pr - 1]
if pr == b.comm.size - 1:
req[pr].wait()
st0 = index_map[pr, 0, 0, 0].item()
sp0 = index_map[pr, 0, 0, 1].item() + 1
st1 = index_map[pr, 1, 1, 0].item()
sp1 = index_map[pr, 1, 1, 1].item()
c._DNDarray__array[: sp0 - st0, st1:sp1] += a._DNDarray__array @ b_lp_data[pr]
del b_lp_data[pr]
c = (
c
if not vector_flag
else factories.array(c._DNDarray__array.squeeze(), is_split=0, device=a.device)
)
return c
elif split_10_flag:
# for this case, only a sum is needed at the end
a_rem_locs1 = (rem_map[:, 0, 1] == 1).nonzero()
# locations of the remainders in b
b_rem_locs0 = (rem_map[:, 1, 0] == 1).nonzero()
res = torch.zeros(
(a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=c.device.torch_device
)
for i in range(a.lshape[-1] // kB):
res += (
a._DNDarray__array[:mB, i * kB : i * kB + kB]
@ b._DNDarray__array[i * kB : i * kB + kB, :nB]
)
c._DNDarray__array[r_loc.item(), st1:sp1] += r[st:sp] @ b_lp_data[pr]
else:
c._DNDarray__array[r_loc.item(), :] += r[st:sp] @ b_lp_data[pr]
# set the final blocks on the last loop, then adjust for the
# the remainders which were collected in b_rem
if b_rem_locs0.numel():
c._DNDarray__array[: a_node_rem_s0.shape[0]] += a_node_rem_s0 @ b_rem
del b_lp_data[pr]
if vector_flag:
c_loc = c._DNDarray__array.squeeze()
if c_loc.nelement() == 1:
c = torch.tensor(c_loc, device=c._DNDarray__array.device)
c = factories.array(c_loc, is_split=0, device=a.device)
return c
elif split_1_flag:
# for this case, a is sent to b
# this is because 'b' has complete columns and the rows of 'a' are split
# locations of the remainders in b
b_rem_locs1 = (rem_map[:, 1, 1] == 1).nonzero()
a_rem_locs1 = (rem_map[:, 0, 1] == 1).nonzero()
b_node_rem_s1 = b._DNDarray__array[
kB : (kB + 1) * a_rem_locs1.numel() : kB + 1, :nB
] # remainders for a in the
a_rem = torch.empty(
a.lshape[-2],
a_rem_locs1.numel(),
dtype=b.dtype.torch_type(),
def sanitize_input_type(x, y):
"""
Verifies that x is either a scalar, or a ht.DNDarray. If a scalar, x gets wrapped in a ht.DNDarray.
Raises TypeError if x is neither.
"""
if not isinstance(x, dndarray.DNDarray):
if np.ndim(x) == 0:
dtype = getattr(x, "dtype", float)
device = getattr(y, "device", None)
x = factories.array(x, dtype=dtype, device=device)
else:
raise TypeError("Expected DNDarray or numeric scalar, input was {}".format(type(x)))
return x
if unbiased:
kwargs["unbiased"] = unbiased
if Fischer:
kwargs["Fischer"] = Fischer
output_shape = list(x.shape)
if isinstance(axis, int):
if axis >= len(x.shape):
raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis))
axis = stride_tricks.sanitize_axis(x.shape, axis)
# only one axis given
output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis]
output_shape = output_shape if output_shape else (1,)
if x.split is None: # x is *not* distributed -> no need to distributed
return factories.array(
function(x._DNDarray__array, **kwargs), dtype=x.dtype, device=x.device
)
elif axis == x.split: # x is distributed and axis chosen is == to split
return elementwise_function(output_shape)
# singular axis given (axis) not equal to split direction (x.split)
lcl = function(x._DNDarray__array, **kwargs)
return factories.array(
lcl, is_split=x.split if axis > x.split else x.split - 1, dtype=x.dtype, device=x.device
)
elif not isinstance(axis, (list, tuple, torch.Tensor)):
raise TypeError(
f"axis must be an int, tuple, list, or torch.Tensor; currently it is {type(axis)}."
)
# else:
if isinstance(axis, torch.Tensor):
axis = axis.tolist()
if not isinstance(overwrite_a, bool):
raise TypeError("overwrite_a must be a bool, currently {}".format(type(overwrite_a)))
if isinstance(tiles_per_proc, torch.Tensor):
raise ValueError(
"tiles_per_proc must be a single element torch.Tenor or int, "
"currently has {} entries".format(tiles_per_proc.numel())
)
if len(a.shape) != 2:
raise ValueError("Array 'a' must be 2 dimensional")
QR = collections.namedtuple("QR", "Q, R")
if a.split is None:
q, r = a._DNDarray__array.qr(some=False)
q = factories.array(q, device=a.device)
r = factories.array(r, device=a.device)
ret = QR(q if calc_q else None, r)
return ret
# =============================== Prep work ====================================================
r = a if overwrite_a else a.copy()
# r.create_square_diag_tiles(tiles_per_proc=tiles_per_proc)
r_tiles = tiling.SquareDiagTiles(arr=r, tiles_per_proc=tiles_per_proc)
tile_columns = r_tiles.tile_columns
tile_rows = r_tiles.tile_rows
if calc_q:
q = factories.eye(
(r.gshape[0], r.gshape[0]), split=0, dtype=r.dtype, comm=r.comm, device=r.device
)
q_tiles = tiling.SquareDiagTiles(arr=q, tiles_per_proc=tiles_per_proc)
q_tiles.match_tiles(r_tiles)
else:
q, q_tiles = None, None
>>> ht.flip(a, [0,1])
(1/2) tensor([5,4,3])
(2/2) tensor([2,1,0])
"""
# flip all dimensions
if axis is None:
axis = tuple(range(a.ndim))
# torch.flip only accepts tuples
if isinstance(axis, int):
axis = [axis]
flipped = torch.flip(a._DNDarray__array, axis)
if a.split not in axis:
return factories.array(
flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
)
# Need to redistribute tensors on split axis
# Get local shapes
old_lshape = a.lshape
dest_proc = a.comm.size - 1 - a.comm.rank
new_lshape = a.comm.sendrecv(old_lshape, dest=dest_proc, source=dest_proc)
# Exchange local tensors
req = a.comm.Isend(flipped, dest=dest_proc)
received = torch.empty(new_lshape, dtype=a._DNDarray__array.dtype, device=a.device.torch_device)
a.comm.Recv(received, source=dest_proc)
res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)
res.balance_() # after swapping, first processes may be empty
data.get_halo(1)
t_data = data.array_with_halos
# fill out percentile
t_ind_on_rank -= offset
t_map_sum = t_indices_map.sum(axis=1)
perc_ranks = torch.where(t_map_sum > -1 * nperc)[0].tolist()
for r_id, r in enumerate(perc_ranks):
# chunk of the global percentile that will be populated by rank r
_, _, perc_chunk = x.comm.chunk(output_shape, join, rank=r_id, w_size=len(perc_ranks))
perc_slice = perc_slice[:join] + (perc_chunk[join],) + perc_slice[join + 1 :]
local_p = factories.zeros(percentile[perc_slice].shape, dtype=perc_dtype, comm=x.comm)
if rank == r:
if rank > 0:
# correct indices for halo
t_ind_on_rank += 1
local_p = factories.array(_local_percentile(t_data, axis, t_ind_on_rank))
x.comm.Bcast(local_p, root=r)
percentile[perc_slice] = local_p
else:
if x.comm.is_distributed() and split is not None:
# split != axis, calculate percentiles locally, then gather
percentile = factories.empty(
output_shape, dtype=perc_dtype, split=join, device=x.device
)
percentile._DNDarray__array = _local_percentile(t_data, axis, t_indices)
percentile.resplit_(axis=None)
else:
# non-distributed case
percentile = factories.array(_local_percentile(t_data, axis, t_indices))
if percentile.shape[0] == 1:
percentile = manipulations.squeeze(percentile, axis=0)
# different communicators may not be concatenated
if arr0.comm != arr1.comm:
raise RuntimeError("Communicators of passed arrays mismatch.")
# identify common data type
out_dtype = types.promote_types(arr0.dtype, arr1.dtype)
if arr0.dtype != out_dtype:
arr0 = out_dtype(arr0, device=arr0.device)
if arr1.dtype != out_dtype:
arr1 = out_dtype(arr1, device=arr1.device)
s0, s1 = arr0.split, arr1.split
# no splits, local concat
if s0 is None and s1 is None:
return factories.array(
torch.cat((arr0._DNDarray__array, arr1._DNDarray__array), dim=axis),
device=arr0.device,
comm=arr0.comm,
)
# non-matching splits when both arrays are split
elif s0 != s1 and all([s is not None for s in [s0, s1]]):
raise RuntimeError(
"DNDarrays given have differing split axes, arr0 {} arr1 {}".format(s0, s1)
)
# unsplit and split array
elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis):
out_shape = tuple(
arr1.gshape[x] if x != axis else arr0.gshape[x] + arr1.gshape[x]
for x in range(len(arr1.gshape))
if out.dtype is not perc_dtype:
raise TypeError(
"Wrong datatype for out: expected {}, got {}".format(perc_dtype, out.dtype)
)
if out.gshape != output_shape:
raise ValueError("out must have shape {}, got {}".format(output_shape, out.gshape))
if out.split is not None:
raise ValueError(
"Split dimension mismatch for out: expected {}, got {}".format(None, out.split)
)
# END OF SANITATION
# edge-case: x is a scalar. Return x
if x.ndim == 0:
percentile = t_x * torch.ones(nperc, dtype=t_perc_dtype, device=t_x.device)
return factories.array(
percentile, split=None, dtype=perc_dtype, device=x.device, comm=x.comm
)
# compute indices
length = gshape[axis]
t_indices = t_q / 100 * (length - 1)
if interpolation == "linear":
# leave fractional indices, interpolate linearly
pass
elif interpolation == "lower":
t_indices = t_indices.floor().type(torch.int)
elif interpolation == "higher":
t_indices = t_indices.ceil().type(torch.int)
elif interpolation == "midpoint":
t_indices = 0.5 * (t_indices.floor() + t_indices.ceil())
elif interpolation == "nearest":