Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_tf_keras_autolog_persists_manually_created_run(random_train_data, random_one_hot_labels,
fit_variant):
mlflow.tensorflow.autolog()
with mlflow.start_run() as run:
data = random_train_data
labels = random_one_hot_labels
model = create_tf_keras_model()
model.fit(data, labels, epochs=10)
assert mlflow.active_run()
assert mlflow.active_run().info.run_id == run.info.run_id
def test_prepare_env_passes(sk_model):
if no_conda:
pytest.skip("This test requires conda.")
with TempDir(chdr=True):
with mlflow.start_run() as active_run:
mlflow.sklearn.log_model(sk_model, "model")
model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id)
# Test with no conda
p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri,
"--no-conda"], stderr=subprocess.PIPE)
assert p.wait() == 0
# With conda
p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri],
stderr=subprocess.PIPE)
assert p.wait() == 0
# Should be idempotent
p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri],
stderr=subprocess.PIPE)
def test_tf_keras_autolog_ends_auto_created_run(random_train_data, random_one_hot_labels,
fit_variant):
mlflow.tensorflow.autolog()
data = random_train_data
labels = random_one_hot_labels
model = create_tf_keras_model()
model.fit(data, labels, epochs=10)
assert mlflow.active_run() is None
@staticmethod
def load_model(file, **kwars):
return MyModel(file.get("x").value)
def _import_module(name, **kwargs):
if name.startswith(FakeKerasModule.__name__):
return FakeKerasModule
else:
return importlib.import_module(name, **kwargs)
with mock.patch("importlib.import_module") as import_module_mock:
import_module_mock.side_effect = _import_module
x = MyModel("x123")
path0 = os.path.join(model_path, "0")
with pytest.raises(MlflowException):
mlflow.keras.save_model(x, path0)
mlflow.keras.save_model(x, path0, keras_module=FakeKerasModule)
y = mlflow.keras.load_model(path0)
assert x == y
path1 = os.path.join(model_path, "1")
mlflow.keras.save_model(x, path1, keras_module=FakeKerasModule.__name__)
z = mlflow.keras.load_model(path1)
assert x == z
# Tests model log
with mlflow.start_run() as active_run:
with pytest.raises(MlflowException):
mlflow.keras.log_model(x, "model0")
mlflow.keras.log_model(x, "model0", keras_module=FakeKerasModule)
a = mlflow.keras.load_model("runs:/{}/model0".format(active_run.info.run_id))
assert x == a
mlflow.keras.log_model(x, "model1", keras_module=FakeKerasModule.__name__)
with mock.patch("importlib.import_module") as import_module_mock:
import_module_mock.side_effect = _import_module
x = MyModel("x123")
path0 = os.path.join(model_path, "0")
with pytest.raises(MlflowException):
mlflow.keras.save_model(x, path0)
mlflow.keras.save_model(x, path0, keras_module=FakeKerasModule)
y = mlflow.keras.load_model(path0)
assert x == y
path1 = os.path.join(model_path, "1")
mlflow.keras.save_model(x, path1, keras_module=FakeKerasModule.__name__)
z = mlflow.keras.load_model(path1)
assert x == z
# Tests model log
with mlflow.start_run() as active_run:
with pytest.raises(MlflowException):
mlflow.keras.log_model(x, "model0")
mlflow.keras.log_model(x, "model0", keras_module=FakeKerasModule)
a = mlflow.keras.load_model("runs:/{}/model0".format(active_run.info.run_id))
assert x == a
mlflow.keras.log_model(x, "model1", keras_module=FakeKerasModule.__name__)
b = mlflow.keras.load_model("runs:/{}/model1".format(active_run.info.run_id))
assert x == b
def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
h2o_iris_model):
artifact_path = "model"
with mlflow.start_run():
mlflow.h2o.log_model(h2o_model=h2o_iris_model.model, artifact_path=artifact_path)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.h2o.get_default_conda_env()
def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies(
saved_tf_iris_model):
artifact_path = "model"
with mlflow.start_run():
mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path,
tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags,
tf_signature_def_key=saved_tf_iris_model.signature_def_key,
artifact_path=artifact_path,
conda_env=None)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.tensorflow.get_default_conda_env()
def test_log_and_load_model_persists_and_restores_model_successfully(saved_tf_iris_model):
artifact_path = "model"
with mlflow.start_run():
mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path,
tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags,
tf_signature_def_key=saved_tf_iris_model.signature_def_key,
artifact_path=artifact_path)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
infer_fn = mlflow.tensorflow.load_model(model_uri=model_uri)
def test_metric_timestamp(tracking_uri_mock):
with mlflow.start_run() as active_run:
mlflow.log_metric("name_1", 25)
mlflow.log_metric("name_1", 30)
run_id = active_run.info.run_uuid
# Check that metric timestamps are between run start and finish
client = mlflow.tracking.MlflowClient()
history = client.get_metric_history(run_id, "name_1")
finished_run = client.get_run(run_id)
assert len(history) == 2
assert all([
m.timestamp >= finished_run.info.start_time and m.timestamp <= finished_run.info.end_time
for m in history
])
def test_start_and_end_run(tracking_uri_mock):
# Use the start_run() and end_run() APIs without a `with` block, verify they work.
with start_run() as active_run:
mlflow.log_metric("name_1", 25)
finished_run = tracking.MlflowClient().get_run(active_run.info.run_id)
# Validate metrics
assert len(finished_run.data.metrics) == 1
assert finished_run.data.metrics["name_1"] == 25