Skip to main content

Overview

Base dataset abstraction used by model-specific dataset classes. Subclasses should implement preprocess() and get_data_loader() methods.

Class signature

class BaseDataset:
    def __init__(
        self,
        name: str = None,
        datetime_col: str = None,
        path: str = None,
        batchsize: int = 8,
        mode: str = "train",
        **kwargs,
    )

Parameters

name
str
default:"None"
Name of the dataset. If provided, a corresponding get_{name}_dataset function may be used to populate data.
datetime_col
str
default:"None"
Name of the datetime column in the source CSV.
path
str
default:"None"
Path to the dataset file. If omitted, the loader function for name is called.
batchsize
int
default:"8"
Batch size to be used by dataloaders.
mode
str
default:"train"
Mode of dataset usage, e.g. 'train' or 'test'.
kwargs
dict
Extra backend-specific options.

Methods

__len__()

Return the number of items in the dataset.
def __len__(self) -> int
Returns: int - Number of samples in the dataset (len(self.data)).

preprocess()

Preprocess method to be implemented by subclasses.
def preprocess(self, **kwargs)
This method raises NotImplementedError and must be overridden in subclasses.

get_data_loader()

Get data loader method to be implemented by subclasses.
def get_data_loader()
This method raises NotImplementedError and must be overridden in subclasses.

save()

Save the dataset to disk.
def save(self, path)
Parameters:
  • path (str): Path where the dataset will be saved.

Example usage

from samay.dataset import BaseDataset

# BaseDataset is meant to be subclassed, not used directly
class CustomDataset(BaseDataset):
    def preprocess(self, **kwargs):
        # Implement preprocessing logic
        pass
    
    def get_data_loader(self):
        # Implement data loader creation
        pass

Build docs developers (and LLMs) love