Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/MilesONerd/neurenix/llms.txt

Use this file to discover all available pages before exploring further.

Explainable AI

The explainability module provides methods for interpreting machine learning models and explaining their predictions. This includes SHAP values, LIME, feature importance, and various visualization techniques to make AI systems more transparent and trustworthy.

Overview

Explainable AI techniques help answer:
  • Why did the model make this prediction?
  • Which features are most important?
  • How does the model work internally?
  • What would change the prediction?

SHAP (SHapley Additive exPlanations)

SHAP values provide a unified measure of feature importance based on game theory.

KernelSHAP

Model-agnostic method for any black-box model.
from neurenix.explainable import KernelShap
import neurenix as nx

# Train a model
model = train_your_model(X_train, y_train)

# Create explainer
explainer = KernelShap(
    model=model,
    data=X_train,  # Background data
    link="identity"
)

# Explain predictions
shap_values = explainer.explain(
    samples=X_test[:10],
    n_samples=2048
)

print("SHAP values shape:", shap_values["values"].shape)
print("Expected value:", shap_values["expected_value"])

# Visualize
explainer.plot(
    shap_values,
    feature_names=['feature1', 'feature2', 'feature3']
)

TreeSHAP

Fast and exact method for tree-based models.
from neurenix.explainable import TreeShap

# For tree-based models (decision trees, random forests, gradient boosting)
explainer = TreeShap(
    model=tree_model,
    data=X_train,
    feature_perturbation="interventional"
)

shap_values = explainer.explain(X_test)

print("Feature contributions:")
for i, feature in enumerate(feature_names):
    print(f"{feature}: {shap_values['values'][0, i]:.4f}")

DeepSHAP

Optimized for deep neural networks.
from neurenix.explainable import DeepShap
import neurenix as nx

# Deep learning model
model = nx.nn.Sequential(
    nx.nn.Linear(784, 256),
    nx.nn.ReLU(),
    nx.nn.Linear(256, 10)
)

# Create explainer
explainer = DeepShap(
    model=model,
    data=X_train[:100]  # Reference dataset
)

# Explain
shap_values = explainer.explain(X_test[:5])

# Access results
values = shap_values["values"]  # SHAP values for each feature
expected = shap_values["expected_value"]  # Base value

Advanced SHAP Usage

import matplotlib.pyplot as plt
import neurenix as nx

# Summary plot
def plot_shap_summary(shap_values, X, feature_names):
    """Plot feature importance summary"""
    values = shap_values["values"].numpy()
    
    # Calculate mean absolute SHAP values
    mean_shap = np.abs(values).mean(axis=0)
    
    # Sort features
    indices = np.argsort(mean_shap)[::-1]
    
    plt.figure(figsize=(10, 6))
    plt.barh(range(len(indices)), mean_shap[indices])
    plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
    plt.xlabel("Mean |SHAP value|")
    plt.title("Feature Importance")
    plt.tight_layout()
    plt.show()

# Dependence plot
def plot_shap_dependence(shap_values, X, feature_idx, feature_names):
    """Plot SHAP values vs feature values"""
    values = shap_values["values"].numpy()[:, feature_idx]
    feature_vals = X[:, feature_idx].numpy()
    
    plt.figure(figsize=(8, 6))
    plt.scatter(feature_vals, values, alpha=0.5)
    plt.xlabel(feature_names[feature_idx])
    plt.ylabel("SHAP value")
    plt.title(f"SHAP Dependence: {feature_names[feature_idx]}")
    plt.tight_layout()
    plt.show()

LIME (Local Interpretable Model-agnostic Explanations)

LIME explains individual predictions by fitting interpretable models locally.

Tabular Data

from neurenix.explainable import LimeTabular
import neurenix as nx

# Create explainer
explainer = LimeTabular(
    model=model,
    feature_names=['age', 'income', 'education', 'hours_per_week'],
    class_names=['<=50K', '>50K'],
    categorical_features=[2],  # Education is categorical
    kernel_width=0.75
)

# Explain a single prediction
explanation = explainer.explain(
    sample=X_test[0],
    num_features=10,
    num_samples=5000
)

print("Prediction:", explanation["prediction"])
print("\nTop features:")
for feature, weight, value in zip(
    explanation["feature_names"],
    explanation["feature_weights"],
    explanation["feature_values"]
):
    print(f"  {feature} = {value:.2f}: {weight:+.4f}")

# Visualize
explainer.plot_explanation(explanation)

Text Data

from neurenix.explainable import LimeText

# Text classifier
text_model = train_text_classifier()

explainer = LimeText(
    model=text_model,
    class_names=['negative', 'positive'],
    bow=True,
    split_expression=r'\W+'
)

