"""
Explainable AI (XAI) for Physiological Signal Analysis
This module provides interpretability and explainability tools for machine learning
and deep learning models applied to physiological signals.
Supported Methods:
1. SHAP (SHapley Additive exPlanations)
2. LIME (Local Interpretable Model-agnostic Explanations)
3. GradCAM for 1D signals
4. Attention Visualization
5. Feature Importance Analysis
Author: vitalDSP
License: MIT
"""
import numpy as np
from typing import Optional, Union, Tuple, List, Dict, Any, Callable
from pathlib import Path
import warnings
# Core dependencies
from scipy.ndimage import gaussian_filter1d
# Optional dependencies
try:
import shap
SHAP_AVAILABLE = True
except ImportError:
SHAP_AVAILABLE = False
try:
import lime
from lime import lime_tabular
LIME_AVAILABLE = True
except ImportError:
LIME_AVAILABLE = False
try:
import tensorflow as tf
from tensorflow import keras
TENSORFLOW_AVAILABLE = True
except ImportError:
TENSORFLOW_AVAILABLE = False
try:
import torch
import torch.nn as nn
PYTORCH_AVAILABLE = True
except ImportError:
PYTORCH_AVAILABLE = False
try:
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
[docs]
class BaseExplainer:
"""
Base class for all explainability methods.
Provides common functionality for model interpretation and visualization.
"""
def __init__(
self,
model: Any,
feature_names: Optional[List[str]] = None,
class_names: Optional[List[str]] = None,
):
"""
Initialize base explainer.
Parameters
----------
model : object
Trained model to explain
feature_names : list of str, optional
Names of features
class_names : list of str, optional
Names of classes
"""
self.model = model
self.feature_names = feature_names
self.class_names = class_names
self.explanations = {}
[docs]
def explain(self, X: np.ndarray, **kwargs) -> Dict[str, Any]:
"""
Generate explanations for predictions.
Parameters
----------
X : np.ndarray
Input samples to explain
**kwargs
Additional arguments
Returns
-------
dict
Explanation results
"""
raise NotImplementedError("Subclasses must implement explain()")
[docs]
def plot(self, explanation: Dict[str, Any], **kwargs):
"""
Visualize explanations.
Parameters
----------
explanation : dict
Explanation to visualize
**kwargs
Additional plotting arguments
"""
raise NotImplementedError("Subclasses must implement plot()")
[docs]
class SHAPExplainer(BaseExplainer):
"""
SHAP (SHapley Additive exPlanations) for model interpretation.
SHAP values represent the contribution of each feature to the prediction,
based on game-theoretic Shapley values.
Supports:
- TreeExplainer (for tree-based models)
- DeepExplainer (for deep learning models)
- KernelExplainer (model-agnostic)
Examples
--------
>>> from vitalDSP.ml_models.explainability import SHAPExplainer
>>> from sklearn.ensemble import RandomForestClassifier
>>> import numpy as np
>>>
>>> # Train a model
>>> X_train = np.random.randn(1000, 50)
>>> y_train = np.random.randint(0, 2, 1000)
>>> model = RandomForestClassifier()
>>> model.fit(X_train, y_train)
>>>
>>> # Create explainer
>>> explainer = SHAPExplainer(
... model,
... explainer_type='tree',
... feature_names=[f'feature_{i}' for i in range(50)]
... )
>>>
>>> # Explain predictions
>>> X_test = np.random.randn(10, 50)
>>> explanations = explainer.explain(X_test, background_data=X_train)
>>>
>>> # Visualize
>>> explainer.plot_summary(explanations)
>>> explainer.plot_waterfall(explanations, instance_idx=0)
"""
def __init__(
self,
model: Any,
explainer_type: str = "kernel",
feature_names: Optional[List[str]] = None,
class_names: Optional[List[str]] = None,
):
"""
Initialize SHAP explainer.
Parameters
----------
model : object
Trained model to explain
explainer_type : str, default='kernel'
Type of SHAP explainer ('tree', 'deep', 'kernel', 'linear')
feature_names : list of str, optional
Names of features
class_names : list of str, optional
Names of classes
"""
super().__init__(model, feature_names, class_names)
if not SHAP_AVAILABLE:
raise ImportError("SHAP is not installed. Install with: pip install shap")
self.explainer_type = explainer_type.lower()
self.explainer = None
def _create_explainer(self, background_data: Optional[np.ndarray] = None):
"""Create SHAP explainer based on type."""
if self.explainer_type == "tree":
# For tree-based models (RandomForest, XGBoost, LightGBM)
self.explainer = shap.TreeExplainer(self.model)
elif self.explainer_type == "deep":
# For deep learning models
if not TENSORFLOW_AVAILABLE and not PYTORCH_AVAILABLE:
raise ImportError("TensorFlow or PyTorch required for DeepExplainer")
if background_data is None:
raise ValueError("background_data required for DeepExplainer")
self.explainer = shap.DeepExplainer(self.model, background_data)
elif self.explainer_type == "kernel":
# Model-agnostic explainer
if background_data is None:
raise ValueError("background_data required for KernelExplainer")
# Create prediction function
if hasattr(self.model, "predict_proba"):
predict_fn = self.model.predict_proba
elif hasattr(self.model, "predict"):
predict_fn = self.model.predict
else:
raise ValueError("Model must have predict or predict_proba method")
self.explainer = shap.KernelExplainer(predict_fn, background_data)
elif self.explainer_type == "linear":
# For linear models
self.explainer = shap.LinearExplainer(self.model, background_data)
else:
raise ValueError(f"Unknown explainer type: {self.explainer_type}")
[docs]
def explain(
self,
X: np.ndarray,
background_data: Optional[np.ndarray] = None,
nsamples: int = 100,
) -> Dict[str, Any]:
"""
Generate SHAP explanations.
Parameters
----------
X : np.ndarray
Samples to explain
background_data : np.ndarray, optional
Background dataset for kernel/deep explainers
nsamples : int, default=100
Number of samples for kernel explainer
Returns
-------
dict
Dictionary containing:
- 'shap_values': SHAP values for each sample
- 'base_values': Base values (expected model output)
- 'data': Original input data
"""
# Create explainer if not already created
if self.explainer is None:
self._create_explainer(background_data)
# Compute SHAP values
if self.explainer_type == "kernel":
shap_values = self.explainer.shap_values(X, nsamples=nsamples)
else:
shap_values = self.explainer.shap_values(X)
# Get base values
if hasattr(self.explainer, "expected_value"):
base_values = self.explainer.expected_value
else:
base_values = None
explanation = {
"shap_values": shap_values,
"base_values": base_values,
"data": X,
"feature_names": self.feature_names,
}
self.explanations["last"] = explanation
return explanation
[docs]
def plot_summary(
self,
explanation: Optional[Dict[str, Any]] = None,
plot_type: str = "dot",
max_display: int = 20,
show: bool = True,
):
"""
Create SHAP summary plot.
Parameters
----------
explanation : dict, optional
Explanation to plot. If None, use last explanation.
plot_type : str, default='dot'
Type of plot ('dot', 'bar', 'violin')
max_display : int, default=20
Maximum number of features to display
show : bool, default=True
Whether to show the plot
"""
if explanation is None:
explanation = self.explanations.get("last")
if explanation is None:
raise ValueError("No explanation available. Run explain() first.")
shap_values = explanation["shap_values"]
data = explanation["data"]
# Create summary plot
shap.summary_plot(
shap_values,
data,
feature_names=self.feature_names,
plot_type=plot_type,
max_display=max_display,
show=show,
)
[docs]
def plot_waterfall(
self,
explanation: Optional[Dict[str, Any]] = None,
instance_idx: int = 0,
max_display: int = 20,
show: bool = True,
):
"""
Create SHAP waterfall plot for a single prediction.
Parameters
----------
explanation : dict, optional
Explanation to plot
instance_idx : int, default=0
Index of instance to explain
max_display : int, default=20
Maximum number of features to display
show : bool, default=True
Whether to show the plot
"""
if explanation is None:
explanation = self.explanations.get("last")
if explanation is None:
raise ValueError("No explanation available. Run explain() first.")
shap_values = explanation["shap_values"]
base_values = explanation["base_values"]
data = explanation["data"]
# Handle multi-class case
if isinstance(shap_values, list):
shap_values = shap_values[0]
if isinstance(base_values, (list, np.ndarray)):
base_values = base_values[0]
# Create waterfall plot
shap_exp = shap.Explanation(
values=shap_values[instance_idx],
base_values=base_values,
data=data[instance_idx],
feature_names=self.feature_names,
)
shap.plots.waterfall(shap_exp, max_display=max_display, show=show)
[docs]
def plot_force(
self,
explanation: Optional[Dict[str, Any]] = None,
instance_idx: int = 0,
matplotlib: bool = True,
):
"""
Create SHAP force plot.
Parameters
----------
explanation : dict, optional
Explanation to plot
instance_idx : int, default=0
Index of instance to explain
matplotlib : bool, default=True
Whether to use matplotlib backend
"""
if explanation is None:
explanation = self.explanations.get("last")
if explanation is None:
raise ValueError("No explanation available. Run explain() first.")
shap_values = explanation["shap_values"]
base_values = explanation["base_values"]
data = explanation["data"]
# Handle multi-class case
if isinstance(shap_values, list):
shap_values = shap_values[0]
if isinstance(base_values, (list, np.ndarray)):
base_values = base_values[0]
# Create force plot
shap.force_plot(
base_values,
shap_values[instance_idx],
data[instance_idx],
feature_names=self.feature_names,
matplotlib=matplotlib,
)
[docs]
def plot_dependence(
self,
feature_idx: Union[int, str],
explanation: Optional[Dict[str, Any]] = None,
interaction_idx: Union[int, str, None] = "auto",
show: bool = True,
):
"""
Create SHAP dependence plot showing feature interaction.
Parameters
----------
feature_idx : int or str
Feature to plot
explanation : dict, optional
Explanation to plot
interaction_idx : int, str, or None, default='auto'
Feature to show interaction with
show : bool, default=True
Whether to show the plot
"""
if explanation is None:
explanation = self.explanations.get("last")
if explanation is None:
raise ValueError("No explanation available. Run explain() first.")
shap_values = explanation["shap_values"]
data = explanation["data"]
# Handle multi-class case
if isinstance(shap_values, list):
shap_values = shap_values[0]
shap.dependence_plot(
feature_idx,
shap_values,
data,
feature_names=self.feature_names,
interaction_index=interaction_idx,
show=show,
)
[docs]
class LIMEExplainer(BaseExplainer):
"""
LIME (Local Interpretable Model-agnostic Explanations).
LIME explains individual predictions by approximating the model locally
with an interpretable model.
Examples
--------
>>> from vitalDSP.ml_models.explainability import LIMEExplainer
>>> from sklearn.ensemble import RandomForestClassifier
>>> import numpy as np
>>>
>>> # Train a model
>>> X_train = np.random.randn(1000, 50)
>>> y_train = np.random.randint(0, 2, 1000)
>>> model = RandomForestClassifier()
>>> model.fit(X_train, y_train)
>>>
>>> # Create explainer
>>> explainer = LIMEExplainer(
... model,
... training_data=X_train,
... feature_names=[f'feature_{i}' for i in range(50)],
... class_names=['Normal', 'Abnormal']
... )
>>>
>>> # Explain a prediction
>>> X_test = np.random.randn(1, 50)
>>> explanation = explainer.explain(X_test[0])
>>> explainer.plot(explanation)
"""
def __init__(
self,
model: Any,
training_data: np.ndarray,
mode: str = "classification",
feature_names: Optional[List[str]] = None,
class_names: Optional[List[str]] = None,
discretize_continuous: bool = False,
):
"""
Initialize LIME explainer.
Parameters
----------
model : object
Trained model to explain
training_data : np.ndarray
Training data for LIME
mode : str, default='classification'
'classification' or 'regression'
feature_names : list of str, optional
Names of features
class_names : list of str, optional
Names of classes (for classification)
discretize_continuous : bool, default=False
Whether to discretize continuous features
"""
super().__init__(model, feature_names, class_names)
if not LIME_AVAILABLE:
raise ImportError("LIME is not installed. Install with: pip install lime")
self.training_data = training_data
self.mode = mode.lower()
self.discretize_continuous = discretize_continuous
# Create LIME explainer
self.lime_explainer = lime_tabular.LimeTabularExplainer(
training_data,
mode=self.mode,
feature_names=feature_names,
class_names=class_names,
discretize_continuous=discretize_continuous,
)
[docs]
def explain(
self,
instance: np.ndarray,
num_features: int = 10,
num_samples: int = 5000,
labels: Optional[Tuple[int, ...]] = None,
) -> Any:
"""
Explain a single prediction.
Parameters
----------
instance : np.ndarray
Instance to explain (1D array)
num_features : int, default=10
Number of features to include in explanation
num_samples : int, default=5000
Number of samples for local approximation
labels : tuple of int, optional
Labels to explain (for classification)
Returns
-------
lime.explanation.Explanation
LIME explanation object
"""
# Get prediction function
if self.mode == "classification":
if hasattr(self.model, "predict_proba"):
predict_fn = self.model.predict_proba
else:
raise ValueError("Model must have predict_proba for classification")
else:
predict_fn = self.model.predict
# Generate explanation
explanation = self.lime_explainer.explain_instance(
instance,
predict_fn,
num_features=num_features,
num_samples=num_samples,
labels=labels,
)
return explanation
[docs]
def plot(
self,
explanation: Any,
label: Optional[int] = None,
figsize: Tuple[int, int] = (10, 6),
):
"""
Plot LIME explanation.
Parameters
----------
explanation : lime.explanation.Explanation
Explanation to plot
label : int, optional
Label to visualize (for classification)
figsize : tuple, default=(10, 6)
Figure size
"""
if not MATPLOTLIB_AVAILABLE:
warnings.warn("Matplotlib not available, using default plot")
explanation.show_in_notebook()
return
# Use LIME's built-in visualization
if self.mode == "classification" and label is not None:
fig = explanation.as_pyplot_figure(label=label)
else:
fig = explanation.as_pyplot_figure()
plt.tight_layout()
plt.show()
[docs]
class GradCAM1D:
"""
Gradient-weighted Class Activation Mapping (GradCAM) for 1D signals.
Visualizes which parts of the signal are important for the prediction
by computing gradients of the output with respect to feature maps.
Examples
--------
>>> from vitalDSP.ml_models.explainability import GradCAM1D
>>> import tensorflow as tf
>>> import numpy as np
>>>
>>> # Assume we have a trained 1D CNN model
>>> model = tf.keras.models.load_model('my_cnn_model.h5')
>>>
>>> # Create GradCAM explainer
>>> gradcam = GradCAM1D(model, layer_name='conv1d_3')
>>>
>>> # Generate heatmap for a signal
>>> signal = np.random.randn(1, 1000, 1)
>>> heatmap = gradcam.compute_heatmap(signal, class_idx=1)
>>>
>>> # Visualize
>>> gradcam.plot_overlay(signal[0], heatmap)
"""
def __init__(
self, model: Any, layer_name: Optional[str] = None, backend: str = "tensorflow"
):
"""
Initialize GradCAM.
Parameters
----------
model : object
Trained neural network model
layer_name : str, optional
Name of convolutional layer to visualize. If None, use last conv layer.
backend : str, default='tensorflow'
Deep learning backend ('tensorflow' or 'pytorch')
"""
self.model = model
self.layer_name = layer_name
self.backend = backend.lower()
if self.backend == "tensorflow":
if not TENSORFLOW_AVAILABLE:
raise ImportError("TensorFlow not installed")
self._setup_tensorflow()
elif self.backend == "pytorch":
if not PYTORCH_AVAILABLE:
raise ImportError("PyTorch not installed")
self._setup_pytorch()
else:
raise ValueError(f"Unknown backend: {backend}")
def _setup_tensorflow(self):
"""Setup for TensorFlow backend."""
# Find last convolutional layer if not specified
if self.layer_name is None:
for layer in reversed(self.model.layers):
if isinstance(layer, keras.layers.Conv1D):
self.layer_name = layer.name
break
if self.layer_name is None:
raise ValueError("No Conv1D layer found in model")
# Create gradient model
self.grad_model = keras.Model(
inputs=self.model.input,
outputs=[self.model.get_layer(self.layer_name).output, self.model.output],
)
def _setup_pytorch(self):
"""Setup for PyTorch backend."""
# Find last convolutional layer if not specified
if self.layer_name is None:
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, nn.Conv1d):
self.layer_name = name
break
if self.layer_name is None:
raise ValueError("No Conv1d layer found in model")
# Register hooks for gradients
self.gradients = None
self.activations = None
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
def forward_hook(module, input, output):
self.activations = output
# Get target layer
target_layer = dict(self.model.named_modules())[self.layer_name]
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
[docs]
def compute_heatmap(
self, signal: np.ndarray, class_idx: Optional[int] = None
) -> np.ndarray:
"""
Compute GradCAM heatmap.
Parameters
----------
signal : np.ndarray
Input signal of shape (batch, length, channels) or (length, channels)
class_idx : int, optional
Target class index. If None, use predicted class.
Returns
-------
np.ndarray
Heatmap of shape (length,)
"""
if signal.ndim == 2:
signal = signal[np.newaxis, :]
if self.backend == "tensorflow":
return self._compute_heatmap_tensorflow(signal, class_idx)
else:
return self._compute_heatmap_pytorch(signal, class_idx)
def _compute_heatmap_tensorflow(
self, signal: np.ndarray, class_idx: Optional[int] = None
) -> np.ndarray:
"""Compute heatmap using TensorFlow."""
with tf.GradientTape() as tape:
conv_outputs, predictions = self.grad_model(signal)
if class_idx is None:
class_idx = tf.argmax(predictions[0])
loss = predictions[:, class_idx]
# Compute gradients
grads = tape.gradient(loss, conv_outputs)
# Pool gradients across channels
pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
# Weight feature maps by gradients
conv_outputs = conv_outputs[0]
for i in range(pooled_grads.shape[-1]):
conv_outputs[:, i] *= pooled_grads[i]
# Create heatmap
heatmap = tf.reduce_mean(conv_outputs, axis=-1)
heatmap = tf.maximum(heatmap, 0) # ReLU
heatmap = heatmap / (tf.reduce_max(heatmap) + 1e-10) # Normalize
# Resize to input length — tf.image.resize needs (batch, h, w, c)
original_length = signal.shape[1]
heatmap = tf.image.resize(
heatmap[tf.newaxis, :, tf.newaxis], (original_length, 1)
)
heatmap = tf.squeeze(heatmap).numpy()
return heatmap
def _compute_heatmap_pytorch(
self, signal: np.ndarray, class_idx: Optional[int] = None
) -> np.ndarray:
"""Compute heatmap using PyTorch."""
signal_tensor = torch.FloatTensor(signal).permute(
0, 2, 1
) # (batch, channels, length)
signal_tensor.requires_grad = True
# Forward pass
self.model.eval()
output = self.model(signal_tensor)
if class_idx is None:
class_idx = output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
output[0, class_idx].backward()
# Get gradients and activations
gradients = self.gradients.detach().cpu()
activations = self.activations.detach().cpu()
# Pool gradients
pooled_grads = torch.mean(gradients, dim=(0, 2))
# Weight activations
for i in range(pooled_grads.shape[0]):
activations[:, i, :] *= pooled_grads[i]
# Create heatmap
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = torch.relu(heatmap)
heatmap = heatmap / (torch.max(heatmap) + 1e-10)
# Resize to input length
original_length = signal.shape[1]
heatmap = torch.nn.functional.interpolate(
heatmap.unsqueeze(0).unsqueeze(0), size=original_length, mode="linear"
)
heatmap = heatmap.squeeze().numpy()
return heatmap
[docs]
def plot_overlay(
self,
signal: np.ndarray,
heatmap: np.ndarray,
alpha: float = 0.4,
colormap: str = "jet",
figsize: Tuple[int, int] = (15, 5),
):
"""
Plot signal with GradCAM heatmap overlay.
Parameters
----------
signal : np.ndarray
Original signal (1D or 2D with channels)
heatmap : np.ndarray
GradCAM heatmap
alpha : float, default=0.4
Transparency of heatmap overlay
colormap : str, default='jet'
Colormap for heatmap
figsize : tuple, default=(15, 5)
Figure size
"""
if not MATPLOTLIB_AVAILABLE:
raise ImportError("Matplotlib is required for visualization")
# Flatten signal if multi-channel
if signal.ndim > 1:
signal = signal[:, 0]
fig, ax = plt.subplots(figsize=figsize)
# Plot signal
time = np.arange(len(signal))
ax.plot(time, signal, "k-", linewidth=1, label="Signal")
# Create colored overlay
cmap = plt.cm.get_cmap(colormap)
colors = cmap(heatmap)
# Plot heatmap as background
for i in range(len(time) - 1):
ax.axvspan(time[i], time[i + 1], alpha=alpha * heatmap[i], color=colors[i])
ax.set_xlabel("Time")
ax.set_ylabel("Amplitude")
ax.set_title("GradCAM Visualization")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
[docs]
class AttentionVisualizer:
"""
Visualize attention weights from transformer models.
Examples
--------
>>> from vitalDSP.ml_models.explainability import AttentionVisualizer
>>> import numpy as np
>>>
>>> # Assume we have attention weights from a transformer
>>> attention_weights = np.random.rand(8, 100, 100) # (n_heads, seq_len, seq_len)
>>>
>>> # Create visualizer
>>> viz = AttentionVisualizer()
>>>
>>> # Plot attention patterns
>>> viz.plot_attention_map(attention_weights, head_idx=0)
>>> viz.plot_attention_rollout(attention_weights)
"""
def __init__(self):
"""Initialize attention visualizer."""
if not MATPLOTLIB_AVAILABLE:
raise ImportError("Matplotlib is required for visualization")
[docs]
def plot_attention_map(
self,
attention_weights: np.ndarray,
head_idx: int = 0,
figsize: Tuple[int, int] = (10, 10),
title: Optional[str] = None,
):
"""
Plot attention map for a specific head.
Parameters
----------
attention_weights : np.ndarray
Attention weights of shape (n_heads, seq_len, seq_len)
head_idx : int, default=0
Index of attention head to visualize
figsize : tuple, default=(10, 10)
Figure size
title : str, optional
Plot title
"""
if attention_weights.ndim == 2:
attn = attention_weights
else:
attn = attention_weights[head_idx]
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(attn, cmap="viridis", aspect="auto")
ax.set_xlabel("Key Position")
ax.set_ylabel("Query Position")
if title is None:
title = f"Attention Map (Head {head_idx})"
ax.set_title(title)
plt.colorbar(im, ax=ax, label="Attention Weight")
plt.tight_layout()
plt.show()
[docs]
def plot_attention_rollout(
self, attention_weights: np.ndarray, figsize: Tuple[int, int] = (12, 6)
):
"""
Plot attention rollout (average across heads and layers).
Parameters
----------
attention_weights : np.ndarray
Attention weights of shape (n_heads, seq_len, seq_len)
figsize : tuple, default=(12, 6)
Figure size
"""
# Average across heads
if attention_weights.ndim == 3:
attn_rollout = np.mean(attention_weights, axis=0)
else:
attn_rollout = attention_weights
# Compute attention flow
seq_len = attn_rollout.shape[0]
attention_flow = np.sum(attn_rollout, axis=1) # Sum over keys
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# Plot attention matrix
im1 = ax1.imshow(attn_rollout, cmap="viridis", aspect="auto")
ax1.set_xlabel("Key Position")
ax1.set_ylabel("Query Position")
ax1.set_title("Attention Rollout")
plt.colorbar(im1, ax=ax1, label="Attention")
# Plot attention flow
ax2.plot(attention_flow, linewidth=2)
ax2.fill_between(range(seq_len), attention_flow, alpha=0.3)
ax2.set_xlabel("Position")
ax2.set_ylabel("Total Attention")
ax2.set_title("Attention Flow")
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
[docs]
def plot_head_comparison(
self,
attention_weights: np.ndarray,
query_idx: int = 0,
figsize: Tuple[int, int] = (15, 4),
):
"""
Compare attention patterns across different heads.
Parameters
----------
attention_weights : np.ndarray
Attention weights of shape (n_heads, seq_len, seq_len)
query_idx : int, default=0
Query position to visualize
figsize : tuple, default=(15, 4)
Figure size
"""
n_heads = attention_weights.shape[0]
seq_len = attention_weights.shape[1]
fig, axes = plt.subplots(1, n_heads, figsize=figsize)
if n_heads == 1:
axes = [axes]
for head_idx, ax in enumerate(axes):
attn = attention_weights[head_idx, query_idx, :]
ax.plot(attn, linewidth=2)
ax.fill_between(range(seq_len), attn, alpha=0.3)
ax.set_title(f"Head {head_idx}")
ax.set_xlabel("Key Position")
if head_idx == 0:
ax.set_ylabel("Attention Weight")
ax.grid(True, alpha=0.3)
plt.suptitle(f"Attention Patterns for Query Position {query_idx}")
plt.tight_layout()
plt.show()
# Convenience function
[docs]
def explain_prediction(
model: Any,
X: np.ndarray,
method: str = "shap",
background_data: Optional[np.ndarray] = None,
feature_names: Optional[List[str]] = None,
class_names: Optional[List[str]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Quick explanation of model predictions.
Parameters
----------
model : object
Trained model
X : np.ndarray
Samples to explain
method : str, default='shap'
Explanation method ('shap', 'lime')
background_data : np.ndarray, optional
Background data (required for some methods)
feature_names : list of str, optional
Feature names
class_names : list of str, optional
Class names
**kwargs
Additional arguments for the explainer
Returns
-------
dict
Explanation results
Examples
--------
>>> from vitalDSP.ml_models.explainability import explain_prediction
>>> from sklearn.ensemble import RandomForestClassifier
>>> import numpy as np
>>>
>>> # Train model
>>> X_train = np.random.randn(1000, 50)
>>> y_train = np.random.randint(0, 2, 1000)
>>> model = RandomForestClassifier()
>>> model.fit(X_train, y_train)
>>>
>>> # Explain test predictions
>>> X_test = np.random.randn(10, 50)
>>> explanation = explain_prediction(
... model, X_test,
... method='shap',
... explainer_type='tree'
... )
"""
method = method.lower()
if method == "shap":
explainer = SHAPExplainer(
model, feature_names=feature_names, class_names=class_names, **kwargs
)
return explainer.explain(X, background_data=background_data)
elif method == "lime":
if background_data is None:
raise ValueError("background_data (training data) required for LIME")
explainer = LIMEExplainer(
model,
training_data=background_data,
feature_names=feature_names,
class_names=class_names,
**kwargs,
)
# LIME explains one instance at a time
if X.ndim == 1 or len(X) == 1:
instance = X if X.ndim == 1 else X[0]
return explainer.explain(instance)
else:
# Explain multiple instances
explanations = []
for instance in X:
exp = explainer.explain(instance)
explanations.append(exp)
return {"explanations": explanations}
else:
raise ValueError(f"Unknown method: {method}")