Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# attributions has the same dimensionality as inputs
attributions = tuple(
total_grad * (input - baseline)
for total_grad, input, baseline in zip(total_grads, inputs, baselines)
)
if return_convergence_delta:
start_point, end_point = baselines, inputs
# computes approximation error based on the completeness axiom
delta = self.compute_convergence_delta(
attributions,
start_point,
end_point,
additional_forward_args=additional_forward_args,
target=target,
)
return _format_attributions(is_inputs_tuple, attributions), delta
return _format_attributions(is_inputs_tuple, attributions)
additional_forward_args
)
gradient_mask = apply_gradient_requirements(inputs)
_, input_grads = _forward_layer_eval_with_neuron_grads(
self.forward_func,
inputs,
self.layer,
additional_forward_args,
neuron_index,
device_ids=self.device_ids,
attribute_to_layer_input=attribute_to_neuron_input,
)
undo_gradient_requirements(inputs, gradient_mask)
return _format_attributions(is_inputs_tuple, input_grads)
output_attr.append(
guided_backprop_attr[i]
* LayerAttribution.interpolate(
grad_cam_attr,
inputs[i].shape[2:],
interpolate_mode=interpolate_mode,
)
)
except (RuntimeError, NotImplementedError):
warnings.warn(
"Couldn't appropriately interpolate GradCAM attributions for "
"some input tensors, returning None for corresponding attributions."
)
output_attr.append(None)
return _format_attributions(is_inputs_tuple, tuple(output_attr))
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
is_inputs_tuple = isinstance(inputs, tuple)
inputs = _format_input(inputs)
gradient_mask = apply_gradient_requirements(inputs)
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
)
attributions = tuple(
input * gradient for input, gradient in zip(inputs, gradients)
)
undo_gradient_requirements(inputs, gradient_mask)
return _format_attributions(is_inputs_tuple, attributions)
warnings.warn(
"Setting backward hooks on ReLU activations."
"The hooks will be removed after the attribution is finished"
)
self.model.apply(self._register_hooks)
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
)
# remove set hooks
self._remove_hooks()
undo_gradient_requirements(inputs, gradient_mask)
return _format_attributions(is_inputs_tuple, gradients)
is_inputs_tuple = isinstance(inputs, tuple)
inputs = _format_input(inputs)
gradient_mask = apply_gradient_requirements(inputs)
# No need to format additional_forward_args here.
# They are being formated in the `_run_forward` function in `common.py`
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
)
if abs:
attributions = tuple(torch.abs(gradient) for gradient in gradients)
else:
attributions = gradients
undo_gradient_requirements(inputs, gradient_mask)
return _format_attributions(is_inputs_tuple, attributions)
def _apply_checks_and_return_attributions(
self, attributions, is_inputs_tuple, return_convergence_delta, delta
):
attributions = _format_attributions(is_inputs_tuple, attributions)
return (
(attributions, delta)
if self.is_delta_supported and return_convergence_delta
else attributions
)
additional_forward_args=exp_addit_args,
return_convergence_delta=return_convergence_delta,
custom_attribution_func=custom_attribution_func,
)
if return_convergence_delta:
attributions, delta = attributions
attributions = tuple(
self._compute_mean_across_baselines(inp_bsz, base_bsz, attribution)
for attribution in attributions
)
if return_convergence_delta:
return _format_attributions(is_inputs_tuple, attributions), delta
else:
return _format_attributions(is_inputs_tuple, attributions)
attributions = tuple(
total_grad * (input - baseline)
for total_grad, input, baseline in zip(total_grads, inputs, baselines)
)
if return_convergence_delta:
start_point, end_point = baselines, inputs
# computes approximation error based on the completeness axiom
delta = self.compute_convergence_delta(
attributions,
start_point,
end_point,
additional_forward_args=additional_forward_args,
target=target,
)
return _format_attributions(is_inputs_tuple, attributions), delta
return _format_attributions(is_inputs_tuple, attributions)