# Explain text prediction
text = "This movie was absolutely fantastic! Great acting and plot."
explanation = explainer.explain(
    text=text,
    num_features=10,
    num_samples=5000
)

print("Prediction:", explanation["prediction"])
print("\nImportant words:")
for word, weight in zip(explanation["words"], explanation["word_weights"]):
    print(f"  {word}: {weight:+.4f}")

Image Data

from neurenix.explainable import LimeImage
import neurenix as nx

# Image classifier
image_model = nx.vision.resnet50(pretrained=True)

explainer = LimeImage(
    model=image_model,
    class_names=imagenet_classes
)

# Explain image prediction
image = load_image('cat.jpg')
explanation = explainer.explain(
    image=image,
    num_features=10,
    num_samples=1000
)

print("Predicted class:", explanation["prediction"])

# Visualize important regions
explainer.plot_explanation(explanation)

Feature Importance

Permutation Importance

from neurenix.explainable import PermutationImportance
import neurenix as nx

# Calculate permutation importance
importance_calculator = PermutationImportance(
    model=model,
    metric='accuracy',
    n_repeats=10
)

importances = importance_calculator.compute(
    X_test,
    y_test
)

print("Feature Importances:")
for feature, importance, std in zip(
    feature_names,
    importances['mean'],
    importances['std']
):
    print(f"{feature}: {importance:.4f} (+/- {std:.4f})")

# Plot
import matplotlib.pyplot as plt

sorted_idx = importances['mean'].argsort()[::-1]
plt.figure(figsize=(10, 6))
plt.bar(range(len(sorted_idx)), importances['mean'][sorted_idx])
plt.xticks(range(len(sorted_idx)), [feature_names[i] for i in sorted_idx], rotation=45)
plt.ylabel('Importance')
plt.title('Permutation Feature Importance')
plt.tight_layout()
plt.show()

Feature Importance from Gradients

from neurenix.explainable import FeatureImportance

importance = FeatureImportance(model)

# Gradient-based importance
grad_importance = importance.compute_gradient_importance(
    X_test,
    y_test
)

# Integrated gradients
ig_importance = importance.compute_integrated_gradients(
    X_test,
    baseline=X_train.mean(dim=0),
    steps=50
)

Partial Dependence Plots

Show how features affect predictions on average.
from neurenix.explainable import PartialDependence
import neurenix as nx

pd_calculator = PartialDependence(model)

# Single feature
pd_result = pd_calculator.compute(
    X_train,
    feature_idx=0,
    grid_resolution=50
)

# Plot
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
plt.plot(pd_result['feature_values'], pd_result['predictions'])
plt.xlabel(feature_names[0])
plt.ylabel('Predicted value')
plt.title(f'Partial Dependence: {feature_names[0]}')
plt.grid(True)
plt.show()

# Two-feature interaction
pd_2d = pd_calculator.compute_2d(
    X_train,
    feature_idx1=0,
    feature_idx2=1,
    grid_resolution=30
)

# 3D surface plot
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(
    pd_2d['feature1_grid'],
    pd_2d['feature2_grid'],
    pd_2d['predictions'],
    cmap='viridis'
)
plt.show()

Counterfactual Explanations

Find minimal changes needed to flip the prediction.
from neurenix.explainable import Counterfactual
import neurenix as nx

cf_explainer = Counterfactual(
    model=model,
    feature_ranges=feature_ranges,
    categorical_features=[2, 5]
)

# Find counterfactual
original = X_test[0]
counterfactuals = cf_explainer.generate(
    original=original,
    target_class=1,  # Desired prediction
    n_counterfactuals=5,
    max_iterations=1000
)

print("Original prediction:", model(original).argmax())
print("\nCounterfactuals:")
for i, cf in enumerate(counterfactuals):
    print(f"\nCounterfactual {i+1}:")
    print(f"  Prediction: {model(cf).argmax()}")
    print(f"  Distance: {cf_explainer.distance(original, cf):.4f}")
    print(f"  Changes:")
    
    for j, (orig_val, cf_val) in enumerate(zip(original, cf)):
        if orig_val != cf_val:
            print(f"    {feature_names[j]}: {orig_val:.2f} -> {cf_val:.2f}")

Activation Visualization

Visualize what neural networks learn.
from neurenix.explainable import ActivationVisualization
import neurenix as nx

vis = ActivationVisualization(model)

# Visualize layer activations
activations = vis.get_activations(
    input=X_test[0],
    layer_names=['conv1', 'conv2', 'fc1']
)

