How to use batchgenerators - 10 common examples

To help you get started, we’ve selected a few batchgenerators examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github MIC-DKFZ / basic_unet_example / datasets / three_dim / data_augmentation.py View on Github external
if mode == "train":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          MirrorTransform(axes=(2,)),
                          SpatialTransform(patch_size=(target_size, target_size, target_size), random_crop=False,
                                           patch_center_dist_from_border=target_size // 2,
                                           do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
                                           do_rotation=True,
                                           angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                           scale=(0.9, 1.4),
                                           border_mode_data="nearest", border_mode_seg="nearest"),
                          ]

    elif mode == "val":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          ]

    elif mode == "test":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          ]

    transform_list.append(NumpyToTensor())

    return Compose(transform_list)
github MIC-DKFZ / medicaldetectiontoolkit / experiments / lidc_exp / data_loader.py View on Github external
min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
                    max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
                    data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
                    seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)

            batch_data.append(data)
            batch_segs.append(seg[np.newaxis])

        data = np.array(batch_data)
        seg = np.array(batch_segs).astype(np.uint8)
        class_target = np.array(batch_targets)
        return {'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target}



class PatientBatchIterator(SlimDataLoaderBase):
    """
    creates a test generator that iterates over entire given dataset returning 1 patient per batch.
    Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
    if willing to accept speed-loss during training.
    :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
    batch_size = n_2D_patches in 2D .
    """
    def __init__(self, data, cf): #threads in augmenter
        super(PatientBatchIterator, self).__init__(data, 0)
        self.cf = cf
        self.patient_ix = 0
        self.dataset_pids = [v['pid'] for (k, v) in data.items()]
        self.patch_size = cf.patch_size
        if len(self.patch_size) == 2:
            self.patch_size = self.patch_size + [1]
github huangmozhilv / u2net_torch / u2net_torch_src / ccToolkits / cc_augment.py View on Github external
if patch_center_dist_from_border[d] == data.shape[d + 2] - patch_center_dist_from_border[d]:
                                ctr = int(patch_center_dist_from_border[d])
                            elif patch_center_dist_from_border[d] < data.shape[d + 2] - patch_center_dist_from_border[d]:
                                ctr = int(np.random.randint(patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]))
                            else:
                                logger.error('low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}'.format(str(patch_type), str(patch_center_dist_from_border), str(data.shape), str(ctr_list)))
                    else: # center crop
                        ctr = int(np.round(data.shape[d + 2] / 2.))
                    ctr_list.append(ctr)

                # extracting patch
                if n < 10 and modified_coords:
                    for d in range(dim):
                        coords[d] += ctr_list[d]
                    for channel_id in range(data.shape[1]):
                        data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data)
                    if seg is not None:
                        for channel_id in range(seg.shape[1]):
                            seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True)
                else:
                    augs = list()
                    if seg is None:
                        s = None
                    else:
                        s = seg[sample_id:sample_id + 1]
                    if random_crop:
                        # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
                        # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
                        d_tmps = list()
                        for channel_id in range(data.shape[1]):
                            d_tmp = utils.extract_roi_from_volume(data[sample_id, channel_id], ctr_list, patch_size, fill="zero")
                            d_tmps.append(d_tmp)
github MIC-DKFZ / TractSeg / tractseg / data / spatial_transform_peaks.py View on Github external
# now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(patch_center_dist_from_border[d],
                                            data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
                                                                     border_mode_data, cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
                                                                        border_mode_seg, cval=border_cval_seg,
                                                                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]
github huangmozhilv / u2net_torch / u2net_torch_src / ccToolkits / cc_augment.py View on Github external
ctr = int(np.random.randint(patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]))
                            else:
                                logger.error('low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}'.format(str(patch_type), str(patch_center_dist_from_border), str(data.shape), str(ctr_list)))
                    else: # center crop
                        ctr = int(np.round(data.shape[d + 2] / 2.))
                    ctr_list.append(ctr)

                # extracting patch
                if n < 10 and modified_coords:
                    for d in range(dim):
                        coords[d] += ctr_list[d]
                    for channel_id in range(data.shape[1]):
                        data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data)
                    if seg is not None:
                        for channel_id in range(seg.shape[1]):
                            seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True)
                else:
                    augs = list()
                    if seg is None:
                        s = None
                    else:
                        s = seg[sample_id:sample_id + 1]
                    if random_crop:
                        # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
                        # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
                        d_tmps = list()
                        for channel_id in range(data.shape[1]):
                            d_tmp = utils.extract_roi_from_volume(data[sample_id, channel_id], ctr_list, patch_size, fill="zero")
                            d_tmps.append(d_tmp)
                        d = np.asarray(d_tmps)
                        if seg is not None:
                            s_tmps = list()
