Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_init(self):
fulldir=os.getcwd()+"/tests/test_google"
os.chdir(fulldir)
# load config(s) from test app
ctx = Context()
a2ml = A2ML(ctx)
assert len(a2ml.runner.providers)==3
assert isinstance(a2ml.runner.providers[0], GoogleA2ML)
def test_calculate_scores(self):
model_path = 'tests/fixtures/test_predict_by_model/iris'
options = fsclient.read_json_file(os.path.join(model_path, "options.json"))
y_test, _ = ModelHelper.preprocess_target(model_path,
records=[["setosa"], ["versicolor"], ["virginica"], ["setosa"], ["versicolor"], ["virginica"]],
features=["species"]
)
y_pred, _ = ModelHelper.preprocess_target(model_path,
records=[["setosa"], ["versicolor"], ["versicolor"], ["setosa"], ["versicolor"], ["virginica"]],
features=["species"]
)
scores = ModelHelper.calculate_scores(options, y_test=y_test, y_pred=y_pred)
self.assertEqual(len(scores), len(options['scoreNames']))
self.assertTrue(scores['accuracy']>0.8)
def test_distribution_chart_stats_for_categorical_target():
model_path = 'tests/fixtures/test_distribution_chart_stats/adult'
date_from = datetime.date(2020, 2, 16)
date_to = datetime.date(2020, 2, 20)
res = ModelReview(load_metric_task_params(model_path)).distribution_chart_stats(date_from, date_to)
assert type(res) is dict
assert type(res[str(date_to)]) is dict
assert res[str(date_to)] == {
'income': {'dist': {' <=50K': 1}, 'imp': 0},
'age': {'avg': 0.0, 'std_dev': 0, 'imp': 0.716105},
'workclass': {'dist': {0: 1}, 'imp': 0.120064},
'fnlwgt': {'avg': 0.0, 'std_dev': 0, 'imp': 1.0},
'education': {'dist': {0: 1}, 'imp': 0.299958},
'education-num': {'avg': 0.0, 'std_dev': 0, 'imp': 0},
'marital-status': {'dist': {0: 1}, 'imp': 0.143296},
'occupation': {'dist': {0: 1}, 'imp': 0.209677},
'relationship': {'dist': {0: 1}, 'imp': 0.086982},
'race': {'dist': {0: 1}, 'imp': 0.041806},
'sex': {'dist': {0: 1}, 'imp': 0.039482},
'capital-gain': {'avg': 0.0, 'std_dev': 0, 'imp': 0.313237},
def test_set_support_review_model_flag():
# setup
model_path = 'tests/fixtures/test_set_support_review_model_flag'
shutil.copyfile(model_path + '/options_original.json', model_path + '/options.json')
# test
ModelReview({'model_path': model_path}).set_support_review_model_flag(True)
res = {}
with open(model_path + '/options.json', 'r') as f:
res = json.load(f)
assert res['support_review_model'] == True
# teardown
os.remove(model_path + '/options.json')
def test_count_actuals_by_prediction_id():
model_path = 'tests/fixtures/test_count_actuals_by_prediction_id/adult'
res = ModelReview({'model_path': model_path}).count_actuals_by_prediction_id()
assert type(res) is dict
assert len(res) > 0
assert res == {
'ffa89d52-5300-412d-b7a4-d21b3c9b7d16': 2,
'5d9f640d-529a-42bd-be85-172107249a01': 1,
'066f3c25-80ee-4c75-af15-38cda8a4ad57': 1
}
def test_distribution_chart_stats():
model_path = 'tests/fixtures/test_distribution_chart_stats/bikesharing'
date_from = datetime.date(2020, 2, 16)
date_to = datetime.date(2020, 2, 19)
res = ModelReview(load_metric_task_params(model_path)).distribution_chart_stats(date_from, date_to)
assert type(res) is dict
assert type(res[str(date_to)]) is dict
assert res[str(date_to)] == {
'cnt': { 'avg': 483.18357849636016, 'std_dev': 0.0, 'imp': 0 },
'dteday': { 'avg': 0.0, 'std_dev': 0.0, 'imp': 0 },
'season': { 'dist': { 0: 2 }, 'imp': 0},
'yr': { 'dist': { 0: 2 }, 'imp': 0 },
'mnth': { 'avg': 0.0, 'std_dev': 0.0, 'imp': 0 },
'holiday': { 'dist': { 0: 2 }, 'imp': 0},
'weekday': { 'dist': { 0: 2 }, 'imp': 0},
'workingday': { 'dist': { 0: 2 }, 'imp': 0},
'weathersit': { 'dist': { 0: 2 }, 'imp': 0},
'temp': { 'avg': 0.0, 'std_dev': 0.0, 'imp': 0 },
'atemp': { 'avg': 0.0, 'std_dev': 0.0, 'imp': 0 },
'hum': { 'avg': 0.0, 'std_dev': 0.0, 'imp': 0 },
def test_get_feature_importances_general_metrics_cache():
model_path = 'tests/fixtures/test_distribution_chart_stats/adult'
params = load_metric_task_params(model_path)
res = ModelReview(params)._get_feature_importances()
assert res == {'workclass': 0.12006373015194421, 'sex': 0.039481754114499897,
'occupation': 0.20967661413259162, 'education': 0.2999579889231273,
'relationship': 0.08698243068672135, 'marital-status': 0.14329620992107325,
'race': 0.04180630794271793, 'native-country': 0.02072552576600564,
'capital-loss': 0.2571256791934569, 'capital-gain': 0.31323744185565716,
'hours-per-week': 0.4246393312722869, 'age': 0.7161049235052714, 'fnlwgt': 1.0}
def test_score_model_performance_daily_no_matching_actuals_and_predictions():
model_path = 'tests/fixtures/test_score_model_performance_daily/iris_no_matches'
date_from = datetime.date(2020, 2, 16)
date_to = datetime.date(2020, 2, 18)
res = ModelReview({'model_path': model_path}).score_model_performance_daily(date_from, str(date_to))
assert type(res) is dict
assert res[str(date_from)] == 0
def test_import_data(self):
print("Current directory: {}".format(os.getcwd()))
self.ctx = Context()
print("Project name: {}".format(self.ctx.config.get('name')))
a2ml = AzureA2ML(self.ctx.copy('azure'))
a2ml.import_data()
def test_list_projects_server(self):
from a2ml.api.a2ml_project import A2MLProject
ctx = Context(
path=os.path.join(
os.environ.get('A2ML_PROJECT_PATH', ''),
'cli-integration-test'
),
debug=True
)
provider = "azure"
ctx.config.set('providers', [provider])
ctx.config.set('use_server', True)
A2MLProject(ctx, provider).list()