The quantization utilities enable running Samay models with reduced memory footprint using 8-bit (int8) or 4-bit (nf4) quantization.
quantize_linear_layers
Quantize linear layers in a PyTorch model using bitsandbytes.
from samay.quantization import quantize_linear_layers
quantized_model = quantize_linear_layers(
module,
threshold=6.0,
quantization_type="int8"
)
PyTorch module to quantize. The function recursively quantizes all nn.Linear layers with at least 128 input features
Threshold parameter for int8 quantization. Controls the outlier threshold for mixed-precision decomposition
Type of quantization to apply:
"int8": 8-bit linear quantization using bnb.nn.Linear8bitLt
"nf4": 4-bit NormalFloat quantization using bnb.nn.Linear4bit
The input module with linear layers replaced by quantized versions
Example
import torch
from samay.model import LPTMModel
from samay.quantization import quantize_linear_layers
# Load model
config = {
"task_name": "forecasting",
"forecast_horizon": 192,
"freeze_encoder": True,
}
model = LPTMModel(config)
model.model = model.model.to("cuda")
# Check memory before quantization
before_size = sum(p.numel() * p.element_size() for p in model.model.parameters())
print(f"Before quantization: {before_size / 1e6:.2f} MB")
# Quantize with int8
model.model = quantize_linear_layers(
model.model,
threshold=6.0,
quantization_type="int8"
)
# Check memory after quantization
after_size = sum(p.numel() * p.element_size() for p in model.model.parameters())
print(f"After quantization: {after_size / 1e6:.2f} MB")
print(f"Memory reduction: {(1 - after_size/before_size) * 100:.1f}%")
Quantization Types
INT8 Quantization
Uses bnb.nn.Linear8bitLt for 8-bit integer quantization:
- Reduces memory by approximately 4x compared to float32
- Minimal accuracy loss for most tasks
- Fast inference on CUDA devices
- Uses mixed-precision decomposition for outliers
model.model = quantize_linear_layers(
model.model,
quantization_type="int8",
threshold=6.0
)
NF4 Quantization
Uses bnb.nn.Linear4bit for 4-bit NormalFloat quantization:
- Reduces memory by approximately 8x compared to float32
- Slightly higher accuracy loss than int8
- Optimized for normal distributions
- Computes in float16
model.model = quantize_linear_layers(
model.model,
quantization_type="nf4"
)
Requirements
Quantization requires the bitsandbytes library:
bitsandbytes requires CUDA for GPU acceleration. CPU-only quantization is not supported.
Complete Example with Evaluation
import torch
from samay.model import LPTMModel
from samay.dataset import LPTMDataset
from samay.quantization import quantize_linear_layers
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
config = {
"task_name": "forecasting",
"forecast_horizon": 192,
"head_dropout": 0,
"freeze_encoder": True,
}
model = LPTMModel(config)
model.model = model.model.to(device)
# Quantize
print("Quantizing model...")
model.model = quantize_linear_layers(
model.model,
threshold=6.0,
quantization_type="int8"
)
# Verify quantization
import bitsandbytes as bnb
import torch.nn as nn
n_int8 = sum(1 for m in model.model.modules() if isinstance(m, bnb.nn.Linear8bitLt))
n_linear = sum(1 for m in model.model.modules() if isinstance(m, nn.Linear))
print(f"Quantized layers (int8): {n_int8}")
print(f"Remaining linear layers: {n_linear}")
# Load dataset
dataset = LPTMDataset(
name="ett",
datetime_col="date",
path="data/ETTh1.csv",
mode="val",
horizon=192,
)
# Evaluate with memory tracking
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats()
avg_loss, trues, preds, histories = model.evaluate(
dataset,
task_name="forecasting"
)
if device.type == "cuda":
peak_memory = torch.cuda.max_memory_allocated() / 1024**2
print(f"Peak GPU memory: {peak_memory:.2f} MB")
print(f"Average loss: {avg_loss:.4f}")
Layer Selection
The quantize_linear_layers function only quantizes nn.Linear layers with at least 128 input features. This is because:
- Smaller layers have minimal memory impact
- Quantization overhead may outweigh benefits for small layers
- Maintaining precision in small layers preserves model quality
To customize this threshold, modify the condition in the source code:
if isinstance(child, nn.Linear) and child.in_features >= 128:
# Quantize this layer
Best Practices
When to Use Quantization
- Large models: Models with many parameters benefit most from quantization
- Inference only: Quantization is designed for evaluation, not training
- Memory-constrained environments: When GPU memory is limited
- Batch inference: Quantization reduces memory per sample
Quantization Type Selection
- INT8: Best balance of speed, memory, and accuracy
- NF4: Maximum memory savings when accuracy loss is acceptable
- Quantization adds some computational overhead
- First inference may be slower due to lazy initialization
- Subsequent inferences are typically faster than float32
- Memory savings enable larger batch sizes, improving throughput