Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_calculate_scores(self):
model_path = 'tests/fixtures/test_predict_by_model/iris'
options = fsclient.read_json_file(os.path.join(model_path, "options.json"))
y_test, _ = ModelHelper.preprocess_target(model_path,
records=[["setosa"], ["versicolor"], ["virginica"], ["setosa"], ["versicolor"], ["virginica"]],
features=["species"]
)
y_pred, _ = ModelHelper.preprocess_target(model_path,
records=[["setosa"], ["versicolor"], ["versicolor"], ["setosa"], ["versicolor"], ["virginica"]],
features=["species"]
)
scores = ModelHelper.calculate_scores(options, y_test=y_test, y_pred=y_pred)
self.assertEqual(len(scores), len(options['scoreNames']))
self.assertTrue(scores['accuracy']>0.8)
def _predict_locally(self, predict_data, model_id, threshold):
is_loaded, model_path = self.verify_local_model(model_id)
if not is_loaded:
raise Exception("Model should be deployed before predict.")
fitted_model = fsclient.load_object_from_file(model_path)
try:
options = fsclient.read_json_file(os.path.join(self.ctx.config.get_model_path(model_id), "options.json"))
model_features = options.get("originalFeatureColumns")
predict_data = predict_data[model_features]
predict_data.to_csv("test_options.csv", index=False, compression=None, encoding='utf-8')
except Exception as e:
self.ctx.log('Cannot get columns from model.Use original columns from predicted data: %s'%e)
results_proba = None
proba_classes = None
results = None
if threshold is not None:
results_proba = fitted_model.predict_proba(predict_data)
proba_classes = list(fitted_model.classes_)
else:
def _upload_to_multi_tenant(self, file_to_upload):
file_path = 'workspace/projects/%s/files/%s-%s' % \
(self.parent_api.object_name, shortuuid.uuid(),
os.path.basename(file_to_upload))
res = self.rest_api.call('create_project_file_url', {
'project_id': self.parent_api.object_id,
'file_path': file_path,
'file_size': fsclient.get_file_size(file_to_upload)
})
if res is None:
raise AugerException(
'Error while uploading file to Auger Cloud...')
if 'multipart' in res:
upload_details = res['multipart']
config = upload_details['config']
uploader = FileUploader(
upload_details['bucket'],
config['endpoint'],
config['access_key'],
config['secret_key'],
config['security_token']
def download_file(remote_path, local_dir, file_name, force_download=False):
local_file_path = ""
download_file = True
remote_file_info = {}
logging.info("download_file: %s, %s, %s, %s"%(remote_path, local_dir, file_name, force_download))
if file_name:
all_local_files = fsclient.list_folder(os.path.join(local_dir, file_name+".*"), wild=True, remove_folder_name=True)
if all_local_files:
local_file_path = os.path.join( local_dir, all_local_files[0])
if not local_file_path:
remote_file_info = get_remote_file_info(remote_path)
if not remote_file_info:
raise Exception("Remote path does not exist or unaccessible: %s"%(remote_path))
if file_name:
local_file_path = os.path.join(local_dir,
file_name+remote_file_info.get('file_ext'))
else:
local_file_path = os.path.join(local_dir,
remote_file_info.get('file_name') + remote_file_info.get('file_ext'))
if fsclient.isFileExists(local_file_path):
if source is None:
source = self.ctx.config.get('source', None)
if source is None:
raise AzureException('Please specify data source file...')
if source.startswith("http:") or source.startswith("https:"):
url_info = get_remote_file_info(source)
if self.ctx.config.get('source_format', "") == "parquet" or \
url_info.get('file_ext', "").endswith(".parquet"):
dataset = Dataset.Tabular.from_parquet_files(path=source)
else:
dataset = Dataset.Tabular.from_delimited_files(path=source)
dataset_name = url_info.get('file_name')+url_info.get('file_ext')
else:
with fsclient.with_s3_downloaded_or_local_file(source) as local_path:
ds = self.ws.get_default_datastore()
if self.ctx.config.path and not local_path.startswith("/"):
local_path = os.path.join(self.ctx.config.path, local_path)
ds.upload_files(files=[local_path], relative_root=None,
target_path=None, overwrite=True, show_progress=True)
dataset_name = os.path.basename(local_path)
if dataset_name.endswith(".parquet") or self.ctx.config.get('source_format', "") == "parquet":
dataset = Dataset.Tabular.from_parquet_files(path=ds.path(dataset_name))
else:
dataset = Dataset.Tabular.from_delimited_files(
path=ds.path(dataset_name))
dataset.register(workspace = ws, name = dataset_name,
create_new_version = True)