How to use torchcam - 5 common examples

To help you get started, we’ve selected a few torchcam examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github Tramac / pytorch-cam / torchcam / saliency / saliency.py View on Github external
def get_saliency(model, raw_input, input, label, method='gradcam', layer_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    input = input.to(device)
    if label is not None:
        label = label.to(device)

    if input.grad is not None:
        input.grad.zero_()
    if label is not None and label.grad is not None:
        label.grad.zero_()
    model.eval()
    model.zero_grad()

    exp = get_explainer(method, model, layer_path)
    saliency = exp.explain(input, label, raw_input)

    if saliency is not None:
        saliency = saliency.abs().sum(dim=1)[0].squeeze()
        saliency -= saliency.min()
        saliency /= (saliency.max() + 1e-20)
        return saliency.detach().cpu().numpy()
    else:
        return None
github Tramac / pytorch-cam / torchcam / saliency / saliency.py View on Github external
def get_image_saliency_plot(image_saliency_results, cols: int = 2, figsize: tuple = None, display=True, save_path=False):
    rows = math.ceil(len(image_saliency_results) / cols)
    figsize = figsize or (8, 3 * rows)
    figure = plt.figure(figsize=figsize)

    for i, r in enumerate(image_saliency_results):
        ax = figure.add_subplot(rows, cols, i + 1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(r.title, fontdict={'fontsize': 18})

        saliency_upsampled = skimage.transform.resize(r.saliency,
                                                      (r.raw_image.height, r.raw_image.width),
                                                      mode='reflect')

        show_image(r.raw_image, img2=saliency_upsampled, alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)

    if display:
        figure.show()
        figure.waitforbuttonpress()
    if save_path:
        figure.savefig(save_path)

    return figure
github Tramac / pytorch-cam / torchcam / cam.py View on Github external
def getCAM(model, raw_image, input, label, layer_path=None, display=True, save_path=False):
    saliency_maps = get_image_saliency_result(model, raw_image, input, label,
                                              methods=['gradcam'], layer_path=layer_path)
    figure = get_image_saliency_plot(saliency_maps, display=display, save_path=save_path)

    return figure
github Tramac / pytorch-cam / torchcam / cam.py View on Github external
def getCAM(model, raw_image, input, label, layer_path=None, display=True, save_path=False):
    saliency_maps = get_image_saliency_result(model, raw_image, input, label,
                                              methods=['gradcam'], layer_path=layer_path)
    figure = get_image_saliency_plot(saliency_maps, display=display, save_path=save_path)

    return figure
github Tramac / pytorch-cam / torchcam / saliency / saliency.py View on Github external
def get_image_saliency_result(model, raw_image, input, label,
                              methods=['smooth_grad', 'gradcam', 'vanilla_grad', 'grad_x_input'],
                              layer_path=None):
    result = list()
    for method in methods:
        sal = get_saliency(model, raw_image, input, label, method=method, layer_path=layer_path)
        if sal is not None:
            result.append(SaliencyImage(raw_image, sal, method))

    return result

torchcam

Class activation maps for your PyTorch CNN models

Apache-2.0
Latest version published 6 months ago

Package Health Score

71 / 100
Full package analysis

Similar packages