Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/SforAiDl/lrnnx/llms.txt

Use this file to discover all available pages before exploring further.

Overview

Linear RNNs in PyTorch require special handling during inference. Following the approach from Mamba, lrnnx implements CUDA graphs-based inference which reduces CPU overhead and provides >10x speedup compared to a simple for loop.
The generation API is located in lrnnx/utils/generation.py and works with all models in the library.

Quick Start

1

Import the generation API

import torch
from lrnnx.utils.generation import capture_graph, generate
from lrnnx.models.lti import LRU
2

Create and prepare model

# Initialize model in eval mode on CUDA
model = LRU(d_model=64, d_state=64).cuda().eval()

# Model configuration
batch_size = 4
H = 64  # d_model
3

Capture CUDA graph (one-time setup)

# Capture graph once - reuse for all generations
cache = capture_graph(model, batch_size=batch_size, H=H)
Graph capture is a one-time operation. The captured graph is tied to the specific batch size.
4

Generate sequences

# Create seed input (B, H)
x0 = torch.randn(batch_size, H, device="cuda")

# Generate 512 steps with CUDA graph replay
output = generate(
    model,
    x0,
    num_steps=512,
    graph_cache=cache  # Use captured graph for 10x speedup
)

# Output shape: (batch_size, num_steps, H)
print(output.shape)  # torch.Size([4, 512, 64])

CUDA Graph Optimization

Why CUDA Graphs?

Standard autoregressive generation in PyTorch has significant CPU overhead for each step. CUDA graphs eliminate this by:
  1. Recording the computational graph once during capture_graph()
  2. Replaying the pre-recorded graph for each timestep with zero CPU overhead
This provides >10x speedup for inference!

How It Works

import torch
from lrnnx.utils.generation import capture_graph, generate
from lrnnx.models.lti import LRU

model = LRU(d_model=64, d_state=64).cuda().eval()

# One-time graph capture
cache = capture_graph(model, batch_size=4, H=64)

# Seed token
x0 = torch.randn(4, 64, device="cuda")

# Fast generation with graph replay
output = generate(model, x0, num_steps=512, graph_cache=cache)

Complete Example

import torch
from lrnnx.models.lti import LRU
from lrnnx.utils.generation import capture_graph, generate

# Configuration
batch_size = 4
d_model = 64
d_state = 64
num_steps = 512

# Initialize model
model = LRU(d_model=d_model, d_state=d_state).cuda()
model.eval()  # Important: set to eval mode

# Capture CUDA graph (do this once)
print("Capturing CUDA graph...")
cache = capture_graph(model, batch_size=batch_size, H=d_model)

# Create seed input
x0 = torch.randn(batch_size, d_model, device="cuda")

# Generate with CUDA graph
print("Generating...")
with torch.inference_mode():
    output = generate(model, x0, num_steps=num_steps, graph_cache=cache)

print(f"Generated output shape: {output.shape}")
# Output: Generated output shape: torch.Size([4, 512, 64])

Event-Based Inference

Some models support event-driven timesteps during generation:
from lrnnx.models.lti import S5
from lrnnx.utils.generation import capture_graph, generate

model = S5(d_model=64, d_state=64).cuda().eval()

# Capture with event mode enabled
cache = capture_graph(
    model,
    batch_size=4,
    H=64,
    event_mode=True  # Enable event-driven timesteps
)

x0 = torch.randn(4, 64, device="cuda")

# Provide integration timestep (reused at every step)
integration_timesteps = torch.ones(4, 1, device="cuda") * 0.1

output = generate(
    model,
    x0,
    num_steps=512,
    graph_cache=cache,
    integration_timesteps=integration_timesteps
)
When using integration_timesteps, you must capture the graph with event_mode=True.

Benchmarking Inference

The library includes built-in benchmarks to measure inference performance:
from benchmarks.benchmark_inference import benchmark_sequence_length
from lrnnx.models.lti import LRU

def model_fn():
    return LRU(d_model=128, d_state=64).cuda().eval()

# Benchmark CUDA-graph inference across sequence lengths
results = benchmark_sequence_length(
    model_fn,
    seq_lengths=[64, 128, 256, 512, 1024, 2048],
    batch_size=32,
    repeats=5
)

for seq_len, times in results.items():
    avg_time = sum(times) / len(times)
    print(f"Seq len {seq_len}: {avg_time:.2f} ms")
See benchmarks/benchmark_inference.py for complete benchmarking utilities including:
  • benchmark_sequence_length() - Vary generation length
  • benchmark_model_dimension() - Vary model size
  • benchmark_batch_size() - Vary batch size

API Reference

capture_graph()

Captures a CUDA graph for the model’s single-step recurrence. Parameters:
  • model (LTI_LRNN | LTV_LRNN) - Model on CUDA in eval mode
  • batch_size (int) - Batch size to capture for
  • H (int) - Model input/output dimension (d_model)
  • max_seqlen (int, optional) - Maximum sequence length, default: 1
  • event_mode (bool, optional) - Enable event-driven timesteps, default: False
  • device (torch.device, optional) - CUDA device, inferred from model if None
  • n_warmups (int, optional) - Warmup iterations before capture, default: 3
Returns:
  • CUDAGraphStepCache - Opaque cache object to pass to generate()
Example:
cache = capture_graph(model, batch_size=4, H=64)

generate()

Autoregressive generation with optional CUDA graph acceleration. Parameters:
  • model (LTI_LRNN | LTV_LRNN) - Model on CUDA in eval mode
  • x (torch.Tensor) - Seed input, shape (batch, H)
  • num_steps (int) - Number of autoregressive steps
  • graph_cache (CUDAGraphStepCache, optional) - Pre-captured graph from capture_graph(), default: None
  • integration_timesteps (torch.Tensor, optional) - Integration timestep shape (batch, 1) for event models, default: None
Returns:
  • torch.Tensor - Generated sequence, shape (batch, num_steps, H)
Example:
output = generate(model, x0, num_steps=512, graph_cache=cache)

Performance Tips

1

Always use CUDA graphs

Capture once, reuse for all generations with the same batch size:
cache = capture_graph(model, batch_size=4, H=64)
# Reuse cache for multiple generations
out1 = generate(model, x1, num_steps=100, graph_cache=cache)
out2 = generate(model, x2, num_steps=200, graph_cache=cache)
2

Recapture for different batch sizes

CUDA graphs are fixed-shape, so create separate caches:
cache_b4 = capture_graph(model, batch_size=4, H=64)
cache_b8 = capture_graph(model, batch_size=8, H=64)

out_b4 = generate(model, x4, num_steps=100, graph_cache=cache_b4)
out_b8 = generate(model, x8, num_steps=100, graph_cache=cache_b8)
3

Use inference_mode for best performance

with torch.inference_mode():
    output = generate(model, x0, num_steps=512, graph_cache=cache)

Troubleshooting

Batch Size Mismatch

If you get an error about batch size mismatch:
ValueError: Batch size 8 != captured 4. Re-capture with capture_graph(model, batch_size=8).
Solution: Recapture the graph with the correct batch size:
cache = capture_graph(model, batch_size=8, H=64)

Memory Issues

If graph capture fails due to memory:
import gc
import torch

# Free memory before capture
gc.collect()
torch.cuda.empty_cache()

cache = capture_graph(model, batch_size=4, H=64)

Next Steps

Training Guide

Learn how to train lrnnx models

Custom Kernels

Understand the high-performance CUDA kernels

Build docs developers (and LLMs) love