Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_structured_data_from_csv_partial_col_type_classifier(tmp_dir):
clf = ak.StructuredDataClassifier(
column_types=common.PARTIAL_COLUMN_TYPES_FROM_CSV,
directory=tmp_dir,
max_trials=1)
clf.fit(x=common.TRAIN_FILE_PATH, y='survived', epochs=2,
validation_data=common.TEST_FILE_PATH)
def test_structured_data_from_numpy_classifier(tmp_dir):
num_data = 500
data = common.structured_data(num_data)
x_train = data
y = np.random.randint(0, 3, num_data)
y_train = y
clf = ak.StructuredDataClassifier(directory=tmp_dir, max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
def test_structured_data_from_csv_less_col_name_classifier(tmp_dir):
with pytest.raises(ValueError) as info:
clf = ak.StructuredDataClassifier(
column_names=common.LESS_COLUMN_NAMES_FROM_CSV,
directory=tmp_dir,
max_trials=1)
clf.fit(x=common.TRAIN_FILE_PATH, y='survived', epochs=2,
validation_data=common.TEST_FILE_PATH)
assert 'Expect column_names to have length' in str(info.value)
def test_structured_data_from_numpy_col_name_classifier(tmp_dir):
num_data = 500
data = common.structured_data(num_data)
x_train = data
y = np.random.randint(0, 3, num_data)
y_train = y
clf = ak.StructuredDataClassifier(
column_names=common.COLUMN_NAMES_FROM_NUMPY,
directory=tmp_dir,
max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
def test_structured_data_from_csv_false_col_type_classifier(tmp_dir):
with pytest.raises(ValueError) as info:
clf = ak.StructuredDataClassifier(
column_types=common.FALSE_COLUMN_TYPES_FROM_CSV,
directory=tmp_dir,
max_trials=1)
clf.fit(x=common.TRAIN_FILE_PATH, y='survived', epochs=2,
validation_data=common.TEST_FILE_PATH)
assert 'Column_types should be either "categorical"' in str(info.value)
def test_structured_data_from_numpy_col_type_classifier(tmp_dir):
num_data = 500
data = common.structured_data(num_data)
x_train = data
y = np.random.randint(0, 3, num_data)
y_train = y
with pytest.raises(ValueError) as info:
clf = ak.StructuredDataClassifier(
column_types=common.COLUMN_TYPES_FROM_NUMPY,
directory=tmp_dir,
max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
assert str(info.value) == 'Column names must be specified.'
def test_structured_data_from_csv_col_name_classifier(tmp_dir):
clf = ak.StructuredDataClassifier(
column_names=common.COLUMN_NAMES_FROM_CSV,
directory=tmp_dir,
max_trials=1)
clf.fit(x=common.TRAIN_FILE_PATH, y='survived', epochs=2,
validation_data=common.TEST_FILE_PATH)
def test_structured_data_from_numpy_col_name_type_classifier(tmp_dir):
num_data = 500
data = common.structured_data(num_data)
x_train = data
y = np.random.randint(0, 3, num_data)
y_train = y
clf = ak.StructuredDataClassifier(
column_names=common.COLUMN_NAMES_FROM_NUMPY,
column_types=common.COLUMN_TYPES_FROM_NUMPY,
directory=tmp_dir,
max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
def test_structured_data_from_csv_classifier(tmp_dir):
clf = ak.StructuredDataClassifier(directory=tmp_dir, max_trials=1)
clf.fit(x=common.TRAIN_FILE_PATH, y='survived', epochs=2,
validation_data=common.TEST_FILE_PATH)