How to use the torchcam.saliency.explainer.get_explainer function in torchcam

To help you get started, we’ve selected a few torchcam examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github Tramac / pytorch-cam / torchcam / saliency / saliency.py View on Github external
def get_saliency(model, raw_input, input, label, method='gradcam', layer_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    input = input.to(device)
    if label is not None:
        label = label.to(device)

    if input.grad is not None:
        input.grad.zero_()
    if label is not None and label.grad is not None:
        label.grad.zero_()
    model.eval()
    model.zero_grad()

    exp = get_explainer(method, model, layer_path)
    saliency = exp.explain(input, label, raw_input)

    if saliency is not None:
        saliency = saliency.abs().sum(dim=1)[0].squeeze()
        saliency -= saliency.min()
        saliency /= (saliency.max() + 1e-20)
        return saliency.detach().cpu().numpy()
    else:
        return None

torchcam

Class activation maps for your PyTorch CNN models

Apache-2.0
Latest version published 6 months ago

Package Health Score

71 / 100
Full package analysis

Similar packages