Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _saliency_base_assert(
self, model, inputs, expected, additional_forward_args=None, nt_type="vanilla"
):
saliency = Saliency(model)
if nt_type == "vanilla":
attributions = saliency.attribute(
inputs, additional_forward_args=additional_forward_args
)
else:
nt = NoiseTunnel(saliency)
attributions = nt.attribute(
inputs,
nt_type=nt_type,
n_samples=10,
stdevs=0.0000002,
additional_forward_args=additional_forward_args,
)
if isinstance(attributions, tuple):
for input, attribution, expected_attr in zip(
inputs, attributions, expected
def _gradient_matching_test_assert(self, model, output_layer, test_input):
out = _forward_layer_eval(model, test_input, output_layer)
gradient_attrib = NeuronGradient(model, output_layer)
for i in range(out.shape[1]):
neuron = (i,)
while len(neuron) < len(out.shape) - 1:
neuron = neuron + (0,)
input_attrib = Saliency(
lambda x: _forward_layer_eval(model, x, output_layer)[
(slice(None), *neuron)
]
)
sal_vals = input_attrib.attribute(test_input, abs=False)
grad_vals = gradient_attrib.attribute(test_input, neuron)
# Verify matching sizes
self.assertEqual(grad_vals.shape, sal_vals.shape)
self.assertEqual(grad_vals.shape, test_input.shape)
assertArraysAlmostEqual(
sal_vals.reshape(-1).tolist(),
grad_vals.reshape(-1).tolist(),
delta=0.001,
)
def _saliency_classification_assert(self, nt_type="vanilla"):
num_in = 5
input = torch.tensor([[0.0, 1.0, 2.0, 3.0, 4.0]], requires_grad=True)
target = torch.tensor(5)
# 10-class classification model
model = SoftmaxModel(num_in, 20, 10)
saliency = Saliency(model)
if nt_type == "vanilla":
attributions = saliency.attribute(input, target)
output = model(input)[:, target]
output.backward()
expected = torch.abs(input.grad)
self.assertEqual(
expected.detach().numpy().tolist(),
attributions.detach().numpy().tolist(),
)
else:
nt = NoiseTunnel(saliency)
attributions = nt.attribute(
input, nt_type=nt_type, n_samples=10, stdevs=0.0002, target=target
)