Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/SforAiDl/lrnnx/llms.txt

Use this file to discover all available pages before exploring further.

Overview

Convolution operations for LRNN models using FFT (Fast Fourier Transform) for efficient computation. These functions implement optimized strategies for convolution-based forward passes in state space models. Reference: arxiv:2409.03377

Functions

opt_ssm_forward

opt_ssm_forward(x, K, B_, C) -> Tensor
Optimized FFT convolution with automatic strategy selection. This function intelligently chooses between three different computation strategies based on tensor dimensions to minimize computational cost. Strategy Selection:
  1. Strategy 1: When (1/H_in + 1/H_out) > (1/B + 1/N) and H_in * H_out <= N
    • Precompute full kernel: kernel = einsum("on,nl,ni->loi", C, K, B_)
    • Apply convolution: fft_conv("bli,loi->blo", x, kernel)
  2. Strategy 2: When (1/H_in + 1/H_out) <= (1/B + 1/N) and N <= H_in
    • Project input: x_proj = einsum("blh,nh->bln", x, B_)
    • Convolve projected input: fft_conv("bln,ln->bln", x_proj, K.T)
    • Apply output projection: einsum("bln,hn->blh", x_conv, C)
  3. Fallback: When neither strategy is optimal
    • Direct computation: fft_conv("blh,nl,nh,on->blo", x, K, B_, C)
x
torch.Tensor
required
Input tensor, shape (B, L, H) where:
  • B is the batch size
  • L is the sequence length
  • H is the input dimension
K
torch.Tensor
required
Kernel tensor, shape (L, H, H) or (L, N) depending on the model configuration.
B_
torch.Tensor
required
Normalized input projection matrix, shape (N, H) where N is the state dimension.
C
torch.Tensor
required
Output projection matrix, shape (H, N).
output
torch.Tensor
Output tensor, shape (B, L, H), representing the convolved sequence.

fft_conv

fft_conv(equation, input, *args) -> Tensor
FFT-based convolution operation with flexible einsum equations. This is a lower-level function used by opt_ssm_forward and supports multiple argument patterns.
equation
str
required
Einsum equation string specifying the contraction pattern (e.g., "bli,loi->blo").
input
torch.Tensor
required
Input tensor, shape (B, L, H) or (B, L, N).
*args
torch.Tensor
required
Variable arguments depending on the convolution pattern:
  • Single argument: Kernel tensor of shape (L, H, H)
  • Multiple arguments: Separate K, B_norm, and C tensors
output
torch.Tensor
Convolved output tensor, shape (B, L, H) or (B, L, N) depending on the input configuration.
Implementation Details:
  • Performs FFT with padding to 2*L to avoid circular convolution artifacts
  • Converts tensors to complex float (cfloat) for FFT operations
  • Returns real part after inverse FFT, truncated to original sequence length L

FFTConvS4 Module

Class

FFTConvS4(d_model, l_max=None, channels=1, swap_channels=False, 
          transposed=True, dropout=0.0, tie_dropout=False, 
          drop_kernel=0.0, kernel_type=None, param_config=None, 
          kernel=None, **kernel_args)
PyTorch module implementing FFT convolution around a learnable convolution kernel. This is the main building block for S4-style models.
d_model
int
required
Model dimension (in CNN terminology, the number of “channels”).
l_max
int
Maximum kernel length. Use None for a global kernel that adapts to input length.
channels
int
default:"1"
Number of “heads”; the SSM maps 1-dimensional input to C-dimensional output.
swap_channels
bool
default:"False"
Whether to swap channel ordering in the computation.
transposed
bool
default:"True"
Backbone axis ordering. If True, expects input shape (B, D, L). If False, expects (B, L, D).
dropout
float
default:"0.0"
Dropout probability applied to the output.
tie_dropout
bool
default:"False"
If True, ties the dropout mask across the sequence length dimension.
drop_kernel
float
default:"0.0"
Kernel dropout probability, applied to the convolution kernel.
kernel_type
str
Kernel algorithm specification:
  • "s4" - DPLR (Diagonal Plus Low-Rank) parameterization
  • "s4d" - Diagonal parameterization
Required when param_config is provided.
param_config
dict
Dictionary containing references to SSM parameters (A, B, C, dt, P, etc.). Used with kernel_type to configure the kernel.
kernel
str
Alternative kernel specification. Either this or param_config must be provided.
**kernel_args
dict
Additional keyword arguments forwarded to the kernel class constructor.

Methods

forward

forward(x, state=None, rate=1.0, **kwargs) -> tuple[Tensor, Tensor | None]
Forward pass through the FFTConvS4 module.
x
torch.Tensor
required
Input tensor. Shape depends on transposed parameter:
  • If transposed=True: (B, D, L)
  • If transposed=False: (B, L, D)
state
torch.Tensor
Recurrent state from previous time step. Used for stateful/recurrent processing.
rate
float
default:"1.0"
Rate parameter for kernel computation, useful for temporal downsampling.
**kwargs
dict
Additional keyword arguments (absorbs return_output, transformer source mask, etc.).
y
torch.Tensor
Convolution output, shape (B, C, H, L) where C is the number of channels.
next_state
torch.Tensor | None
Updated state for recurrent mode. Returns None if state was not provided.

step

step(x, state) -> tuple[Tensor, Tensor]
Step one time step as a recurrent model. Intended for use during validation or autoregressive generation.
x
torch.Tensor
required
Input tensor at current time step, shape (B, H).
state
torch.Tensor
required
Recurrent state, shape (B, H, N) where N is the state dimension.
y
torch.Tensor
Output at current time step, shape (B, C, H).
next_state
torch.Tensor
Updated state for next time step, shape (B, H, N).

setup_step

setup_step(**kwargs)
Prepare the module for step-by-step (recurrent) inference. Must be called before using the step method.

default_state

default_state(*batch_shape, device=None) -> Tensor
Create a default initial state for recurrent processing.
*batch_shape
int
Batch dimensions for the state tensor.
device
torch.device
Device on which to create the state tensor.
state
torch.Tensor
Initialized state tensor with appropriate shape and device.

Properties

d_output
int
Output dimension, computed as d_model * channels.

Example Usage

import torch
from lrnnx.core.convolution import opt_ssm_forward, FFTConvS4

# Low-level usage with opt_ssm_forward
B, L, H, N = 32, 1024, 64, 128
x = torch.randn(B, L, H)
K = torch.randn(L, N)
B_ = torch.randn(N, H)
C = torch.randn(H, N)

output = opt_ssm_forward(x, K, B_, C)  # (B, L, H)

# High-level usage with FFTConvS4 module
conv_layer = FFTConvS4(
    d_model=64,
    l_max=1024,
    channels=4,
    kernel="s4d",
    dropout=0.1
)

x = torch.randn(32, 64, 1024)  # (B, D, L) with transposed=True
y, _ = conv_layer(x)  # (B, 4, 64, 1024)

Performance Considerations

  • The opt_ssm_forward function automatically selects the most efficient computation strategy
  • FFT operations are performed with padding to avoid circular convolution
  • Kernel dropout can be used for regularization without recomputing the FFT
  • The @torch.compiler.disable decorator on fft_conv prevents torch compilation issues with FFT operations

See Also

Build docs developers (and LLMs) love