Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def compute_output_shape(self, input_shape):
input_shape = tuple(UnknownSize() if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape)
final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape)
return final_shape
def compute_output_shape(self, input_shape):
input_shape = tuple(UnknownSize() if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape)
final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape)
return final_shape
def shape(self, x):
# mxnet has problems with shape inference - it does not provide shape symbols
# shape_array seems to be impossible to use in shape inference
# infer_shape_partial returns empty tuple if was not able to infer shape
# reductions such as sum can't return scalars, but return 1-element vectors
shape = x.infer_shape_partial()[1][0]
if len(shape) == 0:
warnings.warn('mxnet inferred shape to be (), which probably means it could not be inferred')
shape = tuple(UnknownSize() if d == 0 else d for d in shape)
return shape
def reshape(self, x, shape):
if len(shape) == 0:
return x # poor support of scalars in mxnet
if any(isinstance(dimension, UnknownSize) for dimension in shape):
from .einops import EinopsError
raise EinopsError("Mxnet could't infer all dimensions statically, please provide those with axes_lengths")
return x.reshape(shape)