Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/lean-dojo/LeanDojo-v2/llms.txt

Use this file to discover all available pages before exploring further.

RetrievalTrainer trains the dense retrieval model at the heart of LeanDojo v2’s lifelong learning pipeline. Given a merged corpus of traced theorems and premises, it fine-tunes a PremiseRetriever — a bi-encoder that scores how relevant each premise in the corpus is to the current proof state — using PyTorch Lightning and an Elastic Weight Consolidation (EWC) regulariser to prevent catastrophic forgetting as new repositories are added. In practice, LeanAgent constructs and calls this trainer automatically; you only need to interact with it directly when building a custom retrieval pipeline.

Class

lean_dojo_v2.trainer.retrieval_trainer.RetrievalTrainer

Constructor Parameters

config
Optional[TrainingConfig]
default:"TrainingConfig()"
A lean_dojo_v2.lean_agent.config.TrainingConfig dataclass instance that controls every aspect of retrieval training. When None, a default TrainingConfig() is used. See TrainingConfig fields below for the full list of options.
For most use cases you should use LeanAgent directly rather than constructing RetrievalTrainer manually. LeanAgent creates the trainer, wires the database, and calls train() in the right order as part of the progressive learning loop.

train Method

RetrievalTrainer.train(
    repos: List[Repository],
    database: DynamicDatabase,
    data_path: Path = Path("raid/data/merged"),
    model_checkpoint_path: Optional[str] = None,
) -> None
For each repo in repos (processed cumulatively):
  1. Calls database.export_merged_data(repos_so_far, data_path) to produce corpus.jsonl and the random/ split directory.
  2. Resolves the latest checkpoint via find_latest_checkpoint() if model_checkpoint_path is None.
  3. Loads the PremiseRetriever from the checkpoint (unfrozen) and sets the EWC lambda with model.set_lambda(config.lambda_value).
  4. Instantiates a RetrievalDataModule from the corpus and random-split paths.
  5. Calls pl.Trainer.fit(model, datamodule=data_module) to train.
The resulting checkpoint — saved by PyTorch Lightning’s ModelCheckpoint callback to the directory defined in CHECKPOINT_DIR — is the file you pass to RetrievalProver(ret_ckpt_path=...).

evaluate Method

RetrievalTrainer.evaluate(dataset_path: Optional[str] = None) -> None
Loads the latest checkpoint, runs inference over every subdirectory under dataset_path, and logs R@1, R@10, and MRR retrieval metrics. Useful for offline evaluation after a training run completes.

TrainingConfig Fields

TrainingConfig is a Python dataclass defined in lean_dojo_v2.lean_agent.config. All fields have defaults and can be overridden:

Training Mode

FieldTypeDefaultDescription
run_progressive_trainingboolTrueEnable EWC regularisation for progressive / lifelong training.
single_repoboolTrueProcess one repository at a time.
num_reposint1Total number of repositories in the curriculum.
max_epochsint1Maximum epochs per repository.

Model

FieldTypeDefaultDescription
model_namestr"kaiyuy/leandojo-lean4-retriever-byt5-small"HuggingFace identifier for the retriever backbone.
lrfloat1e-3Peak learning rate.
warmup_stepsint1000Linear warmup steps.
max_seq_lenint512Max token length for the retriever encoder.
num_retrievedint100Number of premises retrieved per query at inference time.

Hardware & Training

FieldTypeDefaultDescription
batch_sizeint(from constants)Per-device training batch size.
eval_batch_sizeint64Evaluation batch size.
accumulate_grad_batchesint4Gradient accumulation steps.
num_gpusint1Number of GPUs (DDP).
num_workersint4DataLoader worker processes.
precisionstr"bf16-mixed"PyTorch Lightning mixed-precision mode.
gradient_clip_valfloat1.0Gradient clipping threshold.

Data

FieldTypeDefaultDescription
num_negativesint3Hard negatives per positive premise pair.
num_in_file_negativesint1In-file negatives per pair.
data_max_seq_lenint1024Maximum sequence length in the data module.
tokenizer_model_namestr"google/byt5-small"Tokenizer for the data module.

EWC Regularisation

FieldTypeDefaultDescription
lambda_valuefloat0.1EWC penalty weight. Set to 0.0 to disable.

Timeout

FieldTypeDefaultDescription
timeout_secondsint31449600 (~1 year)NCCL and DDP operation timeout in seconds.

Callbacks & Checkpointing

FieldTypeDefaultDescription
early_stopping_patienceint5Early stopping patience (epochs).
early_stopping_monitorstr"Recall@10_val"Metric to monitor for early stopping.
early_stopping_modestr"max"Optimisation direction for early stopping ("max" or "min").
checkpoint_monitorstr"Recall@10_val"Metric to monitor for ModelCheckpoint.
checkpoint_modestr"max"Optimisation direction for ModelCheckpoint ("max" or "min").
checkpoint_save_top_kint-1Save all checkpoints (-1) or only the top-k.
checkpoint_every_n_epochsint1Checkpoint frequency.

Logging & Sanity Checks

FieldTypeDefaultDescription
log_every_n_stepsint1How often (in steps) to log training metrics.
num_sanity_val_stepsint0Number of sanity-check validation batches before training starts.

Seed

FieldTypeDefaultDescription
seedint3407Global random seed (passed to seed_everything).

Loading Config from YAML or JSON

TrainingConfig supports serialisation to and from YAML and JSON files:
from lean_dojo_v2.lean_agent.config import TrainingConfig

# Load from YAML
config = TrainingConfig.from_yaml("config/retrieval.yaml")

# Load from JSON
config = TrainingConfig.from_json("config/retrieval.json")

# Save to YAML
config.to_yaml("config/retrieval_saved.yaml")

How LeanAgent Uses RetrievalTrainer

LeanAgent (in lean_dojo_v2.agent.lean_agent) constructs a RetrievalTrainer with its own config and calls trainer.train(repos, database, data_path) on each training cycle. After training, it locates the best checkpoint (highest Recall@10_val) and passes its path to RetrievalProver(ret_ckpt_path=...):
from lean_dojo_v2.agent.lean_agent import LeanAgent

agent = LeanAgent()
agent.setup_github_repository(
    url="https://github.com/durant42040/lean4-example",
    commit="b14fef0ceca29a65bc3122bf730406b33c7effe5",
)
agent.train()   # internally calls RetrievalTrainer.train(...)
agent.prove()   # internally calls RetrievalProver with the saved checkpoint

Advanced: Manual Construction

If you need direct control — for example, to evaluate a pre-trained retriever on a custom corpus without the full agent — you can construct RetrievalTrainer directly:
from lean_dojo_v2.trainer.retrieval_trainer import RetrievalTrainer
from lean_dojo_v2.lean_agent.config import TrainingConfig

config = TrainingConfig(
    max_epochs=3,
    lr=5e-4,
    lambda_value=0.05,   # lighter EWC regularisation
    num_gpus=2,
)

trainer = RetrievalTrainer(config=config)
trainer.train(repos=repos, database=database, data_path=data_path)
trainer.evaluate()
RetrievalTrainer relies on a pre-existing PremiseRetriever checkpoint to initialise from. The checkpoint path is discovered automatically via find_latest_checkpoint() if model_checkpoint_path is not supplied. Ensure that a valid checkpoint exists in CHECKPOINT_DIR before calling train(), or pass the path explicitly.
Set run_progressive_training=False and lambda_value=0.0 in TrainingConfig to disable EWC and train the retriever from a checkpoint without any regularisation penalty — useful when you have a single large repository and catastrophic forgetting is not a concern.

Build docs developers (and LLMs) love