Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_import_data(self):
print("Current directory: {}".format(os.getcwd()))
self.ctx = Context()
print("Project name: {}".format(self.ctx.config.get('name')))
a2ml = AzureA2ML(self.ctx.copy('azure'))
a2ml.import_data()
def test_list_projects_server(self):
from a2ml.api.a2ml_project import A2MLProject
ctx = Context(
path=os.path.join(
os.environ.get('A2ML_PROJECT_PATH', ''),
'cli-integration-test'
),
debug=True
)
provider = "azure"
ctx.config.set('providers', [provider])
ctx.config.set('use_server', True)
A2MLProject(ctx, provider).list()
def test_train_server(self):
from a2ml.api.a2ml import A2ML
ctx = Context(
path=os.path.join(
os.environ.get('A2ML_PROJECT_PATH', ''),
'cli-integration-test'
),
debug=True
)
provider = "azure"
ctx.config.set('providers', [provider])
ctx.config.set('use_server', True)
A2ML(ctx, provider).train()
def _get_hub_context():
from a2ml.api.auger.project import AugerProject
ctx = Context(debug=task_config.debug)
project = AugerProject(ctx)
return ctx
def create_context(params, new_project=False):
if params.get('context'):
ctx = jsonpickle.decode(params['context'])
# Server mode supports only one provider in request
provider = params.get('provider') or to_list(ctx.config.get('providers'))[0]
ctx.name = provider
ctx.config.name = provider
ctx.set_runs_on_server(True)
ctx.config.set('use_server', False, config_name='config')
ctx.notificator = notificator
ctx.request_id = params['_request_id']
ctx.setup_logger(format='')
else:
ctx = Context(
path=params.get('project_path'),
debug=params.get('debug_log', False)
)
if not new_project:
if params.get('provider'):
ctx.config.set('providers', [params.get('provider')], config_name='config')
if params.get('source_path'):
ctx.config.set('source', params.get('source_path'), config_name='config')
return ctx
def _create_provider_context(params):
provider = params.get('provider', 'auger')
ctx = Context(
name=provider,
path=params.get('hub_info', {}).get('projectPath'),
debug=task_config.debug
)
ctx.set_runs_on_server(True)
ctx.config.set('providers', [provider])
hub_info = params.get('hub_info', {})
provider_info = params.get('provider_info', {}).get(provider, {})
project_name = dict_dig(provider_info, 'project', 'name') or hub_info.get('project_name')
experiment_name = dict_dig(provider_info, 'experiment', 'name') or hub_info.get('experiment_name')
if project_name:
ctx.config.set('name', project_name, provider)
if experiment_name:
ctx.config.set('experiment/name', experiment_name, provider)
def copy(self, name):
"""creates a copy of an existing Context
Args:
name (str): The name of the config file. Default is 'config'
Returns:
object: Context object
Example:
.. code-block:: python
ctx = Context()
new_ctx = ctx.copy()
"""
new = Context(name, self.config.path, self.debug)
new.set_runs_on_server(self._runs_on_server)
new.notificator = self.notificator
new.request_id = self.request_id
new.config.parts = self.config.parts
if self._runs_on_server:
new.credentials = self.credentials
return new
def _update_hub_leaderboad(params, leaderboard):
from a2ml.api.auger.experiment import AugerExperiment
ctx = Context(debug=task_config.debug)
experiment = AugerExperiment(ctx)
_log(leaderboard, level=logging.DEBUG)
data = {
'type': 'Leaderboard',
'provider': params['provider'],
'hub_info': params['hub_info'],
'trials': _format_leaderboard_for_hub(leaderboard.get('trials', [])),
'evaluate_status': leaderboard.get('evaluate_status'),
}
send_result_to_hub.delay(json.dumps(data))