Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True,
**kwargs):
self._capture_config(locals())
self.transform = self.build_transform(**self.config)
if not save_dir:
save_dir = tempdir_human()
if not logger:
logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN)
logger.info('Hyperparameter:\n' + self.config.to_json())
num_examples = self.build_vocab(trn_data, logger)
# assert num_examples, 'You forgot to return the number of training examples in your build_vocab'
logger.info('Building...')
train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
logger.info('Model built:\n' + summary_of_model(self.model))
self.save_config(save_dir)
self.save_vocabs(save_dir)
self.save_meta(save_dir)
trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
dev_data = self.build_valid_dataset(dev_data, batch_size)
callbacks = self.build_callbacks(save_dir, logger, **self.config)
# need to know #batches, otherwise progbar crashes
dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size)
checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
timer = Timer()
try:
history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
num_examples=num_examples,
train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
def build(self, logger, **kwargs):
self.transform.build_config()
self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None),
loss=kwargs.get('loss', None)))
self.transform.lock_vocabs()
optimizer = self.build_optimizer(**self.config)
loss = self.build_loss(
**self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)]))
# allow for different
metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'),
logger=logger, overwrite=True))
if not isinstance(metrics, list):
if isinstance(metrics, tf.keras.metrics.Metric):
metrics = [metrics]
if not self.model.built:
sample_inputs = self.sample_data
if sample_inputs is not None:
self.model(sample_inputs)
else:
def build(self, logger, **kwargs):
self.transform.build_config()
self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None),
loss=kwargs.get('loss', None)))
self.transform.lock_vocabs()
optimizer = self.build_optimizer(**self.config)
loss = self.build_loss(
**self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)]))
# allow for different
metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'),
logger=logger, overwrite=True))
if not isinstance(metrics, list):
if isinstance(metrics, tf.keras.metrics.Metric):
metrics = [metrics]
if not self.model.built:
sample_inputs = self.sample_data
if sample_inputs is not None:
self.model(sample_inputs)
else:
if len(self.transform.output_shapes[0]) == 1 and self.transform.output_shapes[0][0] is None:
x_shape = self.transform.output_shapes[0]
else:
x_shape = list(self.transform.output_shapes[0])
for i, shape in enumerate(x_shape):
x_shape[i] = [None] + shape # batch + X.shape
self.model.build(input_shape=x_shape)
train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
logger.info('Model built:\n' + summary_of_model(self.model))
self.save_config(save_dir)
self.save_vocabs(save_dir)
self.save_meta(save_dir)
trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
dev_data = self.build_valid_dataset(dev_data, batch_size)
callbacks = self.build_callbacks(save_dir, logger, **self.config)
# need to know #batches, otherwise progbar crashes
dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size)
checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
timer = Timer()
try:
history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
num_examples=num_examples,
train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
loss=loss,
metrics=metrics, overwrite=True))
except KeyboardInterrupt:
print()
if not checkpoint or checkpoint.best in (np.Inf, -np.Inf):
self.save_weights(save_dir)
logger.info('Aborted with model saved')
else:
logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}')
# noinspection PyTypeChecker
history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History)
delta_time = timer.stop()
best_epoch_ago = 0
def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs):
self.meta['load_path'] = save_dir
save_dir = get_resource(save_dir)
self.load_config(save_dir)
self.load_vocabs(save_dir)
self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True))
self.load_weights(save_dir, **kwargs)
self.load_meta(save_dir)