How to use the function in elephas

To help you get started, we’ve selected a few elephas examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github maxpumperla / elephas / elephas / View on Github external
:param step_delay:
        :param step_trigger:
        :param threshold_step:
        :param collect_stats:
        :param save_file:
        :param args:
        :param kwargs:
        SparkModel.__init__(self, model=model, num_workers=num_workers, batch_size=batch_size, mode='asynchronous',
                            shake_frequency=shake_frequency, min_threshold=min_threshold,
                            update_threshold=update_threshold, workers_per_node=workers_per_node,
                            num_batches_prefetch=num_batches_prefetch, step_delay=step_delay, step_trigger=step_trigger,
                            threshold_step=threshold_step, collect_stats=collect_stats, *args, **kwargs)
        model_file = java_classes.File(save_file)
        keras_model_type = model.__class__.__name__
        self.java_spark_model = dl4j_import(
            java_spark_context, model_file, keras_model_type)
github maxpumperla / elephas / examples / View on Github external
def main():
    # Set Java Spark context
    conf = java_classes.SparkConf().setMaster('local[*]').setAppName("elephas_dl4j")
    jsc = java_classes.JavaSparkContext(conf)

    # Define Keras model
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(128, input_dim=784))
    model.add(keras.layers.Dense(units=10, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

    # Define DL4J Elephas model
    spark_model = ParameterAveragingModel(java_spark_context=jsc, model=model, num_workers=4, batch_size=32)

    # Load data and build DL4J DataSet RDD under the hood
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train.reshape(60000, 784)
    x_test = x_test.reshape(10000, 784)
    x_train = x_train.astype("float64")
    x_test = x_test.astype("float64")
github maxpumperla / elephas / elephas / utils / View on Github external
def to_java_rdd(jsc, features, labels, batch_size):
    """Convert numpy features and labels into a JavaRDD of
    DL4J DataSet type.

    :param jsc: JavaSparkContext from pyjnius
    :param features: numpy array with features
    :param labels: numpy array with labels:
    :return: JavaRDD
    data_sets = java_classes.ArrayList()
    num_batches = int(len(features) / batch_size)
    for i in range(num_batches):
        xi = ndarray(features[:batch_size].copy())
        yi = ndarray(labels[:batch_size].copy())
        data_set = java_classes.DataSet(xi.array, yi.array)
        features = features[batch_size:]
        labels = labels[batch_size:]

    return jsc.parallelize(data_sets)
github maxpumperla / elephas / elephas / View on Github external
def dl4j_import(jsc, model_file, keras_model_type):
    emi = java_classes.ElephasModelImport
    if keras_model_type == "Sequential":
            return emi.importElephasSequentialModelAndWeights(
                jsc, model_file.absolutePath)
            print("Couldn't load Keras model into DL4J")
    elif keras_model_type == "Model":
            return emi.importElephasModelAndWeights(jsc, model_file.absolutePath)
            print("Couldn't load Keras model into DL4J")
        raise Exception(
            "Keras model not understood, got: {}".format(keras_model_type))