Source code for chmncc.explanations.gradients

"""Module which deals with the explaination we can retrieve out of the images"""
import torch
import tqdm


[docs]def compute_integrated_gradient( batch_x: torch.Tensor, batch_blank: torch.Tensor, model: torch.nn.Module ) -> torch.Tensor: """Integrated gradients computation Implementation taken from https://github.com/CVxTz/IntegratedGradientsPytorch/blob/main/code/mlp_gradient.py Args: batch_x [torch.Tensor] data instances batch batch_blank [torch.Tensor] batch of zeros model [torch.nn.Module] network Returns: integrated_gradients [torch.Tensor] integrated gradients """ mean_grad = 0 n = 10 for i in tqdm.tqdm(range(1, n + 1), desc="Computing the integrated gradient"): x = batch_blank + i / n * (batch_x - batch_blank) # x.requires_grad = True y = model(x) (grad,) = torch.autograd.grad( y, x, grad_outputs=torch.ones_like(y), ) mean_grad += grad / n integrated_gradients = (batch_x - batch_blank) * mean_grad return integrated_gradients[0]
[docs]def output_gradients(inputs: torch.Tensor, preds: torch.Tensor) -> torch.Tensor: """Compute the gradients with respect to the input Args: inputs [torch.Tensor] input tensor preds [torch.Tensor] output tensor Returns: grad [torch.Tensor] gradients of the output with respect to the input """ return torch.autograd.grad( outputs=preds, inputs=inputs, grad_outputs=torch.ones_like(preds), retain_graph=True, )[0]