Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
parse_shape output can be used to specify axes_lengths for other operations
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
For symbolic frameworks may return symbols, not integers.
:param x: tensor of any of supported frameworks
:param pattern: str, space separated names for axes, underscore means skip axis
:return: dict, maps axes names to their lengths
"""
names = [elementary_axis for elementary_axis in pattern.split(' ') if len(elementary_axis) > 0]
shape = get_backend(x).shape(x)
if len(shape) != len(names):
raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format(
pattern=pattern, shape=shape))
result = {}
for axis_name, axis_length in zip(names, shape):
if axis_name != '_':
result[axis_name] = axis_length
return result
def apply(self, tensor):
backend = get_backend(tensor)
init_shapes, reduced_axes, axes_reordering, final_shapes = self.reconstruct_from_shape(backend.shape(tensor))
tensor = backend.reshape(tensor, init_shapes)
tensor = _reduce_axes(tensor, reduction_type=self.reduction_type, reduced_axes=reduced_axes, backend=backend)
tensor = backend.transpose(tensor, axes_reordering)
return backend.reshape(tensor, final_shapes)
def _enumerate_directions(x):
"""
For an n-dimensional tensor, returns tensors to enumerate each axis.
>>> x = np.zeros([2, 3, 4]) # or any other tensor
>>> i, j, k = _enumerate_directions(x)
>>> result = i + 2 * j + 3 * k
result[i, j, k] = i + 2 * j + 3 * k, and also has the same shape as result
Works very similarly to numpy.ogrid (open indexing grid)
"""
backend = get_backend(x)
shape = backend.shape(x)
result = []
for axis_id, axis_length in enumerate(shape):
shape = [1] * len(shape)
shape[axis_id] = axis_length
result.append(backend.reshape(backend.arange(0, axis_length), shape))
return result
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
(32, 15, 20, 12)
:param tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray).
list of tensors is also accepted, those should be of the same type and shape
:param pattern: string, rearrangement pattern
:param axes_lengths: any additional specifications for dimensions
:return: tensor of the same type as input. If possible, a view to the original tensor is returned.
When composing axes, C-order enumeration used (consecutive elements have different last axis)
More source_examples and explanations can be found in the einops guide.
"""
if isinstance(tensor, list):
if len(tensor) == 0:
raise TypeError("Rearrange can't be applied to an empty list")
tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor)
return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
def asnumpy(tensor):
"""
Convert a tensor of an imperative framework (i.e. numpy/cupy/torch/gluon/etc.) to numpy.ndarray
:param tensor: tensor of any of known imperative framework
:return: numpy.ndarray, converted to numpy
"""
return get_backend(tensor).to_numpy(tensor)
:param tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray).
list of tensors is also accepted, those should be of the same type and shape
:param pattern: string, reduction pattern
:param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
:param axes_lengths: any additional specifications for dimensions
:return: tensor of the same type as input
"""
try:
hashable_axes_lengths = tuple(sorted(axes_lengths.items()))
recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths)
return recipe.apply(tensor)
except EinopsError as e:
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
if not isinstance(tensor, list):
message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor))
else:
message += '\n Input is list. '
message += 'Additional info: {}.'.format(axes_lengths)
raise EinopsError(message + '\n {}'.format(e))