Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def to_java_rdd(jsc, features, labels, batch_size):
"""Convert numpy features and labels into a JavaRDD of
DL4J DataSet type.
:param jsc: JavaSparkContext from pyjnius
:param features: numpy array with features
:param labels: numpy array with labels:
:return: JavaRDD
"""
data_sets = java_classes.ArrayList()
num_batches = int(len(features) / batch_size)
for i in range(num_batches):
xi = ndarray(features[:batch_size].copy())
yi = ndarray(labels[:batch_size].copy())
data_set = java_classes.DataSet(xi.array, yi.array)
data_sets.add(data_set)
features = features[batch_size:]
labels = labels[batch_size:]
return jsc.parallelize(data_sets)
def retrieve_keras_weights(java_model):
"""For a previously imported Keras model, after training it with DL4J Spark,
we want to set the resulting weights back to the original Keras model.
:param java_model: DL4J model (MultiLayerNetwork or ComputationGraph
:return: list of numpy arrays in correct order for model.set_weights(...) of a corresponding Keras model
"""
weights = []
layers = java_model.getLayers()
for layer in layers:
params = layer.paramTable()
keys = params.keySet()
key_list = java_classes.ArrayList(keys)
for key in key_list:
weight = params.get(key)
np_weight = np.squeeze(to_numpy(weight))
weights.append(np_weight)
return weights