How to use the elephas.java.java_classes.File function in elephas

To help you get started, we’ve selected a few elephas examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github maxpumperla / elephas / examples / basic_import.py View on Github external
from elephas.java import java_classes, adapter
from keras.models import Sequential
from keras.layers import Dense


model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=100))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

model.save('test.h5')


kmi = java_classes.KerasModelImport
file = java_classes.File("test.h5")

java_model = kmi.importKerasSequentialModelAndWeights(file.absolutePath)

weights = adapter.retrieve_keras_weights(java_model)
model.set_weights(weights)
github maxpumperla / elephas / elephas / dl4j.py View on Github external
:param java_spark_context JavaSparkContext, initialized through pyjnius
         :param model: compiled Keras model
         :param num_workers: number of Spark workers/executors.
         :param batch_size: batch size used for model training
         :param averaging_frequency: int, after how many batches of training averaging takes place
         :param num_batches_prefetch: int, how many batches to pre-fetch, deactivated if 0.
         :param collect_stats: boolean, if statistics get collected during training
         :param save_file: where to store elephas model temporarily.
         """
        SparkModel.__init__(self, model=model, batch_size=batch_size, mode='synchronous',
                            averaging_frequency=averaging_frequency, num_batches_prefetch=num_batches_prefetch,
                            num_workers=num_workers, collect_stats=collect_stats, *args, **kwargs)

        self.save(save_file)
        model_file = java_classes.File(save_file)
        keras_model_type = model.__class__.__name__
        self.java_spark_model = dl4j_import(
            java_spark_context, model_file, keras_model_type)