github MIC-DKFZ / TractSeg / tractseg / data / spatial_transform_peaks.py View on Github external
else:
                sc = np.random.uniform(max(scale[0], 1), scale[1])
            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(patch_center_dist_from_border[d],
                                            data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
                                                                     border_mode_data, cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
                                                                        border_mode_seg, cval=border_cval_seg,
                                                                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s)
github jenspetersen / probabilistic-unet / probunet / experiment / probabilistic_unet_segmentation.py View on Github external
generator_val=data.LinearBatchGenerator,
        transforms_train={
            0: {
                "type": SpatialTransform,
                "kwargs": {
                    "patch_size": patch_size,
                    "patch_center_dist_from_border": patch_size[0] // 2,
                    "do_elastic_deform": False,
                    "p_el_per_sample": 0.2,
                    "p_rot_per_sample": 0.3,
                    "p_scale_per_sample": 0.3
                },
                "active": True
            },
            1: {
                "type": MirrorTransform,
                "kwargs": {"axes": (0, 1, 2)},
                "active": True
            },
            2: {
                "type": SegLabelSelectionBinarizeTransform,
                "kwargs": {"label": [1, 2, 3]},
                "active": False
            }
        },
        transforms_val={
            0: {
                "type": CenterCropTransform,
                "kwargs": {"crop_size": patch_size},
                "active": False
            },
            1: {
github MIC-DKFZ / basic_unet_example / datasets / three_dim / data_augmentation.py View on Github external
def get_transforms(mode="train", target_size=128):
    transform_list = []

    if mode == "train":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          MirrorTransform(axes=(2,)),
                          SpatialTransform(patch_size=(target_size, target_size, target_size), random_crop=False,
                                           patch_center_dist_from_border=target_size // 2,
                                           do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
                                           do_rotation=True,
                                           angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                           scale=(0.9, 1.4),
                                           border_mode_data="nearest", border_mode_seg="nearest"),
                          ]

    elif mode == "val":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          ]

    elif mode == "test":
        transform_list = [CenterCropTransform(crop_size=target_size),
github MIC-DKFZ / basic_unet_example / datasets / two_dim / data_augmentation.py View on Github external
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=(target_size,target_size), order=1),
                         MirrorTransform(axes=(1,)),
                         SpatialTransform(patch_size=(target_size, target_size), random_crop=False,
                                          patch_center_dist_from_border=target_size // 2,
                                          do_elastic_deform=True, alpha=(0., 900.), sigma=(20., 30.),
                                          do_rotation=True, p_rot_per_sample=0.8,
                                          angle_x=(-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                          scale=(0.85, 1.25), p_scale_per_sample=0.8,
                                          border_mode_data="nearest", border_mode_seg="nearest"),
                         ]


    elif mode == "val":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    elif mode == "test":
github delira-dev / delira / tests / data_loading / test_numba_transforms.py View on Github external
def setUp(self) -> None:
        from delira.data_loading.numba_transform import NumbaTransform, \
            NumbaCompose
        self._basic_zoom_trafo = ZoomTransform(3)
        self._numba_zoom_trafo = NumbaTransform(ZoomTransform, zoom_factors=3)
        self._basic_pad_trafo = PadTransform(new_size=(30, 30))
        self._numba_pad_trafo = NumbaTransform(PadTransform,
                                               new_size=(30, 30))

        self._basic_compose_trafo = Compose([self._basic_pad_trafo,
                                             self._basic_zoom_trafo])
        self._numba_compose_trafo = NumbaCompose([self._basic_pad_trafo,
                                                  self._basic_zoom_trafo])

        self._input = {"data": np.random.rand(10, 1, 24, 24)}