Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_DATAGENERATOR_runTraining(self):
pp_fi = Preprocessor(self.data_io, batch_size=4, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi, training=True,
shuffle=False, iterations=None)
self.assertEqual(len(data_gen), 3)
for batch in data_gen:
self.assertIsInstance(batch, tuple)
self.assertEqual(batch[0].shape, (4,16,16,16,1))
self.assertEqual(batch[1].shape, (4,16,16,16,3))
pp_pc = Preprocessor(self.data_io, batch_size=3, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
patch_shape=(5,5,5), analysis="patchwise-crop")
data_gen = DataGenerator(self.sample_list, pp_pc, training=True,
shuffle=False, iterations=None)
self.assertEqual(len(data_gen), 4)
for batch in data_gen:
self.assertIsInstance(batch, tuple)
self.assertEqual(batch[0].shape, (3,5,5,5,1))
self.assertEqual(batch[1].shape, (3,5,5,5,3))
def test_DATAGENERATOR_prepareData(self):
pp_fi = Preprocessor(self.data_io, batch_size=4, data_aug=None,
prepare_subfunctions=True, prepare_batches=True,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi, training=True,
shuffle=True, iterations=None)
self.assertEqual(len(data_gen), 3)
for batch in data_gen:
self.assertIsInstance(batch, tuple)
self.assertEqual(batch[0].shape[1:], (16,16,16,1))
self.assertEqual(batch[1].shape[1:], (16,16,16,3))
self.assertIn(batch[0].shape[0], [2,4])
def test_DATAGENERATOR_iterations(self):
pp_fi = Preprocessor(self.data_io, batch_size=1, data_aug=None,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi,
training=True, shuffle=False, iterations=None)
self.assertEqual(10, len(data_gen))
data_gen = DataGenerator(self.sample_list, pp_fi,
training=True, shuffle=False, iterations=5)
self.assertEqual(5, len(data_gen))
data_gen = DataGenerator(self.sample_list, pp_fi,
training=True, shuffle=False, iterations=50)
self.assertEqual(50, len(data_gen))
data_gen = DataGenerator(self.sample_list, pp_fi,
training=True, shuffle=False, iterations=100)
self.assertEqual(100, len(data_gen))
def test_DATAGENERATOR_create(self):
pp_fi = Preprocessor(self.data_io, batch_size=4, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi, training=False,
validation=False, shuffle=False,
iterations=None)
self.assertIsInstance(data_gen, DataGenerator)
def test_DATAGENERATOR_runPrediction(self):
pp_fi = Preprocessor(self.data_io, batch_size=4, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi, training=False,
shuffle=False, iterations=None)
self.assertEqual(len(data_gen), 10)
for batch in data_gen:
self.assertNotIsInstance(batch, tuple)
self.assertEqual(batch.shape, (1,16,16,16,1))
pp_pc = Preprocessor(self.data_io, batch_size=3, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
patch_shape=(5,5,5), analysis="patchwise-crop")
data_gen = DataGenerator(self.sample_list, pp_pc, training=False,
shuffle=False, iterations=None)
self.assertEqual(len(data_gen), 220)
for batch in data_gen:
self.assertNotIsInstance(batch, tuple)
self.assertIn(batch.shape, [(3,5,5,5,1), (1,5,5,5,1)])
def test_DATAGENERATOR_augcyling(self):
data_aug = Data_Augmentation(cycles=20)
pp_fi = Preprocessor(self.data_io, batch_size=4, data_aug=data_aug,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi,
training=True, shuffle=False, iterations=None)
self.assertEqual(50, len(data_gen))
def test_DATAGENERATOR_consistency(self):
pp_fi = Preprocessor(self.data_io, batch_size=1, data_aug=None,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi,
training=True, shuffle=False, iterations=None)
i = 0
for batch in data_gen:
sample = self.data_io.sample_loader(self.sample_list[i],
load_seg=True)
self.assertTrue(np.array_equal(batch[0][0], sample.img_data))
seg = to_categorical(sample.seg_data, num_classes=3)
self.assertTrue(np.array_equal(batch[1][0], seg))
i += 1
def test_DATAGENERATOR_runTraining(self):
pp_fi = Preprocessor(self.data_io, batch_size=4, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
analysis="fullimage")
data_gen = DataGenerator(self.sample_list, pp_fi, training=True,
shuffle=False, iterations=None)
self.assertEqual(len(data_gen), 3)
for batch in data_gen:
self.assertIsInstance(batch, tuple)
self.assertEqual(batch[0].shape, (4,16,16,16,1))
self.assertEqual(batch[1].shape, (4,16,16,16,3))
pp_pc = Preprocessor(self.data_io, batch_size=3, data_aug=self.data_aug,
prepare_subfunctions=False, prepare_batches=False,
patch_shape=(5,5,5), analysis="patchwise-crop")
data_gen = DataGenerator(self.sample_list, pp_pc, training=True,
shuffle=False, iterations=None)
self.assertEqual(len(data_gen), 4)
for batch in data_gen:
self.assertIsInstance(batch, tuple)
self.assertEqual(batch[0].shape, (3,5,5,5,1))
self.assertEqual(batch[1].shape, (3,5,5,5,3))
def evaluate(self, training_samples, validation_samples, epochs=20,
iterations=None, callbacks=[]):
# Initialize a Keras Data Generator for generating Training data
dataGen_training = DataGenerator(training_samples, self.preprocessor,
training=True, validation=False,
shuffle=self.shuffle_batches,
iterations=iterations)
# Initialize a Keras Data Generator for generating Validation data
dataGen_validation = DataGenerator(validation_samples,
self.preprocessor,
training=True, validation=True,
shuffle=self.shuffle_batches)
# Run training & validation process with the Keras fit
history = self.model.fit(dataGen_training,
validation_data=dataGen_validation,
callbacks=callbacks,
epochs=epochs,
workers=self.workers,
max_queue_size=self.batch_queue_size)
# Clean up temporary files if necessary