How to use the elephas.spark_model.SparkModel.__init__ function in elephas

To help you get started, we’ve selected a few elephas examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github maxpumperla / elephas / elephas / dl4j.py View on Github external
:param num_workers: number of Spark workers/executors.
        :param batch_size: batch size used for model training
        :param shake_frequency:
        :param min_threshold:
        :param update_threshold:
        :param workers_per_node:
        :param num_batches_prefetch:
        :param step_delay:
        :param step_trigger:
        :param threshold_step:
        :param collect_stats:
        :param save_file:
        :param args:
        :param kwargs:
        """
        SparkModel.__init__(self, model=model, num_workers=num_workers, batch_size=batch_size, mode='asynchronous',
                            shake_frequency=shake_frequency, min_threshold=min_threshold,
                            update_threshold=update_threshold, workers_per_node=workers_per_node,
                            num_batches_prefetch=num_batches_prefetch, step_delay=step_delay, step_trigger=step_trigger,
                            threshold_step=threshold_step, collect_stats=collect_stats, *args, **kwargs)

        self.save(save_file)
        model_file = java_classes.File(save_file)
        keras_model_type = model.__class__.__name__
        self.java_spark_model = dl4j_import(
            java_spark_context, model_file, keras_model_type)
github maxpumperla / elephas / elephas / spark_model.py View on Github external
def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http',
                 num_workers=4, elephas_optimizer=None, custom_objects=None, batch_size=32, *args, **kwargs):
        """SparkMLlibModel

        The Spark MLlib model takes RDDs of LabeledPoints for training.

        :param model: Compiled Keras model
        :param mode: String, choose from `asynchronous`, `synchronous` and `hogwild`
        :param frequency: String, either `epoch` or `batch`
        :param parameter_server_mode: String, either `http` or `socket`
        :param num_workers: int, number of workers used for training (defaults to None)
        :param elephas_optimizer: Elephas optimizer
        :param custom_objects: Keras custom objects
        """
        SparkModel.__init__(self, model=model, mode=mode, frequency=frequency,
                            parameter_server_mode=parameter_server_mode, num_workers=num_workers,
                            elephas_optimizer=elephas_optimizer, custom_objects=custom_objects,
                            batch_size=batch_size, *args, **kwargs)
github maxpumperla / elephas / elephas / dl4j.py View on Github external
def __init__(self, java_spark_context, model, num_workers, batch_size, averaging_frequency=5,
                 num_batches_prefetch=0, collect_stats=False, save_file='temp.h5', *args, **kwargs):
        """ParameterAveragingModel

         :param java_spark_context JavaSparkContext, initialized through pyjnius
         :param model: compiled Keras model
         :param num_workers: number of Spark workers/executors.
         :param batch_size: batch size used for model training
         :param averaging_frequency: int, after how many batches of training averaging takes place
         :param num_batches_prefetch: int, how many batches to pre-fetch, deactivated if 0.
         :param collect_stats: boolean, if statistics get collected during training
         :param save_file: where to store elephas model temporarily.
         """
        SparkModel.__init__(self, model=model, batch_size=batch_size, mode='synchronous',
                            averaging_frequency=averaging_frequency, num_batches_prefetch=num_batches_prefetch,
                            num_workers=num_workers, collect_stats=collect_stats, *args, **kwargs)

        self.save(save_file)
        model_file = java_classes.File(save_file)
        keras_model_type = model.__class__.__name__
        self.java_spark_model = dl4j_import(
            java_spark_context, model_file, keras_model_type)