Skip to main content

Class Signature

class TimesfmModel(Basemodel):
    def __init__(self, config=None, repo=None, ckpt=None, **kwargs)
The TimesfmModel class implements Google’s TimesFM, a decoder-only foundation model for time series forecasting with quantile predictions.

Initialization Parameters

config
dict
default:"None"
Model configuration dictionary. Must include TimesFM-specific hyperparameters.
repo
str
default:"None"
Hugging Face model repository ID for loading pre-trained checkpoints.
ckpt
str
default:"None"
Custom checkpoint path (optional).

Configuration Parameters

The config dictionary should contain TimesFM hyperparameters:
horizon_len
int
required
Forecast horizon length.
quantiles
list[float]
required
List of quantile levels for probabilistic forecasting (e.g., [0.1, 0.5, 0.9]).

Methods

finetune()

def finetune(dataset: TimesfmDataset, freeze_transformer=True, **kwargs)
Finetune the model on the given dataset.
dataset
TimesfmDataset
required
The dataset for finetuning. Use get_data_loader() to obtain the dataloader.
freeze_transformer
bool
default:"True"
Whether to freeze the transformer layers during finetuning. When True, only the output projection layers are trained.
lr
float
default:"1e-4"
Learning rate for training.
epoch
int
default:"5"
Number of training epochs.
return
ppd.PatchedDecoderFinetuneModel
The finetuned model.

forecast()

def forecast(input, **kwargs)
Generate forecast from input data.
input
torch.Tensor
required
Input time series data.
return
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
  • Mean forecast of shape (# inputs, # forecast horizon)
  • Full forecast (mean + quantiles) of shape (# inputs, # forecast horizon, 1 + # quantiles)

evaluate()

def evaluate(dataset: TimesfmDataset, metric_only=False, **kwargs)
Evaluate the model on a dataset.
dataset
TimesfmDataset
required
Dataset for evaluation. Call get_data_loader() to get the dataloader.
metric_only
bool
default:"False"
If True, return only metrics.
return
Dict[str, float] | Tuple
When metric_only=True:Dictionary containing:
  • mse: Mean Squared Error
  • mae: Mean Absolute Error
  • mase: Mean Absolute Scaled Error
  • mape: Mean Absolute Percentage Error
  • rmse: Root Mean Squared Error
  • nrmse: Normalized RMSE
  • smape: Symmetric Mean Absolute Percentage Error
  • msis: Mean Scaled Interval Score
  • nd: Normalized Deviation
  • mwsq: Mean Weighted Scaled Quantile Loss
  • crps: Continuous Ranked Probability Score
When metric_only=False:Tuple of (metrics, trues, preds, histories, quantiles):
  • metrics: Dictionary of metrics (as above)
  • trues: Ground truth values
  • preds: Mean predictions
  • histories: Historical context
  • quantiles: Quantile predictions

plot()

def plot(dataset: TimesfmDataset, **kwargs)
Plot forecast results.
dataset
TimesfmDataset
required
Dataset for plotting. Use get_data_loader() to obtain the dataloader.
**kwargs
dict
Additional keyword arguments forwarded to the visualization helper.
return
None
This method does not return a value. It displays visualizations.

Usage Example

from samay.model import TimesfmModel
from samay.dataset import TimesfmDataset

# Initialize model with config
config = {
    "horizon_len": 96,
    "quantiles": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
}
model = TimesfmModel(config=config, repo="google/timesfm-1.0-200m")

# Prepare dataset
dataset = TimesfmDataset(...)

# Finetune with frozen transformer
model.finetune(dataset, freeze_transformer=True, lr=1e-4, epoch=5)

# Generate forecasts
input_ts = ...  # Your input time series
mean_forecast, quantile_forecast = model.forecast(input_ts)

# Evaluate
metrics = model.evaluate(dataset, metric_only=True)
print(f"CRPS: {metrics['crps']}")

# Visualize
model.plot(dataset)

Notes

  • TimesFM is a decoder-only transformer model designed for zero-shot forecasting
  • The model outputs both point forecasts (mean) and quantile predictions for uncertainty estimation
  • When finetuning with freeze_transformer=True, only the output layers are updated
  • Quantile predictions enable probabilistic forecasting and uncertainty quantification
  • Data is automatically denormalized during evaluation if the dataset was normalized

Build docs developers (and LLMs) love