Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_classification(self):
data, target = load_breast_cancer(True)
x_train, x_test, y_train, y_test = train_test_split(data, target,
test_size=0.2,
random_state=42)
ngb = NGBoost(Base=default_tree_learner, Dist=Bernoulli, Score=MLE,
verbose=False)
ngb.fit(x_train, y_train)
preds = ngb.pred_dist(x_test)
score = roc_auc_score(y_test, preds.prob)
assert score >= 0.95
def __init__(
self,
Dist=Bernoulli,
Score=LogScore,
Base=default_tree_learner,
natural_gradient=True,
n_estimators=500,
learning_rate=0.01,
minibatch_frac=1.0,
col_sample=1.0,
verbose=True,
verbose_eval=100,
tol=1e-4,
random_state=None,
):
assert issubclass(
Dist, ClassificationDistn
), f"{Dist.__name__} is not useable for classification."
super().__init__(
def __init__(self, *args, **kwargs):
super(NGBClassifier, self).__init__(Dist=Bernoulli, *args, **kwargs)