Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/avnlp/llm-finetuning/llms.txt

Use this file to discover all available pages before exploring further.

Every trainer in the TRL ecosystem expects its dataset in a specific column schema. SFTTrainer wants a single text column; GRPOTrainer wants a prompt list and an answer string; PPOTrainer needs pre-tokenized input_ids. Handling these schema differences in ad-hoc scripts leads to duplication and subtle bugs. BaseDatasetLoader solves this by giving every dataset a single load() entry point and a format_example() hook where the schema is enforced. All built-in loaders in this project subclass BaseDatasetLoader, and you can do the same for any new dataset you add.

DatasetConfig

DatasetConfig is a frozen dataclass that bundles everything needed to download a HuggingFace dataset split.
@dataclass(frozen=True, slots=True)
class DatasetConfig:
    """Configuration for loading and mapping a HuggingFace dataset."""

    dataset_id: str
    subset: str | None = None
    cache_dir: str | None = None
    remove_source_columns: bool = True
    num_proc: int = 1
    extra_load_kwargs: dict[str, Any] = field(default_factory=dict)
    """Extra keyword arguments forwarded verbatim to load_dataset."""
dataset_id
str
required
The HuggingFace Hub dataset identifier, e.g. "allenai/ai2_arc" or "openai/gsm8k". Passed directly to datasets.load_dataset.
subset
str | None
default:"None"
The dataset configuration name (second positional argument to load_dataset), e.g. "ARC-Challenge" or "rc". Set to None for datasets with a single configuration.
cache_dir
str | None
default:"None"
Local directory where the downloaded dataset files are cached. Defaults to the HuggingFace cache (~/.cache/huggingface/datasets).
remove_source_columns
bool
default:"True"
When True, all original dataset columns are removed after format_example runs, leaving only the columns produced by the formatter. Set to False to keep source columns alongside the formatted ones.
num_proc
int
default:"1"
Number of processes used by datasets.Dataset.map when applying format_example. Increase for large datasets on multi-core machines.
extra_load_kwargs
dict[str, Any]
default:"{}"
Additional keyword arguments forwarded verbatim to load_dataset. Use this to pass flags like trust_remote_code=True or download_mode="force_redownload".

BaseDatasetLoader

class BaseDatasetLoader(ABC):
    """Loads a HuggingFace dataset split and formats it for training.

    Output column schema by trainer type:
      - SFT  (SFTTrainer):   {"text": str}
      - GRPO (GRPOTrainer):  {"prompt": list[dict], "answer": str}
      - DPO/ORPO:            {"prompt": str, "chosen": str, "rejected": str}
      - PPO:                 {"prompt": str} → tokenized to {"input_ids": list[int]}
      - KTO:                 {"prompt": str, "completion": str, "label": bool}
    """

    def __init__(self, config: DatasetConfig) -> None:
        """Create a loader bound to a concrete dataset configuration."""
        self.config = config

Output column schemas by trainer type

Each trainer requires a specific set of columns. Your format_example implementation must return a dict that matches the schema for the trainer you are targeting.
TrainerRequired columnsTypes
SFTTrainertextstr
GRPOTrainerprompt, answerlist[dict], str
DPOTrainer / ORPOTrainerprompt, chosen, rejectedstr, str, str
PPOTrainerinput_ids (after tokenization)list[int]
KTOTrainerprompt, completion, labelstr, str, bool
PPO loaders first produce {"prompt": str} from format_example, then call tokenize() to convert the prompt strings to input_ids. See the tokenize() section below.

The load() method

load() is the primary entry point. Call it with a split name to get a formatted Dataset ready for training:
from llm_finetuning.core import DatasetConfig
from llm_finetuning.supervised_finetuning.loaders import ARCLoader
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
loader = ARCLoader(tokenizer)
train_dataset = loader.load("train")
test_dataset = loader.load("test")
Internally, load() calls datasets.load_dataset with the parameters from DatasetConfig, then pipes the result through _format_dataset, which maps format_example over every row.
def load(self, split: str) -> Dataset:
    """Load a dataset split and map it to the trainer's expected schema."""
    ds = load_dataset(
        self.config.dataset_id,
        self.config.subset,
        split=split,
        cache_dir=self.config.cache_dir,
        **self.config.extra_load_kwargs,
    )
    return self._format_dataset(ds)

