Skip to main content

Overview

Time-series imputation reconstructs missing or corrupted values in temporal data. Samay models like MOMENT use reconstruction-based approaches to intelligently fill gaps by learning patterns from the available context.

Models Supporting Imputation

ModelZero-ShotFine-TuningApproach
MOMENTMasked reconstruction

Step-by-Step Workflow

1

Load model for imputation

Initialize MOMENT with reconstruction task:
from samay.model import MomentModel

repo = "AutonLab/MOMENT-1-large"
config = {
    "task_name": "reconstruction",
}
mmt = MomentModel(config=config, repo=repo)
2

Prepare imputation dataset

Load data with missing values:
from samay.dataset import MomentDataset

test_dataset = MomentDataset(
    name="ett",
    datetime_col="date",
    path="data/ETTh1.csv",
    mode="test",
    task_name="imputation",
)
The dataset automatically creates masks for missing values. You can also specify custom masking ratios.
3

Run zero-shot imputation

Impute missing values without training:
# Evaluate and get reconstructed values
trues, preds, masks = mmt.evaluate(
    test_dataset, task_name="imputation"
)

print(trues.shape, preds.shape, masks.shape)
# (batch, channels, timesteps) e.g., (100, 7, 512)
4

Visualize imputation

Compare original and imputed values:
mmt.plot(test_dataset, task_name="imputation")
Or manually plot:
import matplotlib.pyplot as plt
import numpy as np

idx = np.random.randint(trues.shape[0])
channel_idx = np.random.randint(trues.shape[1])

fig, axs = plt.subplots(2, 1, figsize=(10, 5))

# Plot time-series
axs[0].set_title(f"Channel={channel_idx}")
axs[0].plot(
    trues[idx, channel_idx, :].squeeze(),
    label='Ground Truth',
    c='darkblue'
)
axs[0].plot(
    preds[idx, channel_idx, :].squeeze(),
    label='Imputed',
    c='red'
)
axs[0].legend(fontsize=16)

# Show mask (0 = missing, 1 = observed)
axs[1].imshow(
    np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)),
    cmap='binary'
)
plt.show()
5

Evaluate imputation quality

Calculate reconstruction metrics:
# Only evaluate on missing values (where mask == 0)
mse = np.mean((trues[masks==0] - preds[masks==0])**2)
mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))

print(f'MSE: {mse:.4f}, MAE: {mae:.4f}')

Real Example: ETTh1 Imputation

Complete workflow from moment_imputation.ipynb:
from samay.model import MomentModel
from samay.dataset import MomentDataset
import numpy as np

# Initialize model
repo = "AutonLab/MOMENT-1-large"
config = {"task_name": "reconstruction"}
mmt = MomentModel(config=config, repo=repo)

# Load dataset
test_dataset = MomentDataset(
    name="ett",
    datetime_col="date",
    path="data/ETTh1.csv",
    mode="test",
    task_name="imputation",
)

# Zero-shot imputation
trues, preds, masks = mmt.evaluate(test_dataset, task_name="imputation")

# Calculate metrics on missing values only
mse = np.mean((trues[masks==0] - preds[masks==0])**2)
mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))
print(f'MSE: {mse}, MAE: {mae}')

# Visualize
mmt.plot(test_dataset, task_name="imputation")
Output: Reconstructs missing values in the ETTh1 dataset with low MSE/MAE.

Fine-Tuning for Better Imputation

Improve imputation quality on domain-specific data:
# Prepare training dataset
train_dataset = MomentDataset(
    name="ett",
    datetime_col="date",
    path="data/ETTh1.csv",
    mode="train",
    task_name="imputation",
)

# Fine-tune
finetuned_model = mmt.finetune(
    train_dataset,
    task_name="imputation",
    epoch=5
)
# Epoch 0: Train loss: 0.262
# Epoch 1: Train loss: 0.259
# Epoch 2: Train loss: 0.256
# Epoch 3: Train loss: 0.253
# Epoch 4: Train loss: 0.249

# Evaluate fine-tuned model
trues, preds, masks = finetuned_model.evaluate(
    test_dataset, task_name="imputation"
)

mse = np.mean((trues[masks==0] - preds[masks==0])**2)
mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))
print(f'Fine-tuned MSE: {mse}, MAE: {mae}')
# Typically 10-20% better than zero-shot

Advanced Techniques

Custom Masking Strategy

Control which values are masked:
# Random masking with custom ratio
import torch

def custom_mask(data, mask_ratio=0.3):
    """Randomly mask 30% of values"""
    mask = torch.rand(data.shape) > mask_ratio
    masked_data = data.clone()
    masked_data[~mask] = 0  # or use a special value like -1
    return masked_data, mask

