Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_tf_keras_autolog_ends_auto_created_run(random_train_data, random_one_hot_labels,
fit_variant):
mlflow.tensorflow.autolog()
data = random_train_data
labels = random_one_hot_labels
model = create_tf_keras_model()
model.fit(data, labels, epochs=10)
assert mlflow.active_run() is None
def test_delete_tag():
"""
Confirm that fluent API delete tags actually works
:return:
"""
mlflow.set_tag('a', 'b')
run = MlflowClient().get_run(mlflow.active_run().info.run_id)
print(run.info.run_id)
assert 'a' in run.data.tags
mlflow.delete_tag('a')
run = MlflowClient().get_run(mlflow.active_run().info.run_id)
assert 'a' not in run.data.tags
with pytest.raises(MlflowException):
mlflow.delete_tag('a')
with pytest.raises(MlflowException):
mlflow.delete_tag('b')
mlflow.end_run()
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
sklearn_knn_model, sklearn_custom_env):
artifact_path = "model"
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model=sklearn_knn_model.model,
artifact_path=artifact_path,
conda_env=sklearn_custom_env)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != sklearn_custom_env
with open(sklearn_custom_env, "r") as f:
sklearn_custom_env_parsed = yaml.safe_load(f)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == sklearn_custom_env_parsed
def test_log_model_no_registered_model_name(sklearn_knn_model, main_scoped_model_class):
register_model_patch = mock.patch("mlflow.register_model")
with register_model_patch:
sklearn_artifact_path = "sk_model_no_run"
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model=sklearn_knn_model,
artifact_path=sklearn_artifact_path)
sklearn_model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=sklearn_artifact_path)
def test_predict(sk_model, model_input):
return sk_model.predict(model_input) * 2
pyfunc_artifact_path = "pyfunc_model"
assert mlflow.active_run() is None
mlflow.pyfunc.log_model(artifact_path=pyfunc_artifact_path,
artifacts={"sk_model": sklearn_model_uri},
python_model=main_scoped_model_class(test_predict))
mlflow.register_model.assert_not_called()
mlflow.end_run()
for should_start_run in [False, True]:
for dfs_tmp_dir in [None, os.path.join(str(tmpdir), "test")]:
print("should_start_run =", should_start_run, "dfs_tmp_dir =", dfs_tmp_dir)
try:
tracking_dir = os.path.abspath(str(tmpdir.join("mlruns")))
mlflow.set_tracking_uri("file://%s" % tracking_dir)
if should_start_run:
mlflow.start_run()
artifact_path = "model%d" % cnt
cnt += 1
sparkm.log_model(
artifact_path=artifact_path,
spark_model=spark_model_estimator.model,
dfs_tmpdir=dfs_tmp_dir)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
# test reloaded model
reloaded_model = sparkm.load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmp_dir)
preds_df = reloaded_model.transform(spark_model_estimator.spark_df)
preds = [x.prediction for x in preds_df.select("prediction").collect()]
assert spark_model_estimator.predictions == preds
finally:
mlflow.end_run()
mlflow.set_tracking_uri(old_tracking_uri)
x = dfs_tmp_dir or sparkm.DFS_TMP
shutil.rmtree(x)
shutil.rmtree(tracking_dir)
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(onnx_model,
onnx_custom_env):
import mlflow.onnx
artifact_path = "model"
with mlflow.start_run():
mlflow.onnx.log_model(
onnx_model=onnx_model, artifact_path=artifact_path, conda_env=onnx_custom_env)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != onnx_custom_env
with open(onnx_custom_env, "r") as f:
onnx_custom_env_parsed = yaml.safe_load(f)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == onnx_custom_env_parsed
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
sequential_model, pytorch_custom_env):
artifact_path = "model"
with mlflow.start_run():
mlflow.pytorch.log_model(pytorch_model=sequential_model,
artifact_path=artifact_path,
conda_env=pytorch_custom_env)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != pytorch_custom_env
with open(pytorch_custom_env, "r") as f:
pytorch_custom_env_text = f.read()
with open(saved_conda_env_path, "r") as f:
saved_conda_env_text = f.read()
assert saved_conda_env_text == pytorch_custom_env_text
def __init__(self, tracking_uri=None):
try:
import mlflow
except ImportError:
raise RuntimeError("This contrib module requires mlflow to be installed. "
"Please install it with command: \n pip install mlflow")
if tracking_uri is not None:
mlflow.set_tracking_uri(tracking_uri)
self.active_run = mlflow.active_run()
if self.active_run is None:
self.active_run = mlflow.start_run()
r2_score_training = xgbr.score(trainingFeatures, trainingLabels)
r2_score_test = xgbr.score(testFeatures, testLabels)
print("Test RMSE:", test_rmse)
print("Training set score:", r2_score_training)
print("Test set score:", r2_score_test)
# Logging the RMSE and r2 scores.
mlflow.log_metric("Test RMSE", test_rmse)
mlflow.log_metric("Train R2", r2_score_training)
mlflow.log_metric("Test R2", r2_score_test)
# Saving the model as an artifact.
sklearn.log_model(xgbr, "model")
run_id = mlflow.active_run().info.run_uuid
print("Run with id %s finished" % run_id)