Skip to main content

Base Model Interface

All models in Samay inherit from the Basemodel class, which defines a unified interface with four core methods.

Core Methods

finetune()

Adapts a pretrained model to your specific dataset:
def finetune(self, dataset: BaseDataset, **kwargs):
    """Finetune the model on the given dataset.
    
    Args:
        dataset (Dataset): Dataset used for finetuning.
        **kwargs: Optional keyword arguments for finetuning (e.g. lr, epoch).
    
    Returns:
        Any: Model or training result produced by the finetune operation
    """
Common kwargs:
  • lr (float) - Learning rate (default varies by model, typically 1e-4)
  • epoch (int) - Number of training epochs (default 5)
  • freeze_transformer (bool) - Freeze transformer layers (TimesFM only)
Example:
model = ChronosModel(repo="amazon/chronos-t5-small")
train_data = ChronosDataset(path="data.csv", mode="train")

# Finetune with custom learning rate and epochs
model.finetune(train_data, lr=5e-5, epoch=10)

forecast()

Generates predictions from input time series data:
def forecast(self, input, **kwargs):
    """Generate forecast(s) from input data.
    
    Args:
        input (Any): Input data for forecasting (e.g. torch.Tensor, numpy array,
            or model-specific input structure).
        **kwargs: Optional keyword arguments for forecasting.
    
    Returns:
        Any: Forecast outputs. The concrete model defines the exact format
            (for example, mean forecasts, quantiles, or full distribution).
    """
Return formats vary by model:
  • TimesFM: Returns (mean_forecast, quantile_forecast) tuple
    output, quantiles = model.forecast(input_ts)
    # output: (batch, horizon)
    # quantiles: (batch, horizon, num_quantiles)
    
  • Chronos: Use pipeline’s predict_quantiles() method
    quantiles, mean = model.pipeline.predict_quantiles(
        context=input_seq,
        prediction_length=96,
        quantile_levels=[0.1, 0.5, 0.9]
    )
    

evaluate()

Computes metrics on a test dataset:
def evaluate(self, dataset: BaseDataset, **kwargs):
    """Evaluate the model on a dataset.
    
    Args:
        dataset (Dataset): Dataset used for evaluation.
        **kwargs: Optional evaluation arguments (e.g. metric_only).
    
    Returns:
        Any: Evaluation metrics or (metrics, details) tuple
    """
Common kwargs:
  • metric_only (bool) - If True, return only metrics dict. If False, also return predictions and ground truth
  • horizon_len (int) - Forecast horizon (Chronos, ChronosBolt)
  • quantile_levels (list) - Quantile levels for probabilistic forecasting
Example:
# Return only metrics
metrics = model.evaluate(
    test_data,
    horizon_len=96,
    quantile_levels=[0.1, 0.5, 0.9],
    metric_only=True
)
print(metrics['mse'], metrics['mae'], metrics['crps'])

# Return metrics + predictions
metrics, trues, preds, histories, quantiles = model.evaluate(
    test_data,
    horizon_len=96,
    quantile_levels=[0.1, 0.5, 0.9],
    metric_only=False
)

save()

Persists model weights to disk:
def save(self, path):
    """Save the model to disk.
    
    Args:
        path (str): Filesystem path where model artifacts should be saved.
    """

Model Configuration Patterns

Loading from HuggingFace

The recommended approach for production use:
from samay import ChronosModel, TimesfmModel, MomentModel

# Chronos models
model = ChronosModel(repo="amazon/chronos-t5-small")
# Options: chronos-t5-tiny, chronos-t5-mini, chronos-t5-small, 
#          chronos-t5-base, chronos-t5-large

# TimesFM
model = TimesfmModel(repo="google/timesfm-1.0-200m")

# MOMENT
model = MomentModel(repo="AutonLab/MOMENT-1-large")
The repo parameter triggers:
  1. Download from HuggingFace Hub (cached locally)
  2. Automatic device mapping to available GPU
  3. Loading pretrained weights

Custom Configuration

For research or custom model variants:
# Chronos with custom config
chronos_config = {
    "model_type": "seq2seq",  # or "causal"
    "context_length": 512,
    "prediction_length": 64,
    "n_tokens": 4096,
    "n_special_tokens": 2,
    "tokenizer_class": "MeanScaleUniformBins",
    "tokenizer_kwargs": {"low_limit": -15.0, "high_limit": 15.0}
}

