How to use the batchgenerators.transforms.abstract_transforms.AbstractTransform function in batchgenerators

To help you get started, we’ve selected a few batchgenerators examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github MIC-DKFZ / TractSeg / tractseg / data / spatial_transform_peaks.py View on Github external
raise ValueError("invalid slice_dir passed as argument")

        data_aug = data_result[sample_id]
        if data_aug.shape[0] == 9:
            data_result[sample_id] = rotate_multiple_peaks(data_aug, a_x, a_y, a_z)
        elif data_aug.shape[0] == 18:
            data_result[sample_id] = rotate_multiple_tensors(data_aug, a_x, a_y, a_z)
        else:
            raise ValueError("Incorrect number of channels (expected 9 or 18)")

    return data_result, seg_result


# This is identical to batchgenerators.transforms.spatial_transforms.SpatialTransform except for another
# augment_spatial function, which also rotates the peaks when doing rotation.
class SpatialTransformPeaks(AbstractTransform):
    """The ultimate spatial transform generator. Rotation, deformation, scaling, cropping: It has all you ever dreamed
    of. Computational time scales only with patch_size, not with input patch size or type of augmentations used.
    Internally, this transform will use a coordinate grid of shape patch_size to which the transformations are
    applied (very fast). Interpolation on the image data will only be done at the very end

    Args:
        patch_size (tuple/list/ndarray of int): Output patch size

        patch_center_dist_from_border (tuple/list/ndarray of int, or int): How far should the center pixel of the
        extracted patch be from the image border? Recommended to use patch_size//2.
        This only applies when random_crop=True

        do_elastic_deform (bool): Whether or not to apply elastic deformation

        alpha (tuple of float): magnitude of the elastic deformation; randomly sampled from interval
github MIC-DKFZ / RegRCNN / utils / dataloader_utils.py View on Github external
for name in roi_item_keys:
                roi_items[name].append(np.array([]))

    if get_rois_from_seg:
        data_dict.pop('class_targets', None)

    data_dict['bb_target'] = np.array(bb_target)
    data_dict['roi_masks'] = np.array(roi_masks)
    data_dict['seg'] = out_seg
    for name in roi_item_keys:
        data_dict[name] = np.array(roi_items[name])


    return data_dict

class ConvertSegToBoundingBoxCoordinates(AbstractTransform):
    """ Converts segmentation masks into bounding box coordinates.
    """

    def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False):
        self.dim = dim
        self.roi_item_keys = roi_item_keys
        self.get_rois_from_seg = get_rois_from_seg
        self.class_specific_seg = class_specific_seg

    def __call__(self, **data_dict):
        return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg,
                                                       self.class_specific_seg)
github MIC-DKFZ / TractSeg / tractseg / data / custom_transformations.py View on Github external
elif dim == 2:
                # cut if dimension got too long
                img_up = img_up[:img.shape[0], :img.shape[1]]

                # pad with 0 if dimension too small
                img_padded = np.zeros((img.shape[0], img.shape[1]))
                img_padded[:img_up.shape[0], :img_up.shape[1]] = img_up

                data[sample_idx, channel_idx] = img_padded
            else:
                raise ValueError("Invalid dimension size")

    return data


class ResampleTransformLegacy(AbstractTransform):
    """
    This is no longer part of batchgenerators, so we have an implementation here.
    CPU always 100% when using this, but batch_time on cluster not longer (1s)

    Downsamples each sample (linearly) by a random factor and upsamples to original resolution again (nearest neighbor)
    Info:
    * Uses scipy zoom for resampling.
    * Resamples all dimensions (channels, x, y, z) with same downsampling factor
      (like isotropic=True from linear_downsampling_generator_nilearn)

    Args:
        zoom_range (tuple of float): Random downscaling factor in this range. (e.g.: 0.5 halfs the resolution)
    """

    def __init__(self, zoom_range=(0.5, 1)):
        self.zoom_range = zoom_range
github MIC-DKFZ / TractSeg / tractseg / data / custom_transformations.py View on Github external
data[id, 0] *= -1
                data[id, 3] *= -1
                data[id, 6] *= -1
            elif axis == "y":
                data[id, 1] *= -1
                data[id, 4] *= -1
                data[id, 7] *= -1
            elif axis == "z":
                data[id, 2] *= -1
                data[id, 5] *= -1
                data[id, 8] *= -1

    return data


class FlipVectorAxisTransform(AbstractTransform):
    """
    Expects as input an image with 3 3D-vectors at each voxels, encoded as a nine-channel image. Will randomly
    flip sign of one dimension of all 3 vectors (x, y or z).
    """
    def __init__(self, axes=(2, 3, 4), data_key="data"):
        self.data_key = data_key
        self.axes = axes

    def __call__(self, **data_dict):
        data_dict[self.data_key] = flip_vector_axis(data=data_dict[self.data_key])
        return data_dict
github MIC-DKFZ / TractSeg / tractseg / libs / AugmentationGenerators.py View on Github external
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from warnings import warn
from batchgenerators.transforms.abstract_transforms import AbstractTransform

class ReorderSegTransform(AbstractTransform):
    """
    Yields reordered seg (needed for DataAugmentation: x&y have to be last 2 dims and nr_classes must be before, for DataAugmentation to work)
    -> here we move it back to (bs, x, y, nr_classes) for easy calculating of f1
    """
    def __init__(self):
        pass

    def __call__(self, **data_dict):
        seg = data_dict.get("seg")

        if seg is None:
            warn("You used ReorderSegTransform but there is no 'seg' key in your data_dict, returning data_dict unmodified", Warning)
        else:
            seg = data_dict["seg"]  # (bs, nr_of_classes, x, y)
            data_dict["seg"] = seg.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
        return data_dict