Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sparkml_model_save_persists_specified_conda_env_in_mlflow_model_directory(
spark_model_iris, model_path, spark_custom_env):
sparkm.save_model(spark_model=spark_model_iris.model,
path=model_path,
conda_env=spark_custom_env)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != spark_custom_env
with open(spark_custom_env, "r") as f:
spark_custom_env_parsed = yaml.safe_load(f)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == spark_custom_env_parsed
def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
xgb_model, model_path):
mlflow.xgboost.save_model(xgb_model=xgb_model.model, path=model_path, conda_env=None)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.xgboost.get_default_conda_env()
def test_model_save_accepts_conda_env_as_dict(sklearn_knn_model, model_path):
conda_env = dict(mlflow.sklearn.get_default_conda_env())
conda_env["dependencies"].append("pytest")
mlflow.sklearn.save_model(
sk_model=sklearn_knn_model.model, path=model_path, conda_env=conda_env)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == conda_env
def test_model_save_accepts_conda_env_as_dict(h2o_iris_model, model_path):
conda_env = dict(mlflow.h2o.get_default_conda_env())
conda_env["dependencies"].append("pytest")
mlflow.h2o.save_model(h2o_model=h2o_iris_model.model, path=model_path, conda_env=conda_env)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == conda_env
def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
h2o_iris_model, model_path):
mlflow.h2o.save_model(h2o_model=h2o_iris_model.model, path=model_path, conda_env=None)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.h2o.get_default_conda_env()
def test_save_model_without_specified_conda_env_uses_default_env_with_expected_dependencies(
saved_tf_iris_model, model_path):
mlflow.tensorflow.save_model(tf_saved_model_dir=saved_tf_iris_model.path,
tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags,
tf_signature_def_key=saved_tf_iris_model.signature_def_key,
path=model_path,
conda_env=None)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.tensorflow.get_default_conda_env()
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
sklearn_knn_model, sklearn_custom_env):
artifact_path = "model"
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model=sklearn_knn_model.model,
artifact_path=artifact_path,
conda_env=sklearn_custom_env)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != sklearn_custom_env
with open(sklearn_custom_env, "r") as f:
sklearn_custom_env_parsed = yaml.safe_load(f)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == sklearn_custom_env_parsed
def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
sklearn_knn_model, model_path):
knn_model = sklearn_knn_model.model
mlflow.sklearn.save_model(sk_model=knn_model, path=model_path, conda_env=None,
serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.sklearn.get_default_conda_env()
def test_model_save_persists_specified_conda_env_in_mlflow_model_directory(
xgb_model, model_path, xgb_custom_env):
mlflow.xgboost.save_model(
xgb_model=xgb_model.model, path=model_path, conda_env=xgb_custom_env)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != xgb_custom_env
with open(xgb_custom_env, "r") as f:
xgb_custom_env_parsed = yaml.safe_load(f)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == xgb_custom_env_parsed
def _load_model_env(path):
"""
Get ENV file string from a model configuration stored in Python Function format.
Returned value is a model-relative path to a Conda Environment file,
or None if none was specified at model save time
"""
return _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME).get(ENV, None)