How to use the fastai.train.ClassificationInterpretation.from_learner function in fastai

To help you get started, we’ve selected a few fastai examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github fastai / fastai / tests / test_vision_train.py View on Github external
def test_ClassificationInterpretation(learn):
    this_tests(ClassificationInterpretation)
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    conf = interp.most_confused()
    expect = {'3', '7'}
    assert (len(conf) == 0 or
            len(conf) == 1 and (set(conf[0][:2]) == expect) or
            len(conf) == 2 and (set(conf[0][:2]) == set(conf[1][:2]) == expect)
    ), f"conf={conf}"
github fastai / fastai / tests / test_tabular_train.py View on Github external
def test_confusion_tabular(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    assert isinstance(interp.confusion_matrix(), (np.ndarray))
    assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
    this_tests(interp.confusion_matrix)
github fastai / fastai / tests / test_vision_train.py View on Github external
def test_interp(learn):
    this_tests(ClassificationInterpretation.from_learner)
    interp = ClassificationInterpretation.from_learner(learn)
    losses,idxs = interp.top_losses()
    assert len(learn.data.valid_ds)==len(losses)==len(idxs)
github fastai / fastai / fastai / vision / learner.py View on Github external
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
    "Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)
Learner.interpret = _learner_interpret
github fastai / fastai / fastai / tabular / learner.py View on Github external
def _learner_interpret(learn:Learner, ds_type:DatasetType = DatasetType.Valid):
    "Create a 'ClassificationInterpretation' object from 'learner' on 'ds_type'."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
github fastai / fastai / fastai / vision / learner.py View on Github external
print(f'{str(len(mismatches))} misclassified samples over {str(len(self.data.valid_ds))} samples in the validation set.')
    samples = min(samples, len(mismatches))
    for ima in range(len(mismatches_ordered_byloss)):
        mismatchescontainer.append(mismatches_ordered_byloss[ima][0])
    for sampleN in range(samples):
        actualclasses = ''
        for clas in infolist[ordlosses_idxs[sampleN]][2]:
            actualclasses = f'{actualclasses} -- {str(classes_ids[clas][1])}'
        imag = mismatches_ordered_byloss[sampleN][0]
        imag = show_image(imag, figsize=figsize)
        imag.set_title(f"""Predicted: {classes_ids[infolist[ordlosses_idxs[sampleN]][1]][1]} \nActual: {actualclasses}\nLoss: {infolist[ordlosses_idxs[sampleN]][3]}\nProbability: {infolist[ordlosses_idxs[sampleN]][4]}""",
                        loc='left')
        plt.show()
        if save_misclassified: return mismatchescontainer

ClassificationInterpretation.from_learner          = _cl_int_from_learner
ClassificationInterpretation.plot_top_losses       = _cl_int_plot_top_losses
ClassificationInterpretation.plot_multi_top_losses = _cl_int_plot_multi_top_losses
 

def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
    "Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)
Learner.interpret = _learner_interpret
github fastai / fastai / fastai / train.py View on Github external
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid):
    "Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
Learner.interpret = _learner_interpret