Skip to main content

Class Signature

class TinyTimeMixerModel(Basemodel):
    def __init__(self, config=None, repo=None)
The TinyTimeMixerModel class implements IBM’s TinyTimeMixer, a lightweight and efficient time series forecasting model designed for edge deployment and resource-constrained environments.

Initialization Parameters

config
dict
required
Model configuration dictionary. Must include context_len and horizon_len.
repo
str
required
Hugging Face model repository ID. The model automatically selects the appropriate revision based on context_len and horizon_len.

Configuration Parameters

context_len
int
required
Length of the input context window (e.g., 512).
horizon_len
int
required
Forecast horizon length. The model will select the closest available horizon from [96, 192, 336, 720].

Methods

finetune()

def finetune(dataset: TinyTimeMixerDataset, **kwargs)
Finetune the model on the given dataset.
dataset
TinyTimeMixerDataset
required
Dataset for finetuning. Call get_data_loader() to get the dataloader.
**kwargs
dict
Optional keyword arguments (currently uses default hyperparameters: lr=1e-4, epochs=5).
return
None
The model is finetuned in-place.

evaluate()

def evaluate(dataset: TinyTimeMixerDataset, metric_only=False, **kwargs)
Evaluate the model on a dataset.
dataset
TinyTimeMixerDataset
required
Dataset for evaluation. Call get_data_loader() to get the dataloader.
metric_only
bool
default:"False"
If True, return only metrics.
return
Dict[str, float] | Tuple
When metric_only=True:Dictionary containing:
  • mse: Mean Squared Error
  • mae: Mean Absolute Error
  • mase: Mean Absolute Scaled Error
  • mape: Mean Absolute Percentage Error
  • rmse: Root Mean Squared Error
  • nrmse: Normalized RMSE
  • smape: Symmetric Mean Absolute Percentage Error
  • msis: Mean Scaled Interval Score
  • nd: Normalized Deviation
When metric_only=False:Tuple of (metrics, trues, preds, histories):
  • metrics: Dictionary of metrics (as above)
  • trues: Ground truth values, shape (num_samples, num_ts, horizon_len)
  • preds: Predictions, shape (num_samples, num_ts, horizon_len)
  • histories: Historical context, shape (num_samples, num_ts, context_len)

plot()

def plot(dataset: TinyTimeMixerDataset, **kwargs)
Plot the forecast results.
dataset
TinyTimeMixerDataset
required
Dataset for plotting. Call get_data_loader() to get the dataloader.
**kwargs
dict
Additional keyword arguments forwarded to visualization.
return
None
This method does not return a value. It displays visualizations.

Usage Example

from samay.model import TinyTimeMixerModel
from samay.dataset import TinyTimeMixerDataset

# Initialize model
config = {
    "context_len": 512,
    "horizon_len": 96
}
model = TinyTimeMixerModel(
    config=config,
    repo="ibm/tinytimemixer"
)

# Prepare dataset
dataset = TinyTimeMixerDataset(...)

# Finetune
model.finetune(dataset)

# Evaluate
metrics, trues, preds, histories = model.evaluate(
    dataset,
    metric_only=False
)
print(f"MAE: {metrics['mae']}")

# Visualize
model.plot(dataset)

Notes

  • TinyTimeMixer is optimized for efficiency and edge deployment
  • The model automatically selects the appropriate checkpoint revision based on your specified horizon_len
  • Available horizon lengths are 96, 192, 336, and 720 - the model uses the closest larger value
  • The default context length of 512 with horizon 96 uses the “main” revision
  • Data is automatically permuted to match the expected input format (batch, channels, time)
  • The model requires a repository ID - it cannot be initialized without pre-trained weights

Build docs developers (and LLMs) love