Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
:param num_workers: number of Spark workers/executors.
:param batch_size: batch size used for model training
:param shake_frequency:
:param min_threshold:
:param update_threshold:
:param workers_per_node:
:param num_batches_prefetch:
:param step_delay:
:param step_trigger:
:param threshold_step:
:param collect_stats:
:param save_file:
:param args:
:param kwargs:
"""
SparkModel.__init__(self, model=model, num_workers=num_workers, batch_size=batch_size, mode='asynchronous',
shake_frequency=shake_frequency, min_threshold=min_threshold,
update_threshold=update_threshold, workers_per_node=workers_per_node,
num_batches_prefetch=num_batches_prefetch, step_delay=step_delay, step_trigger=step_trigger,
threshold_step=threshold_step, collect_stats=collect_stats, *args, **kwargs)
self.save(save_file)
model_file = java_classes.File(save_file)
keras_model_type = model.__class__.__name__
self.java_spark_model = dl4j_import(
java_spark_context, model_file, keras_model_type)
def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http',
num_workers=4, elephas_optimizer=None, custom_objects=None, batch_size=32, *args, **kwargs):
"""SparkMLlibModel
The Spark MLlib model takes RDDs of LabeledPoints for training.
:param model: Compiled Keras model
:param mode: String, choose from `asynchronous`, `synchronous` and `hogwild`
:param frequency: String, either `epoch` or `batch`
:param parameter_server_mode: String, either `http` or `socket`
:param num_workers: int, number of workers used for training (defaults to None)
:param elephas_optimizer: Elephas optimizer
:param custom_objects: Keras custom objects
"""
SparkModel.__init__(self, model=model, mode=mode, frequency=frequency,
parameter_server_mode=parameter_server_mode, num_workers=num_workers,
elephas_optimizer=elephas_optimizer, custom_objects=custom_objects,
batch_size=batch_size, *args, **kwargs)
def __init__(self, java_spark_context, model, num_workers, batch_size, averaging_frequency=5,
num_batches_prefetch=0, collect_stats=False, save_file='temp.h5', *args, **kwargs):
"""ParameterAveragingModel
:param java_spark_context JavaSparkContext, initialized through pyjnius
:param model: compiled Keras model
:param num_workers: number of Spark workers/executors.
:param batch_size: batch size used for model training
:param averaging_frequency: int, after how many batches of training averaging takes place
:param num_batches_prefetch: int, how many batches to pre-fetch, deactivated if 0.
:param collect_stats: boolean, if statistics get collected during training
:param save_file: where to store elephas model temporarily.
"""
SparkModel.__init__(self, model=model, batch_size=batch_size, mode='synchronous',
averaging_frequency=averaging_frequency, num_batches_prefetch=num_batches_prefetch,
num_workers=num_workers, collect_stats=collect_stats, *args, **kwargs)
self.save(save_file)
model_file = java_classes.File(save_file)
keras_model_type = model.__class__.__name__
self.java_spark_model = dl4j_import(
java_spark_context, model_file, keras_model_type)