Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _conda_env():
# NB: We need mlflow as a dependency in the environment.
return _mlflow_conda_env(
additional_conda_deps=None,
install_mlflow=False,
additional_pip_deps=[
"-e " + os.path.dirname(mlflow.__path__[0]),
"cloudpickle=={}".format(cloudpickle.__version__),
"scikit-learn=={}".format(sklearn.__version__)
],
additional_conda_channels=None)
def driver():
warnings.filterwarnings("ignore")
# Dependencies for deploying the model
pytorch_index = "https://download.pytorch.org/whl/"
pytorch_version = "cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl"
deps = [
"cloudpickle=={}".format(cloudpickle.__version__),
pytorch_index + pytorch_version,
"torchvision=={}".format(torchvision.__version__),
"Pillow=={}".format("6.0.0")
]
with mlflow.start_run() as run:
model = Net().to(device)
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
# Log model to run history using MLflow
if args.save_model:
model_env = _mlflow_conda_env(additional_pip_deps=deps)
model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME)
python_model_cloudpickle_version = pyfunc_config.get(CONFIG_KEY_CLOUDPICKLE_VERSION, None)
if python_model_cloudpickle_version is None:
mlflow.pyfunc._logger.warning(
"The version of CloudPickle used to save the model could not be found in the MLmodel"
" configuration")
elif python_model_cloudpickle_version != cloudpickle.__version__:
# CloudPickle does not have a well-defined cross-version compatibility policy. Micro version
# releases have been known to cause incompatibilities. Therefore, we match on the full
# library version
mlflow.pyfunc._logger.warning(
"The version of CloudPickle that was used to save the model, `CloudPickle %s`, differs"
" from the version of CloudPickle that is currently running, `CloudPickle %s`, and may"
" be incompatible",
python_model_cloudpickle_version, cloudpickle.__version__)
python_model_subpath = pyfunc_config.get(CONFIG_KEY_PYTHON_MODEL, None)
if python_model_subpath is None:
raise MlflowException(
"Python model path was not specified in the model configuration")
with open(os.path.join(model_path, python_model_subpath), "rb") as f:
python_model = cloudpickle.load(f)
artifacts = {}
for saved_artifact_name, saved_artifact_info in\
pyfunc_config.get(CONFIG_KEY_ARTIFACTS, {}).items():
artifacts[saved_artifact_name] = os.path.join(
model_path, saved_artifact_info[CONFIG_KEY_ARTIFACT_RELATIVE_PATH])
context = PythonModelContext(artifacts=artifacts)
python_model.load_context(context=context)
def get_default_conda_env(include_cloudpickle=False):
"""
:return: The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
import sklearn
pip_deps = None
if include_cloudpickle:
import cloudpickle
pip_deps = ["cloudpickle=={}".format(cloudpickle.__version__)]
return _mlflow_conda_env(
additional_conda_deps=[
"scikit-learn={}".format(sklearn.__version__),
],
additional_pip_deps=pip_deps,
additional_conda_channels=None
)
"""
import tensorflow as tf
conda_deps = [] # if we use tf.keras we only need to declare dependency on tensorflow
pip_deps = []
if keras_module is None:
import keras
keras_module = keras
if keras_module.__name__ == "keras":
# Temporary fix: the created conda environment has issues installing keras >= 2.3.1
if LooseVersion(keras_module.__version__) < LooseVersion('2.3.1'):
conda_deps.append("keras=={}".format(keras_module.__version__))
else:
pip_deps.append("keras=={}".format(keras_module.__version__))
if include_cloudpickle:
import cloudpickle
pip_deps.append("cloudpickle=={}".format(cloudpickle.__version__))
# Temporary fix: conda-forge currently does not have tensorflow > 1.14
# The Keras pyfunc representation requires the TensorFlow
# backend for Keras. Therefore, the conda environment must
# include TensorFlow
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
conda_deps.append("tensorflow=={}".format(tf.__version__))
else:
if pip_deps is not None:
pip_deps.append("tensorflow=={}".format(tf.__version__))
else:
pip_deps.append("tensorflow=={}".format(tf.__version__))
return _mlflow_conda_env(
additional_conda_deps=conda_deps,
additional_pip_deps=pip_deps,
additional_conda_channels=None)
:return: The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
import torch
import torchvision
return _mlflow_conda_env(
additional_conda_deps=[
"pytorch={}".format(torch.__version__),
"torchvision={}".format(torchvision.__version__),
],
additional_pip_deps=[
# We include CloudPickle in the default environment because
# it's required by the default pickle module used by `save_model()`
# and `log_model()`: `mlflow.pytorch.pickle_module`.
"cloudpickle=={}".format(cloudpickle.__version__)
],
additional_conda_channels=[
"pytorch",
])
def get_args(key=None, default=None):
args = __get_arg_config()
if args.args_data:
if args.use_cloudpickle:
import cloudpickle
assert args.cloudpickle_version == cloudpickle.__version__, "Cloudpickle versions do not match! (host) %s vs (remote) %s" % (args.cloudpickle_version, cloudpickle.__version__)
data = cloudpickle.loads(base64.b64decode(args.args_data))
else:
data = pickle.loads(base64.b64decode(args.args_data))
else:
data = {}
if key is not None:
return data.get(key, default)
return data
def get_args(key=None, default=None):
args = __get_arg_config()
if args.args_data:
if args.use_cloudpickle:
import cloudpickle
assert args.cloudpickle_version == cloudpickle.__version__, "Cloudpickle versions do not match! (host) %s vs (remote) %s" % (args.cloudpickle_version, cloudpickle.__version__)
data = cloudpickle.loads(base64.b64decode(args.args_data))
else:
data = pickle.loads(base64.b64decode(args.args_data))
else:
data = {}
if key is not None:
return data.get(key, default)
return data
def get_default_conda_env():
"""
:return: The default Conda environment for MLflow Models produced by calls to
:func:`save_model() `
and :func:`log_model() ` when a user-defined subclass of
:class:`PythonModel` is provided.
"""
return _mlflow_conda_env(
additional_conda_deps=None,
additional_pip_deps=[
"cloudpickle=={}".format(cloudpickle.__version__),
],
additional_conda_channels=None)