import torch def compute_feature_importance(model, inputs, target, loss_fn): """ Computes feature importance via input gradients. """ inputs = {k: v.clone().detach().requires_grad_(True) for k, v in inputs.items()} output = model(inputs) loss = loss_fn(output, target) loss.backward() importances = {k: v.grad.abs().mean(dim=0).cpu().numpy() for k, v in inputs.items()} return importances def visualize_attention(attn_weights, ax=None): """ Visualizes attention weights as a heatmap. """ import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots() cax = ax.matshow(attn_weights, cmap='viridis') plt.colorbar(cax, ax=ax) ax.set_title("Attention Weights") plt.show()