# Apply custom masking
masked_data, mask = custom_mask(original_data)

Handling Irregular Missingness

For real-world data with irregular gaps:
import pandas as pd
import numpy as np

# Load data with NaN values
df = pd.read_csv("data_with_missing.csv")

# Create mask: 1 = observed, 0 = missing
mask = (~df.isnull()).astype(int).values

# Fill NaN with 0 for model input
df_filled = df.fillna(0)

# Create dataset
test_dataset = MomentDataset(
    name="custom",
    path=None,  # Pass data directly
    data=df_filled.values,
    masks=mask,
    mode="test",
    task_name="imputation",
)

# Impute
trues, preds, masks = mmt.evaluate(test_dataset, task_name="imputation")

# Reconstruct DataFrame with imputed values
df_imputed = df.copy()
df_imputed[mask == 0] = preds[mask == 0]

Iterative Imputation

Refine imputation by iterating:
def iterative_imputation(model, data, mask, iterations=3):
    """Iteratively impute missing values"""
    imputed_data = data.copy()
    
    for i in range(iterations):
        # Create dataset with current imputed values
        dataset = MomentDataset(
            name="iter",
            data=imputed_data,
            masks=mask,
            mode="test",
            task_name="imputation",
        )
        
        # Impute
        trues, preds, _ = model.evaluate(dataset, task_name="imputation")
        
        # Update missing values with predictions
        imputed_data[mask == 0] = preds[mask == 0]
        
        print(f"Iteration {i+1} MSE: {np.mean((trues[mask==0] - preds[mask==0])**2)}")
    
    return imputed_data

# Apply iterative imputation
final_imputed = iterative_imputation(mmt, data, mask, iterations=3)

Multivariate Imputation

Leverage correlations between channels:
# Dataset with multiple correlated channels
multivar_dataset = MomentDataset(
    name="ett",
    path="data/ETTh1.csv",  # Columns: HUFL, HULL, MUFL, MULL, LUFL, LULL, OT
    datetime_col="date",
    mode="test",
    task_name="imputation",
)

# Impute all channels simultaneously
trues, preds, masks = mmt.evaluate(
    multivar_dataset, task_name="imputation"
)

# Model uses cross-channel information for better imputation
mse_per_channel = np.mean((trues - preds)**2, axis=(0, 2))
print(f"MSE per channel: {mse_per_channel}")

Evaluation Metrics

Mean Squared Error (MSE)

mse = np.mean((trues[masks==0] - preds[masks==0])**2)

Mean Absolute Error (MAE)

mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))

Root Mean Squared Error (RMSE)

rmse = np.sqrt(np.mean((trues[masks==0] - preds[masks==0])**2))

Mean Absolute Percentage Error (MAPE)

mape = np.mean(np.abs((trues[masks==0] - preds[masks==0]) / trues[masks==0])) * 100
print(f"MAPE: {mape:.2f}%")

Per-Channel Metrics

for ch in range(trues.shape[1]):
    ch_mse = np.mean((trues[:, ch, :][masks[:, ch, :]==0] - 
                      preds[:, ch, :][masks[:, ch, :]==0])**2)
    print(f"Channel {ch} MSE: {ch_mse:.4f}")

Use Cases

Sensor Data

Fill gaps in IoT sensor readings due to transmission errors or sensor failures

Financial Data

Impute missing stock prices, trading volumes, or economic indicators

Healthcare

Reconstruct missing patient vitals or lab results in electronic health records

Weather Data

Fill gaps in meteorological measurements (temperature, humidity, pressure)

Tips for Better Imputation

Missing Completely at Random (MCAR): Easiest to impute
Missing at Random (MAR): Imputable with covariates
Missing Not at Random (MNAR): Most challenging, may need domain knowledge
Ensure enough observed values around missing points. Avoid imputing long consecutive gaps (>50% of context length).
If you have multiple channels, use them all. Cross-channel patterns improve imputation accuracy.
Fine-tuning on data from the same domain (same sensors, same patients, etc.) improves imputation by 15-25%.
Test your imputation on missing patterns similar to real-world scenarios, not just random masking.
Post-process imputations with domain-specific constraints (e.g., temperature ranges, physical laws).

Common Pitfalls

Imputing too much: Don’t impute >50% of your data—predictions become unreliable. Consider if analysis is valid with that much missing data.
Ignoring uncertainty: Imputations are estimates, not ground truth. Quantify uncertainty (e.g., via ensembles) for critical applications.
Overfitting during fine-tuning: Use validation set to monitor imputation quality and prevent overfitting.

Next Steps

For more examples, see the MOMENT Imputation notebook.

Build docs developers (and LLMs) love