Skip to main content

Class Signature

class Chronos_2_Model(Basemodel):
    def __init__(self, config=None, repo=None)
The Chronos_2_Model class implements Chronos 2.0, the next generation of Amazon’s Chronos with enhanced capabilities for handling longer sequences and improved forecasting performance.

Initialization Parameters

config
dict
default:"None"
Model configuration dictionary using Chronos2CoreConfig parameters. Used when initializing a new model without pre-trained weights.
repo
str
default:"None"
Hugging Face model repository ID. If provided, loads the pre-trained Chronos 2.0 model. If not provided, initializes a new model using the config.

Properties

max_patches
int
Maximum number of output patches the model can generate, derived from model.chronos_config.max_output_patches.
patch_size
int
Size of each output patch, derived from model.chronos_config.output_patch_size.

Methods

finetune()

def finetune(dataset: Chronos_2_Dataset, **kwargs)
Finetune the model on the given dataset.
dataset
Chronos_2_Dataset
required
Dataset for finetuning. Call get_data_loader() to get the dataloader.
epoch
int
default:"5"
Number of training epochs. Can be passed in kwargs.
**kwargs
dict
Optional keyword arguments (default lr=1e-4, default epochs=5).
return
None
The model is finetuned in-place.

evaluate()

def evaluate(
    dataset: Chronos_2_Dataset,
    metric_only=False,
    **kwargs
)
Evaluate the Chronos 2.0 model on a dataset.
dataset
Chronos_2_Dataset
required
Dataset for evaluation. Use get_data_loader() to obtain the dataloader.
metric_only
bool
default:"False"
If True, return only a dict of metrics. If False, return metrics plus arrays.
return
Dict[str, float] | Tuple
When metric_only=True:Dictionary containing:
  • mse: Mean Squared Error
  • mae: Mean Absolute Error
  • mase: Mean Absolute Scaled Error
  • mape: Mean Absolute Percentage Error
  • rmse: Root Mean Squared Error
  • nrmse: Normalized RMSE
  • smape: Symmetric Mean Absolute Percentage Error
  • msis: Mean Scaled Interval Score
  • nd: Normalized Deviation
  • mwsq: Mean Weighted Scaled Quantile Loss
  • crps: Continuous Ranked Probability Score
When metric_only=False:Tuple of (metrics, trues, preds, histories):
  • metrics: Dictionary of metrics (as above)
  • trues: Ground truth values, shape (batch_size, n_channels, horizon_len)
  • preds: Mean predictions, shape (batch_size, n_channels, horizon_len)
  • histories: Input context sequences, shape (batch_size, n_channels, context_len)

Usage Example

from samay.model import Chronos_2_Model
from samay.dataset import Chronos_2_Dataset

# Load pre-trained model
model = Chronos_2_Model(repo="amazon/chronos-2-small")

# Or initialize with custom config
config = {
    "max_output_patches": 512,
    "output_patch_size": 8,
    # ... other Chronos2CoreConfig parameters
}
model = Chronos_2_Model(config=config)

# Prepare dataset
dataset = Chronos_2_Dataset(...)

# Finetune with custom epochs
model.finetune(dataset, epoch=10)

# Evaluate
metrics = model.evaluate(dataset, metric_only=True)
print(f"MSE: {metrics['mse']:.4f}")
print(f"CRPS: {metrics['crps']:.4f}")

# Get detailed results
metrics, trues, preds, histories = model.evaluate(dataset, metric_only=False)

Notes

  • Chronos 2.0 uses a patch-based architecture for better long sequence handling
  • The model automatically calculates the number of output patches based on the forecast horizon and patch size
  • Improved over original Chronos with better scalability and performance
  • Supports both fine-tuning and zero-shot forecasting
  • The horizon length for fine-tuning must be smaller than max_patches * patch_size

Build docs developers (and LLMs) love