Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def read_data_sets(data_dir):
"""
Parse or download movielens 1m data if train_dir is empty.
:param data_dir: The directory storing the movielens data
:return: a 2D numpy array with user index and item index in each row
"""
WHOLE_DATA = 'ml-1m.zip'
local_file = base.maybe_download(WHOLE_DATA, data_dir, SOURCE_URL + WHOLE_DATA)
zip_ref = zipfile.ZipFile(local_file, 'r')
extracted_to = os.path.join(data_dir, "ml-1m")
if not os.path.exists(extracted_to):
print("Extracting %s to %s" % (local_file, data_dir))
zip_ref.extractall(data_dir)
zip_ref.close()
rating_files = os.path.join(extracted_to,"ratings.dat")
rating_list = [i.strip().split("::") for i in open(rating_files,"r").readlines()]
movielens_data = np.array(rating_list).astype(int)
return movielens_data
def main(max_epoch):
_ = init_nncontext()
(training_images_data, training_labels_data) = mnist.read_data_sets("/tmp/mnist", "train")
(testing_images_data, testing_labels_data) = mnist.read_data_sets("/tmp/mnist", "test")
training_images_data = (training_images_data - mnist.TRAIN_MEAN) / mnist.TRAIN_STD
testing_images_data = (testing_images_data - mnist.TRAIN_MEAN) / mnist.TRAIN_STD
model = tf.keras.Sequential(
[tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
]
)
model.compile(optimizer='rmsprop',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
keras_model = KerasModel(model)
keras_model.fit(training_images_data,
batch_size = self.dataset.batch_size
sample_rdd = self.dataset.get_training_data()
if val_outputs is not None and val_labels is not None:
val_rdd = self.dataset.get_validation_data()
if val_rdd is not None:
val_method = [TFValidationMethod(m, len(val_outputs), len(val_labels))
for m in to_list(val_method)]
training_rdd = sample_rdd
elif val_split != 0.0:
training_rdd, val_rdd = sample_rdd.randomSplit([1 - val_split, val_split])
val_method = [TFValidationMethod(m, len(val_outputs), len(val_labels))
for m in to_list(val_method)]
else:
raise ValueError("Validation data is not specified. Please set " +
"val rdd in TFDataset, or set val_split larger than zero")
self.optimizer = Optimizer.create(self.training_helper_layer,
training_rdd,
IdentityCriterion(),
batch_size=batch_size,
optim_method=self.optim_method)
self.optimizer.set_validation(self.dataset.batch_size,
val_rdd,
EveryEpoch(),
val_method)
else:
training_rdd = sample_rdd
self.optimizer = Optimizer.create(self.training_helper_layer,
train_set = TextSet.from_relation_pairs(train_relations, q_set, a_set)
validate_relations = Relations.read(options.data_path + "/relation_valid.csv",
sc, int(options.partition_num))
validate_set = TextSet.from_relation_lists(validate_relations, q_set, a_set)
if options.model:
knrm = KNRM.load_model(options.model)
else:
word_index = a_set.get_word_index()
knrm = KNRM(int(options.question_length), int(options.answer_length),
options.embedding_file, word_index)
model = Sequential().add(
TimeDistributed(
knrm,
input_shape=(2, int(options.question_length) + int(options.answer_length))))
model.compile(optimizer=SGD(learningrate=float(options.learning_rate)),
loss="rank_hinge")
for i in range(0, int(options.nb_epoch)):
model.fit(train_set, batch_size=int(options.batch_size), nb_epoch=1)
knrm.evaluate_ndcg(validate_set, 3)
knrm.evaluate_ndcg(validate_set, 5)
knrm.evaluate_map(validate_set)
if options.output_path:
knrm.save_model(options.output_path + "/knrm.model")
a_set.save_word_index(options.output_path + "/word_index.txt")
print("Trained model and word dictionary saved")
sc.stop()
parser.add_option("--nb_epoch", dest="nb_epoch", default="500")
(options, args) = parser.parse_args(sys.argv)
sc = init_nncontext(init_spark_conf().setMaster("local[4]"))
data_len = 1000
X_ = np.random.uniform(0, 1, (1000, 2))
Y_ = ((2 * X_).sum(1) + 0.4).reshape([data_len, 1])
a = Input(shape=(2,))
b = Dense(1)(a)
c = Lambda(function=add_one_func)(b)
model = Model(input=a, output=c)
model.compile(optimizer=SGD(learningrate=1e-2),
loss=mean_absolute_error)
model.set_tensorboard('./log', 'customized layer and loss')
model.fit(x=X_,
y=Y_,
batch_size=32,
nb_epoch=int(options.nb_epoch),
distributed=False)
model.save_graph_topology('./log')
w = model.get_weights()
print(w)
pred = model.predict_local(X_)
"""
Convert tensorflow model to bigdl model
:param input_ops: operation list used for input, should be placeholders
:param output_ops: operations list used for output
:return: bigdl model
"""
input_names = map(lambda x: x.name.split(":")[0], input_ops)
output_names = map(lambda x: x.name.split(":")[0], output_ops)
temp = tempfile.mkdtemp()
dump_model(path=temp)
model_path = temp + '/model.pb'
bin_path = temp + '/model.bin'
model = Model.load_tensorflow(model_path, input_names, output_names,
byte_order, bin_path, bigdl_type)
try:
shutil.rmtree(temp)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return model
def set_seed(self, seed=123):
"""
You can control the random seed which used to init weights for this model.
:param seed: random seed
:return: Model itself.
"""
callBigDlFunc(self.bigdl_type, "setModelSeed", seed)
return self
def __init__(self, label_map, clses, probs, bigdl_type="float"):
self.value = callBigDlFunc(
bigdl_type, JavaValue.jvm_class_constructor(self), label_map, clses, probs)
def unfreeze(self, names=None):
"""
"unfreeze" module, i.e. make the module parameters(weight/bias, if exists)
to be trained(updated) in training process.
If 'names' is a non-empty list, unfreeze layers that match given names
:param names: list of module names to be unFreezed. Default is None.
:return: current graph model
"""
callBigDlFunc(self.bigdl_type, "unFreeze", self.value, names)
def get_image(self, float_key="floats", to_chw=True):
"""
get image rdd from ImageFrame
"""
tensor_rdd = callBigDlFunc(self.bigdl_type,
"distributedImageFrameToImageTensorRdd", self.value, float_key, to_chw)
return tensor_rdd.map(lambda tensor: tensor.to_ndarray())