Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from elephas.java import java_classes, adapter
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=100))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.save('test.h5')
kmi = java_classes.KerasModelImport
file = java_classes.File("test.h5")
java_model = kmi.importKerasSequentialModelAndWeights(file.absolutePath)
weights = adapter.retrieve_keras_weights(java_model)
model.set_weights(weights)
:param java_spark_context JavaSparkContext, initialized through pyjnius
:param model: compiled Keras model
:param num_workers: number of Spark workers/executors.
:param batch_size: batch size used for model training
:param averaging_frequency: int, after how many batches of training averaging takes place
:param num_batches_prefetch: int, how many batches to pre-fetch, deactivated if 0.
:param collect_stats: boolean, if statistics get collected during training
:param save_file: where to store elephas model temporarily.
"""
SparkModel.__init__(self, model=model, batch_size=batch_size, mode='synchronous',
averaging_frequency=averaging_frequency, num_batches_prefetch=num_batches_prefetch,
num_workers=num_workers, collect_stats=collect_stats, *args, **kwargs)
self.save(save_file)
model_file = java_classes.File(save_file)
keras_model_type = model.__class__.__name__
self.java_spark_model = dl4j_import(
java_spark_context, model_file, keras_model_type)