for layer_name, activation in activations.items():
    print(f"{layer_name}: {activation.shape}")
    vis.plot_activation(activation, layer_name)

# Grad-CAM (for CNNs)
grad_cam = vis.grad_cam(
    input=image,
    target_class=predicted_class,
    target_layer='layer4'
)

# Overlay on original image
vis.plot_grad_cam(image, grad_cam)

# Saliency maps
saliency = vis.saliency_map(input=X_test[0])
vis.plot_saliency(X_test[0], saliency)

Example: Complete Explainability Pipeline

import neurenix as nx
from neurenix.explainable import (
    KernelShap,
    LimeTabular,
    PermutationImportance,
    PartialDependence
)

class ExplainablePipeline:
    def __init__(self, model, X_train, feature_names):
        self.model = model
        self.X_train = X_train
        self.feature_names = feature_names
        
        # Initialize explainers
        self.shap_explainer = KernelShap(model, X_train)
        self.lime_explainer = LimeTabular(
            model,
            feature_names=feature_names
        )
        self.perm_importance = PermutationImportance(model)
        self.pd_calculator = PartialDependence(model)
    
    def explain_prediction(self, sample):
        """Comprehensive explanation for a single prediction"""
        results = {}
        
        # Model prediction
        prediction = self.model(sample)
        results['prediction'] = prediction.item()
        
        # SHAP values
        shap_values = self.shap_explainer.explain(sample.unsqueeze(0))
        results['shap'] = shap_values
        
        # LIME explanation
        lime_exp = self.lime_explainer.explain(sample)
        results['lime'] = lime_exp
        
        return results
    
    def global_explanation(self, X_test, y_test):
        """Global model explanation"""
        results = {}
        
        # Feature importance
        importance = self.perm_importance.compute(X_test, y_test)
        results['importance'] = importance
        
        # Partial dependence for top features
        top_features = importance['mean'].argsort()[-5:][::-1]
        pd_plots = {}
        
        for feat_idx in top_features:
            pd = self.pd_calculator.compute(
                self.X_train,
                feature_idx=feat_idx
            )
            pd_plots[self.feature_names[feat_idx]] = pd
        
        results['partial_dependence'] = pd_plots
        
        return results
    
    def compare_samples(self, sample1, sample2):
        """Compare explanations for two samples"""
        exp1 = self.explain_prediction(sample1)
        exp2 = self.explain_prediction(sample2)
        
        print("Sample 1:")
        print(f"  Prediction: {exp1['prediction']:.4f}")
        print("  Top 3 SHAP features:")
        shap1 = exp1['shap']['values'][0]
        top3_1 = shap1.abs().argsort()[-3:][::-1]
        for idx in top3_1:
            print(f"    {self.feature_names[idx]}: {shap1[idx]:.4f}")
        
        print("\nSample 2:")
        print(f"  Prediction: {exp2['prediction']:.4f}")
        print("  Top 3 SHAP features:")
        shap2 = exp2['shap']['values'][0]
        top3_2 = shap2.abs().argsort()[-3:][::-1]
        for idx in top3_2:
            print(f"    {self.feature_names[idx]}: {shap2[idx]:.4f}")

# Usage
model = train_model(X_train, y_train)
pipeline = ExplainablePipeline(model, X_train, feature_names)

# Explain single prediction
exp = pipeline.explain_prediction(X_test[0])

# Global explanations
global_exp = pipeline.global_explanation(X_test, y_test)

# Compare samples
pipeline.compare_samples(X_test[0], X_test[1])

Best Practices

  1. Multiple Methods: Use multiple explanation methods for robust insights
  2. Local vs Global: Combine local explanations (LIME, SHAP) with global understanding (feature importance, PD plots)
  3. Validation: Verify explanations match domain knowledge
  4. Audience: Tailor explanations to the audience (technical vs non-technical)
  5. Computational Cost: SHAP and LIME can be expensive; cache results when possible

Choosing an Explanation Method

MethodUse CaseSpeedAccuracy
SHAP (Kernel)Any model, global importanceSlowHigh
SHAP (Tree)Tree modelsFastExact
SHAP (Deep)Neural networksMediumHigh
LIMELocal explanations, any modelMediumGood
PermutationGlobal importanceMediumGood
Partial DependenceFeature effectsFastGood
CounterfactualsWhat-if analysisSlowHigh

References

  • Lundberg & Lee (2017) - “A Unified Approach to Interpreting Model Predictions”
  • Ribeiro et al. (2016) - “Why Should I Trust You?: Explaining the Predictions of Any Classifier”
  • Molnar (2019) - “Interpretable Machine Learning”

See Also

Build docs developers (and LLMs) love