Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _calc_output_shape(inputs, kernel_size=None, stride=None, dilation=1, padding=0, transposed=False, **kwargs):
shape = get_shape(inputs)
output_shape = list(shape)
for i in range(2, len(shape)):
if shape[i]:
k = kernel_size[i - 2] if isinstance(kernel_size, tuple) else kernel_size
p = padding[i - 2] if isinstance(padding, tuple) else padding
p = sum(p) if isinstance(p, tuple) else p * 2
s = stride[i - 2] if isinstance(stride, tuple) else stride
d = dilation[i - 2] if isinstance(dilation, tuple) else dilation
if transposed:
output_shape[i] = (shape[i] - 1) * s + k - p
else:
output_shape[i] = (shape[i] + p - d * (k - 1) - 1) // s + 1
else:
output_shape[i] = None
output_shape[1] = kwargs.get('out_channels') or output_shape[1]
steps = len(targets) // microbatch
splitted_inputs = [[item[i:i + microbatch] for item in inputs] for i in range(0, len(targets), microbatch)]
splitted_targets = [targets[i:i + microbatch] for i in range(0, len(targets), microbatch)]
else:
steps = 1
splitted_inputs = [inputs]
splitted_targets = [targets]
if self.model is None:
if isinstance(splitted_inputs[0], (list, tuple)):
self.input_shapes = [get_shape(item) for item in splitted_inputs[0]]
else:
self.input_shapes = get_shape(splitted_inputs[0])
self.target_shape = get_shape(splitted_targets[0])
if self.classes is None:
if len(self.target_shape) > 1: # segmentation
self.classes = self.target_shape[1]
self.build_config()
self._build(splitted_inputs[0])
self.model.train()
if use_lock:
self.train_lock.acquire()
outputs = []
for i in range(steps):
_inputs = splitted_inputs[i]
_targets = splitted_targets[i]
def __init__(self, inputs=None, ratio=4, squeeze_layout='Vfafa', squeeze_units=None, squeeze_activations=None):
from .conv_block import ConvBlock # can't be imported in the file beginning due to recursive imports
super().__init__()
in_units = get_shape(inputs)[1]
units = squeeze_units or [in_units // ratio, in_units]
activations = squeeze_activations or ['relu', 'sigmoid']
self.layer = ConvBlock(layout=squeeze_layout, units=units, activations=activations, inputs=inputs)
def __init__(self, inputs, layout='cna', filters=None, kernel_size=1, pool_op='mean',
pyramid=(0, 1, 2, 3, 6), **kwargs):
super().__init__()
spatial_shape = np.array(get_shape(inputs)[2:])
filters = filters if filters else 'same // {}'.format(len(pyramid))
modules = nn.ModuleList()
for level in pyramid:
if level == 0:
module = nn.Identity()
else:
x = inputs
pool_size = tuple(np.ceil(spatial_shape / level).astype(np.int32).tolist())
pool_strides = tuple(np.floor((spatial_shape - 1) / level + 1).astype(np.int32).tolist())
layer = ConvBlock(inputs=x, layout='p' + layout, filters=filters, kernel_size=kernel_size,
pool_op=pool_op, pool_size=pool_size, pool_strides=pool_strides, **kwargs)
x = layer(x)
upsample_layer = Upsample(inputs=x, factor=None, layout='b',
def __init__(self, inputs=None, output_size=None, **kwargs):
shape = get_shape(inputs)
kwargs.pop('padding', None)
super().__init__(_fn=ADAPTIVE_AVGPOOL, inputs=inputs, output_size=output_size, padding=None, **kwargs)
self.output_shape = tuple(shape[:2]) + tuple(output_size)
def __init__(self, inputs=None, **kwargs):
super().__init__()
num_features = get_num_channels(inputs)
self.norm = BATCH_NORM[get_num_dims(inputs)](num_features=num_features, **kwargs)
self.output_shape = get_shape(inputs)
def __init__(self, inputs, skip, filters, upsample, decoder, **kwargs):
super().__init__()
_ = skip
self.upsample = ConvBlock(inputs, filters=filters, **{**kwargs, **upsample})
shape = list(get_shape(self.upsample))
shape[1] *= 2
shape = tuple(shape)
self.decoder = ConvBlock(shape, filters=filters, **{**kwargs, **decoder})
self.output_shape = self.decoder.output_shape
def __init__(self, *args, inputs=None, base_block=BaseConvBlock, n_repeats=1, n_branches=1, combine='+', **kwargs):
super().__init__()
base_block = kwargs.pop('base', None) or base_block
self.device = getattr(inputs, 'device', None) or getattr(inputs[0], 'device')
self.input_shape = get_shape(inputs)
self.n_repeats, self.n_branches = n_repeats, n_branches
self.base_block, self.combine = base_block, combine
self.args, self.kwargs = args, kwargs
self._make_modules(inputs)
def forward(self, inputs):
i_shape = get_shape(inputs)
r_shape = get_shape(self.resize_to)
output = inputs
for i, (i_shape_, r_shape_) in enumerate(zip(i_shape[2:], r_shape[2:])):
if i_shape_ > r_shape_:
# Decrease input tensor's shape by slicing desired shape out of it
shape = [slice(None, None)] * len(i_shape)
shape[i + 2] = slice(None, r_shape_)
output = output[shape]
elif i_shape_ < r_shape_:
# Increase input tensor's shape by zero padding
zeros_shape = list(i_shape)
zeros_shape[i + 2] = r_shape_
zeros = torch.zeros(zeros_shape)
shape = [slice(None, None)] * len(i_shape)
shape[i + 2] = slice(None, i_shape_)
def _calc_padding(inputs, padding=0, kernel_size=None, dilation=1, transposed=False, stride=1, **kwargs):
_ = kwargs
dims = get_num_dims(inputs)
shape = get_shape(inputs)
if isinstance(padding, str):
if padding == 'valid':
padding = 0
elif padding == 'same':
if transposed:
padding = 0
else:
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * dims
if isinstance(dilation, int):
dilation = (dilation,) * dims
if isinstance(stride, int):
stride = (stride,) * dims
padding = tuple(_get_padding(kernel_size[i], shape[i+2], dilation[i], stride[i]) for i in range(dims))
else: