Documentation Index
Fetch the complete documentation index at: https://mintlify.com/MilesONerd/neurenix/llms.txt
Use this file to discover all available pages before exploring further.
Explainable AI
The explainability module provides methods for interpreting machine learning models and explaining their predictions. This includes SHAP values, LIME, feature importance, and various visualization techniques to make AI systems more transparent and trustworthy.
Overview
Explainable AI techniques help answer:
- Why did the model make this prediction?
- Which features are most important?
- How does the model work internally?
- What would change the prediction?
SHAP (SHapley Additive exPlanations)
SHAP values provide a unified measure of feature importance based on game theory.
KernelSHAP
Model-agnostic method for any black-box model.
from neurenix.explainable import KernelShap
import neurenix as nx
# Train a model
model = train_your_model(X_train, y_train)
# Create explainer
explainer = KernelShap(
model=model,
data=X_train, # Background data
link="identity"
)
# Explain predictions
shap_values = explainer.explain(
samples=X_test[:10],
n_samples=2048
)
print("SHAP values shape:", shap_values["values"].shape)
print("Expected value:", shap_values["expected_value"])
# Visualize
explainer.plot(
shap_values,
feature_names=['feature1', 'feature2', 'feature3']
)
TreeSHAP
Fast and exact method for tree-based models.
from neurenix.explainable import TreeShap
# For tree-based models (decision trees, random forests, gradient boosting)
explainer = TreeShap(
model=tree_model,
data=X_train,
feature_perturbation="interventional"
)
shap_values = explainer.explain(X_test)
print("Feature contributions:")
for i, feature in enumerate(feature_names):
print(f"{feature}: {shap_values['values'][0, i]:.4f}")
DeepSHAP
Optimized for deep neural networks.
from neurenix.explainable import DeepShap
import neurenix as nx
# Deep learning model
model = nx.nn.Sequential(
nx.nn.Linear(784, 256),
nx.nn.ReLU(),
nx.nn.Linear(256, 10)
)
# Create explainer
explainer = DeepShap(
model=model,
data=X_train[:100] # Reference dataset
)
# Explain
shap_values = explainer.explain(X_test[:5])
# Access results
values = shap_values["values"] # SHAP values for each feature
expected = shap_values["expected_value"] # Base value
Advanced SHAP Usage
import matplotlib.pyplot as plt
import neurenix as nx
# Summary plot
def plot_shap_summary(shap_values, X, feature_names):
"""Plot feature importance summary"""
values = shap_values["values"].numpy()
# Calculate mean absolute SHAP values
mean_shap = np.abs(values).mean(axis=0)
# Sort features
indices = np.argsort(mean_shap)[::-1]
plt.figure(figsize=(10, 6))
plt.barh(range(len(indices)), mean_shap[indices])
plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
plt.xlabel("Mean |SHAP value|")
plt.title("Feature Importance")
plt.tight_layout()
plt.show()
# Dependence plot
def plot_shap_dependence(shap_values, X, feature_idx, feature_names):
"""Plot SHAP values vs feature values"""
values = shap_values["values"].numpy()[:, feature_idx]
feature_vals = X[:, feature_idx].numpy()
plt.figure(figsize=(8, 6))
plt.scatter(feature_vals, values, alpha=0.5)
plt.xlabel(feature_names[feature_idx])
plt.ylabel("SHAP value")
plt.title(f"SHAP Dependence: {feature_names[feature_idx]}")
plt.tight_layout()
plt.show()
LIME (Local Interpretable Model-agnostic Explanations)
LIME explains individual predictions by fitting interpretable models locally.
Tabular Data
from neurenix.explainable import LimeTabular
import neurenix as nx
# Create explainer
explainer = LimeTabular(
model=model,
feature_names=['age', 'income', 'education', 'hours_per_week'],
class_names=['<=50K', '>50K'],
categorical_features=[2], # Education is categorical
kernel_width=0.75
)
# Explain a single prediction
explanation = explainer.explain(
sample=X_test[0],
num_features=10,
num_samples=5000
)
print("Prediction:", explanation["prediction"])
print("\nTop features:")
for feature, weight, value in zip(
explanation["feature_names"],
explanation["feature_weights"],
explanation["feature_values"]
):
print(f" {feature} = {value:.2f}: {weight:+.4f}")
# Visualize
explainer.plot_explanation(explanation)
Text Data
from neurenix.explainable import LimeText
# Text classifier
text_model = train_text_classifier()
explainer = LimeText(
model=text_model,
class_names=['negative', 'positive'],
bow=True,
split_expression=r'\W+'
)
# Explain text prediction
text = "This movie was absolutely fantastic! Great acting and plot."
explanation = explainer.explain(
text=text,
num_features=10,
num_samples=5000
)
print("Prediction:", explanation["prediction"])
print("\nImportant words:")
for word, weight in zip(explanation["words"], explanation["word_weights"]):
print(f" {word}: {weight:+.4f}")
Image Data
from neurenix.explainable import LimeImage
import neurenix as nx
# Image classifier
image_model = nx.vision.resnet50(pretrained=True)
explainer = LimeImage(
model=image_model,
class_names=imagenet_classes
)
# Explain image prediction
image = load_image('cat.jpg')
explanation = explainer.explain(
image=image,
num_features=10,
num_samples=1000
)
print("Predicted class:", explanation["prediction"])
# Visualize important regions
explainer.plot_explanation(explanation)
Feature Importance
Permutation Importance
from neurenix.explainable import PermutationImportance
import neurenix as nx
# Calculate permutation importance
importance_calculator = PermutationImportance(
model=model,
metric='accuracy',
n_repeats=10
)
importances = importance_calculator.compute(
X_test,
y_test
)
print("Feature Importances:")
for feature, importance, std in zip(
feature_names,
importances['mean'],
importances['std']
):
print(f"{feature}: {importance:.4f} (+/- {std:.4f})")
# Plot
import matplotlib.pyplot as plt
sorted_idx = importances['mean'].argsort()[::-1]
plt.figure(figsize=(10, 6))
plt.bar(range(len(sorted_idx)), importances['mean'][sorted_idx])
plt.xticks(range(len(sorted_idx)), [feature_names[i] for i in sorted_idx], rotation=45)
plt.ylabel('Importance')
plt.title('Permutation Feature Importance')
plt.tight_layout()
plt.show()
Feature Importance from Gradients
from neurenix.explainable import FeatureImportance
importance = FeatureImportance(model)
# Gradient-based importance
grad_importance = importance.compute_gradient_importance(
X_test,
y_test
)
# Integrated gradients
ig_importance = importance.compute_integrated_gradients(
X_test,
baseline=X_train.mean(dim=0),
steps=50
)
Partial Dependence Plots
Show how features affect predictions on average.
from neurenix.explainable import PartialDependence
import neurenix as nx
pd_calculator = PartialDependence(model)
# Single feature
pd_result = pd_calculator.compute(
X_train,
feature_idx=0,
grid_resolution=50
)
# Plot
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.plot(pd_result['feature_values'], pd_result['predictions'])
plt.xlabel(feature_names[0])
plt.ylabel('Predicted value')
plt.title(f'Partial Dependence: {feature_names[0]}')
plt.grid(True)
plt.show()
# Two-feature interaction
pd_2d = pd_calculator.compute_2d(
X_train,
feature_idx1=0,
feature_idx2=1,
grid_resolution=30
)
# 3D surface plot
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(
pd_2d['feature1_grid'],
pd_2d['feature2_grid'],
pd_2d['predictions'],
cmap='viridis'
)
plt.show()
Counterfactual Explanations
Find minimal changes needed to flip the prediction.
from neurenix.explainable import Counterfactual
import neurenix as nx
cf_explainer = Counterfactual(
model=model,
feature_ranges=feature_ranges,
categorical_features=[2, 5]
)
# Find counterfactual
original = X_test[0]
counterfactuals = cf_explainer.generate(
original=original,
target_class=1, # Desired prediction
n_counterfactuals=5,
max_iterations=1000
)
print("Original prediction:", model(original).argmax())
print("\nCounterfactuals:")
for i, cf in enumerate(counterfactuals):
print(f"\nCounterfactual {i+1}:")
print(f" Prediction: {model(cf).argmax()}")
print(f" Distance: {cf_explainer.distance(original, cf):.4f}")
print(f" Changes:")
for j, (orig_val, cf_val) in enumerate(zip(original, cf)):
if orig_val != cf_val:
print(f" {feature_names[j]}: {orig_val:.2f} -> {cf_val:.2f}")
Activation Visualization
Visualize what neural networks learn.
from neurenix.explainable import ActivationVisualization
import neurenix as nx
vis = ActivationVisualization(model)
# Visualize layer activations
activations = vis.get_activations(
input=X_test[0],
layer_names=['conv1', 'conv2', 'fc1']
)
for layer_name, activation in activations.items():
print(f"{layer_name}: {activation.shape}")
vis.plot_activation(activation, layer_name)
# Grad-CAM (for CNNs)
grad_cam = vis.grad_cam(
input=image,
target_class=predicted_class,
target_layer='layer4'
)
# Overlay on original image
vis.plot_grad_cam(image, grad_cam)
# Saliency maps
saliency = vis.saliency_map(input=X_test[0])
vis.plot_saliency(X_test[0], saliency)
Example: Complete Explainability Pipeline
import neurenix as nx
from neurenix.explainable import (
KernelShap,
LimeTabular,
PermutationImportance,
PartialDependence
)
class ExplainablePipeline:
def __init__(self, model, X_train, feature_names):
self.model = model
self.X_train = X_train
self.feature_names = feature_names
# Initialize explainers
self.shap_explainer = KernelShap(model, X_train)
self.lime_explainer = LimeTabular(
model,
feature_names=feature_names
)
self.perm_importance = PermutationImportance(model)
self.pd_calculator = PartialDependence(model)
def explain_prediction(self, sample):
"""Comprehensive explanation for a single prediction"""
results = {}
# Model prediction
prediction = self.model(sample)
results['prediction'] = prediction.item()
# SHAP values
shap_values = self.shap_explainer.explain(sample.unsqueeze(0))
results['shap'] = shap_values
# LIME explanation
lime_exp = self.lime_explainer.explain(sample)
results['lime'] = lime_exp
return results
def global_explanation(self, X_test, y_test):
"""Global model explanation"""
results = {}
# Feature importance
importance = self.perm_importance.compute(X_test, y_test)
results['importance'] = importance
# Partial dependence for top features
top_features = importance['mean'].argsort()[-5:][::-1]
pd_plots = {}
for feat_idx in top_features:
pd = self.pd_calculator.compute(
self.X_train,
feature_idx=feat_idx
)
pd_plots[self.feature_names[feat_idx]] = pd
results['partial_dependence'] = pd_plots
return results
def compare_samples(self, sample1, sample2):
"""Compare explanations for two samples"""
exp1 = self.explain_prediction(sample1)
exp2 = self.explain_prediction(sample2)
print("Sample 1:")
print(f" Prediction: {exp1['prediction']:.4f}")
print(" Top 3 SHAP features:")
shap1 = exp1['shap']['values'][0]
top3_1 = shap1.abs().argsort()[-3:][::-1]
for idx in top3_1:
print(f" {self.feature_names[idx]}: {shap1[idx]:.4f}")
print("\nSample 2:")
print(f" Prediction: {exp2['prediction']:.4f}")
print(" Top 3 SHAP features:")
shap2 = exp2['shap']['values'][0]
top3_2 = shap2.abs().argsort()[-3:][::-1]
for idx in top3_2:
print(f" {self.feature_names[idx]}: {shap2[idx]:.4f}")
# Usage
model = train_model(X_train, y_train)
pipeline = ExplainablePipeline(model, X_train, feature_names)
# Explain single prediction
exp = pipeline.explain_prediction(X_test[0])
# Global explanations
global_exp = pipeline.global_explanation(X_test, y_test)
# Compare samples
pipeline.compare_samples(X_test[0], X_test[1])
Best Practices
- Multiple Methods: Use multiple explanation methods for robust insights
- Local vs Global: Combine local explanations (LIME, SHAP) with global understanding (feature importance, PD plots)
- Validation: Verify explanations match domain knowledge
- Audience: Tailor explanations to the audience (technical vs non-technical)
- Computational Cost: SHAP and LIME can be expensive; cache results when possible
Choosing an Explanation Method
| Method | Use Case | Speed | Accuracy |
|---|
| SHAP (Kernel) | Any model, global importance | Slow | High |
| SHAP (Tree) | Tree models | Fast | Exact |
| SHAP (Deep) | Neural networks | Medium | High |
| LIME | Local explanations, any model | Medium | Good |
| Permutation | Global importance | Medium | Good |
| Partial Dependence | Feature effects | Fast | Good |
| Counterfactuals | What-if analysis | Slow | High |
References
- Lundberg & Lee (2017) - “A Unified Approach to Interpreting Model Predictions”
- Ribeiro et al. (2016) - “Why Should I Trust You?: Explaining the Predictions of Any Classifier”
- Molnar (2019) - “Interpretable Machine Learning”
See Also