model = ChronosModel(config=chronos_config)
# TimesFM with custom config
timesfm_config = {
    "context_len": 512,
    "horizon_len": 128,
    "input_patch_len": 32,
    "output_patch_len": 128,
    "num_layers": 20,
    "model_dims": 1280,
    "quantiles": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
}

model = TimesfmModel(config=timesfm_config)

HuggingFace Integration

Pipeline Architecture

Many models use HuggingFace’s pipeline pattern:
# From src/samay/model.py:410-420
class ChronosModel(Basemodel):
    def __init__(self, config=None, repo=None):
        super().__init__(config=config, repo=repo)
        if repo:
            self.pipeline = ChronosPipeline.from_pretrained(
                repo, device_map=self.device
            )
        else:
            self.pipeline = ChronosPipeline(config=ChronosConfig(**config))
The pipeline encapsulates:
  • Tokenization (for transformer-based models)
  • Model forward pass
  • Output decoding/detokenization
  • Quantile prediction

Device Mapping

Models automatically map to the best available device:
# From src/samay/model.py:67-71
class Basemodel:
    def __init__(self, config=None, repo=None):
        least_used_gpu = get_least_used_gpu()
        if least_used_gpu >= 0:
            self.device = torch.device(f"cuda:{least_used_gpu}")
        else:
            self.device = torch.device("cpu")
Behavior:
  • Detects all available CUDA GPUs
  • Selects GPU with lowest memory usage
  • Falls back to CPU if no GPUs available
  • Automatically moves model and tensors to device

Fine-tuning Patterns

Full Fine-tuning

Update all model parameters:
model = ChronosModel(repo="amazon/chronos-t5-small")
train_data = ChronosDataset(path="data.csv", mode="train")

# All parameters trainable
model.finetune(train_data, lr=1e-4, epoch=5)

Frozen Transformer Fine-tuning

Only update the output head (TimesFM):
# From src/samay/model.py:174-176
if freeze_transformer:
    for param in FinetunedModel.core_layer.stacked_transformer.parameters():
        param.requires_grad = False
model = TimesfmModel(repo="google/timesfm-1.0-200m")
train_data = TimesfmDataset(path="data.csv", mode="train")

# Only finetune output layers
model.finetune(train_data, freeze_transformer=True, lr=1e-4, epoch=5)

Training Loop Example

The typical fine-tuning implementation (from ChronosModel):
# From src/samay/model.py:432-459
finetune_model = self.pipeline.model.model
dataloader = dataset.get_data_loader()
finetune_model.to(self.device)
finetune_model.train()
optimizer = torch.optim.AdamW(finetune_model.parameters(), lr=1e-4)

for epoch in range(5):
    for i, data in enumerate(dataloader):
        input_ids = data["input_ids"].to(self.device)
        attention_mask = data["attention_mask"].to(self.device)
        labels = data["labels"].to(self.device)
        
        optimizer.zero_grad()
        output = finetune_model(
            input_ids, 
            attention_mask=attention_mask, 
            labels=labels
        )
        loss = output.loss
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()

Model-Specific Features

Key Features:
  • Tokenization-based approach
  • Quantile forecasting
  • Seq2seq or causal architectures
  • Context length up to 512
Pipeline Methods:
quantiles, mean = model.pipeline.predict_quantiles(
    context=input_seq,
    prediction_length=96,
    quantile_levels=[0.1, 0.5, 0.9],
    limit_prediction_length=False
)
Key Features:
  • Patch-based processing
  • Multi-horizon forecasting
  • Quantile output (mean + 9 quantiles)
  • Returns tuple: (mean, full_quantiles)
Forecast Output:
output, quantiles = model.forecast(input_ts)
# output: (batch, horizon) - mean forecast
# quantiles: (batch, horizon, 1+num_quantiles) - mean + quantiles
Key Features:
  • Output patch-based prediction
  • Autoregressive forecasting for long horizons
  • Median-based autoregression
Long Horizon Handling:
# From src/samay/model.py:937-960
# Automatically handles horizons longer than max_patches
remaining_length = forecast_seq.shape[-1]
while remaining_length > 0:
    output = self.model(context=inputs, num_output_patches=current_horizon_patch)
    # Use median as context for next iteration
    last_ar_output = quantile_output[:, :, median_idx]

Next Steps

Datasets

Learn how to prepare data for different models

Evaluation

Understand metrics and evaluation patterns

Build docs developers (and LLMs) love