Skip to main content
The quantization utilities enable running Samay models with reduced memory footprint using 8-bit (int8) or 4-bit (nf4) quantization.

quantize_linear_layers

Quantize linear layers in a PyTorch model using bitsandbytes.
from samay.quantization import quantize_linear_layers

quantized_model = quantize_linear_layers(
    module,
    threshold=6.0,
    quantization_type="int8"
)
module
torch.nn.Module
required
PyTorch module to quantize. The function recursively quantizes all nn.Linear layers with at least 128 input features
threshold
float
default:"6.0"
Threshold parameter for int8 quantization. Controls the outlier threshold for mixed-precision decomposition
quantization_type
str
default:"int8"
Type of quantization to apply:
  • "int8": 8-bit linear quantization using bnb.nn.Linear8bitLt
  • "nf4": 4-bit NormalFloat quantization using bnb.nn.Linear4bit
module
torch.nn.Module
The input module with linear layers replaced by quantized versions

Example

import torch
from samay.model import LPTMModel
from samay.quantization import quantize_linear_layers

# Load model
config = {
    "task_name": "forecasting",
    "forecast_horizon": 192,
    "freeze_encoder": True,
}

model = LPTMModel(config)
model.model = model.model.to("cuda")

# Check memory before quantization
before_size = sum(p.numel() * p.element_size() for p in model.model.parameters())
print(f"Before quantization: {before_size / 1e6:.2f} MB")

# Quantize with int8
model.model = quantize_linear_layers(
    model.model,
    threshold=6.0,
    quantization_type="int8"
)

# Check memory after quantization
after_size = sum(p.numel() * p.element_size() for p in model.model.parameters())
print(f"After quantization: {after_size / 1e6:.2f} MB")
print(f"Memory reduction: {(1 - after_size/before_size) * 100:.1f}%")

Quantization Types

INT8 Quantization

Uses bnb.nn.Linear8bitLt for 8-bit integer quantization:
  • Reduces memory by approximately 4x compared to float32
  • Minimal accuracy loss for most tasks
  • Fast inference on CUDA devices
  • Uses mixed-precision decomposition for outliers
model.model = quantize_linear_layers(
    model.model,
    quantization_type="int8",
    threshold=6.0
)

NF4 Quantization

Uses bnb.nn.Linear4bit for 4-bit NormalFloat quantization:
  • Reduces memory by approximately 8x compared to float32
  • Slightly higher accuracy loss than int8
  • Optimized for normal distributions
  • Computes in float16
model.model = quantize_linear_layers(
    model.model,
    quantization_type="nf4"
)

Requirements

Quantization requires the bitsandbytes library:
pip install bitsandbytes
bitsandbytes requires CUDA for GPU acceleration. CPU-only quantization is not supported.

Complete Example with Evaluation

import torch
from samay.model import LPTMModel
from samay.dataset import LPTMDataset
from samay.quantization import quantize_linear_layers

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
config = {
    "task_name": "forecasting",
    "forecast_horizon": 192,
    "head_dropout": 0,
    "freeze_encoder": True,
}

model = LPTMModel(config)
model.model = model.model.to(device)

# Quantize
print("Quantizing model...")
model.model = quantize_linear_layers(
    model.model,
    threshold=6.0,
    quantization_type="int8"
)

# Verify quantization
import bitsandbytes as bnb
import torch.nn as nn

n_int8 = sum(1 for m in model.model.modules() if isinstance(m, bnb.nn.Linear8bitLt))
n_linear = sum(1 for m in model.model.modules() if isinstance(m, nn.Linear))

print(f"Quantized layers (int8): {n_int8}")
print(f"Remaining linear layers: {n_linear}")

# Load dataset
dataset = LPTMDataset(
    name="ett",
    datetime_col="date",
    path="data/ETTh1.csv",
    mode="val",
    horizon=192,
)

# Evaluate with memory tracking
if device.type == "cuda":
    torch.cuda.reset_peak_memory_stats()

avg_loss, trues, preds, histories = model.evaluate(
    dataset,
    task_name="forecasting"
)

if device.type == "cuda":
    peak_memory = torch.cuda.max_memory_allocated() / 1024**2
    print(f"Peak GPU memory: {peak_memory:.2f} MB")

print(f"Average loss: {avg_loss:.4f}")

Layer Selection

The quantize_linear_layers function only quantizes nn.Linear layers with at least 128 input features. This is because:
  1. Smaller layers have minimal memory impact
  2. Quantization overhead may outweigh benefits for small layers
  3. Maintaining precision in small layers preserves model quality
To customize this threshold, modify the condition in the source code:
if isinstance(child, nn.Linear) and child.in_features >= 128:
    # Quantize this layer

Best Practices

When to Use Quantization

  • Large models: Models with many parameters benefit most from quantization
  • Inference only: Quantization is designed for evaluation, not training
  • Memory-constrained environments: When GPU memory is limited
  • Batch inference: Quantization reduces memory per sample

Quantization Type Selection

  • INT8: Best balance of speed, memory, and accuracy
  • NF4: Maximum memory savings when accuracy loss is acceptable

Performance Considerations

  • Quantization adds some computational overhead
  • First inference may be slower due to lazy initialization
  • Subsequent inferences are typically faster than float32
  • Memory savings enable larger batch sizes, improving throughput

Build docs developers (and LLMs) love