The format_example() method

format_example is the only abstract method you must implement in a subclass. It receives one raw dataset row and returns a dict whose keys match the target trainer’s schema.
@abstractmethod
def format_example(self, example: dict[str, Any]) -> dict[str, Any]:
    """Convert a raw example into the schema expected by the trainer.

    Args:
        example: A single row from the raw HuggingFace dataset.

    Returns:
        A dict whose keys are the columns expected by the trainer.
    """
    raise NotImplementedError
The method is called by _format_dataset via Dataset.map, so it must be stateless — it cannot accumulate state across rows.

The tokenize() method

tokenize() is used exclusively for PPO, where PPOTrainer expects input_ids rather than raw strings:
def tokenize(
    self,
    ds: Dataset,
    tokenizer: PreTrainedTokenizerBase,
    text_column: str = "prompt",
) -> Dataset:
    """Pre-tokenize a text column into input_ids for PPO."""
ds
Dataset
required
A formatted Dataset containing a text column (typically the output of load()).
tokenizer
PreTrainedTokenizerBase
required
The tokenizer to encode the text. Must match the model being trained.
text_column
str
default:"\"prompt\""
Name of the column to tokenize. Defaults to "prompt".
Usage for PPO:
loader = MyPPOLoader(config)
ds = loader.load("train")                 # {"prompt": str}
ds = loader.tokenize(ds, tokenizer)       # {"input_ids": list[int]}

Overriding the dataset via config.yaml

Train scripts read the dataset ID and subset from config.yaml and pass them to DatasetConfig, falling back to the loader’s class-level CONFIG constant when the keys are absent. This means you can point any loader at a different dataset by adding two lines to config.yaml — no code change required.
# config.yaml
model_id: "meta-llama/Llama-3.2-3B"
dataset_id: "my-org/my-custom-arc-dataset"   # overrides ARCLoader.CONFIG.dataset_id
dataset_subset: "challenge"                   # overrides ARCLoader.CONFIG.subset
split: "train"
The train script pattern that enables this:
loader_config = DatasetConfig(
    dataset_id=config.get("dataset_id", ARCLoader.CONFIG.dataset_id),
    subset=config.get("dataset_subset", ARCLoader.CONFIG.subset),
)
dataset = ARCLoader(tokenizer, config=loader_config).load(config["split"])

Implementing a custom loader subclass

The following example creates an SFT loader for a new dataset. The pattern mirrors the built-in loaders in supervised_finetuning/loaders.py.
from typing import Any

from transformers import PreTrainedTokenizerBase

from llm_finetuning.core import BaseDatasetLoader, DatasetConfig, PromptTemplate


# Define a prompt template for the dataset
MY_TEMPLATE = PromptTemplate(
    system_prompt="You are a helpful assistant. Answer the question concisely.",
    user_template="Question: {question}",
    response_field="answer",
)


class MyDatasetLoader(BaseDatasetLoader):
    """SFT loader for my custom QA dataset."""

    CONFIG = DatasetConfig(dataset_id="my-org/my-dataset")

    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        config: DatasetConfig | None = None,
    ) -> None:
        super().__init__(config or self.CONFIG)
        self.template = MY_TEMPLATE
        self.tokenizer = tokenizer

    def format_example(self, example: dict[str, Any]) -> dict[str, Any]:
        """Format a raw row into the {'text': str} schema for SFTTrainer."""
        messages = [
            {"role": "system", "content": self.template.system_prompt},
            {"role": "user", "content": self.template.render_user(**example)},
            {"role": "assistant", "content": str(example[self.template.response_field])},
        ]
        return {
            "text": self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
        }
For GRPO loaders, return {"prompt": list[dict], "answer": str} instead of {"text": str}. Build the prompt message list directly in format_example rather than using PromptTemplate.render_user(), since GRPO loaders work with raw message dicts.
def format_example(self, example: dict[str, Any]) -> dict[str, Any]:
    """Format a raw row into the {'prompt', 'answer'} schema for GRPOTrainer."""
    return {
        "prompt": [
            {"role": "system", "content": "Solve the problem step by step."},
            {"role": "user", "content": example["question"]},
        ],
        "answer": example["answer"],
    }

Build docs developers (and LLMs) love