Documentation Index
Fetch the complete documentation index at: https://mintlify.com/lean-dojo/LeanDojo-v2/llms.txt
Use this file to discover all available pages before exploring further.
RetrievalTrainer trains the dense retrieval model at the heart of LeanDojo v2’s lifelong learning pipeline. Given a merged corpus of traced theorems and premises, it fine-tunes a PremiseRetriever — a bi-encoder that scores how relevant each premise in the corpus is to the current proof state — using PyTorch Lightning and an Elastic Weight Consolidation (EWC) regulariser to prevent catastrophic forgetting as new repositories are added. In practice, LeanAgent constructs and calls this trainer automatically; you only need to interact with it directly when building a custom retrieval pipeline.
Class
lean_dojo_v2.trainer.retrieval_trainer.RetrievalTrainer
Constructor Parameters
config
Optional[TrainingConfig]
default:"TrainingConfig()"
A lean_dojo_v2.lean_agent.config.TrainingConfig dataclass instance that
controls every aspect of retrieval training. When None, a default
TrainingConfig() is used. See TrainingConfig fields
below for the full list of options.
For most use cases you should use LeanAgent directly rather than
constructing RetrievalTrainer manually. LeanAgent creates the trainer,
wires the database, and calls train() in the right order as part of the
progressive learning loop.
train Method
RetrievalTrainer.train(
repos: List[Repository],
database: DynamicDatabase,
data_path: Path = Path("raid/data/merged"),
model_checkpoint_path: Optional[str] = None,
) -> None
For each repo in repos (processed cumulatively):
- Calls
database.export_merged_data(repos_so_far, data_path) to produce
corpus.jsonl and the random/ split directory.
- Resolves the latest checkpoint via
find_latest_checkpoint() if
model_checkpoint_path is None.
- Loads the
PremiseRetriever from the checkpoint (unfrozen) and sets the EWC
lambda with model.set_lambda(config.lambda_value).
- Instantiates a
RetrievalDataModule from the corpus and random-split paths.
- Calls
pl.Trainer.fit(model, datamodule=data_module) to train.
The resulting checkpoint — saved by PyTorch Lightning’s ModelCheckpoint
callback to the directory defined in CHECKPOINT_DIR — is the file you pass to
RetrievalProver(ret_ckpt_path=...).
evaluate Method
RetrievalTrainer.evaluate(dataset_path: Optional[str] = None) -> None
Loads the latest checkpoint, runs inference over every subdirectory under
dataset_path, and logs R@1, R@10, and MRR retrieval metrics. Useful for
offline evaluation after a training run completes.
TrainingConfig Fields
TrainingConfig is a Python dataclass defined in
lean_dojo_v2.lean_agent.config. All fields have defaults and can be overridden:
Training Mode
| Field | Type | Default | Description |
|---|
run_progressive_training | bool | True | Enable EWC regularisation for progressive / lifelong training. |
single_repo | bool | True | Process one repository at a time. |
num_repos | int | 1 | Total number of repositories in the curriculum. |
max_epochs | int | 1 | Maximum epochs per repository. |
Model
| Field | Type | Default | Description |
|---|
model_name | str | "kaiyuy/leandojo-lean4-retriever-byt5-small" | HuggingFace identifier for the retriever backbone. |
lr | float | 1e-3 | Peak learning rate. |
warmup_steps | int | 1000 | Linear warmup steps. |
max_seq_len | int | 512 | Max token length for the retriever encoder. |
num_retrieved | int | 100 | Number of premises retrieved per query at inference time. |
Hardware & Training
| Field | Type | Default | Description |
|---|
batch_size | int | (from constants) | Per-device training batch size. |
eval_batch_size | int | 64 | Evaluation batch size. |
accumulate_grad_batches | int | 4 | Gradient accumulation steps. |
num_gpus | int | 1 | Number of GPUs (DDP). |
num_workers | int | 4 | DataLoader worker processes. |
precision | str | "bf16-mixed" | PyTorch Lightning mixed-precision mode. |
gradient_clip_val | float | 1.0 | Gradient clipping threshold. |
Data
| Field | Type | Default | Description |
|---|
num_negatives | int | 3 | Hard negatives per positive premise pair. |
num_in_file_negatives | int | 1 | In-file negatives per pair. |
data_max_seq_len | int | 1024 | Maximum sequence length in the data module. |
tokenizer_model_name | str | "google/byt5-small" | Tokenizer for the data module. |
EWC Regularisation
| Field | Type | Default | Description |
|---|
lambda_value | float | 0.1 | EWC penalty weight. Set to 0.0 to disable. |
Timeout
| Field | Type | Default | Description |
|---|
timeout_seconds | int | 31449600 (~1 year) | NCCL and DDP operation timeout in seconds. |
Callbacks & Checkpointing
| Field | Type | Default | Description |
|---|
early_stopping_patience | int | 5 | Early stopping patience (epochs). |
early_stopping_monitor | str | "Recall@10_val" | Metric to monitor for early stopping. |
early_stopping_mode | str | "max" | Optimisation direction for early stopping ("max" or "min"). |
checkpoint_monitor | str | "Recall@10_val" | Metric to monitor for ModelCheckpoint. |
checkpoint_mode | str | "max" | Optimisation direction for ModelCheckpoint ("max" or "min"). |
checkpoint_save_top_k | int | -1 | Save all checkpoints (-1) or only the top-k. |
checkpoint_every_n_epochs | int | 1 | Checkpoint frequency. |
Logging & Sanity Checks
| Field | Type | Default | Description |
|---|
log_every_n_steps | int | 1 | How often (in steps) to log training metrics. |
num_sanity_val_steps | int | 0 | Number of sanity-check validation batches before training starts. |
Seed
| Field | Type | Default | Description |
|---|
seed | int | 3407 | Global random seed (passed to seed_everything). |
Loading Config from YAML or JSON
TrainingConfig supports serialisation to and from YAML and JSON files:
from lean_dojo_v2.lean_agent.config import TrainingConfig
# Load from YAML
config = TrainingConfig.from_yaml("config/retrieval.yaml")
# Load from JSON
config = TrainingConfig.from_json("config/retrieval.json")
# Save to YAML
config.to_yaml("config/retrieval_saved.yaml")
How LeanAgent Uses RetrievalTrainer
LeanAgent (in lean_dojo_v2.agent.lean_agent) constructs a
RetrievalTrainer with its own config and calls trainer.train(repos, database, data_path) on each training cycle. After training, it locates the best
checkpoint (highest Recall@10_val) and passes its path to
RetrievalProver(ret_ckpt_path=...):
from lean_dojo_v2.agent.lean_agent import LeanAgent
agent = LeanAgent()
agent.setup_github_repository(
url="https://github.com/durant42040/lean4-example",
commit="b14fef0ceca29a65bc3122bf730406b33c7effe5",
)
agent.train() # internally calls RetrievalTrainer.train(...)
agent.prove() # internally calls RetrievalProver with the saved checkpoint
Advanced: Manual Construction
If you need direct control — for example, to evaluate a pre-trained retriever on
a custom corpus without the full agent — you can construct RetrievalTrainer
directly:
from lean_dojo_v2.trainer.retrieval_trainer import RetrievalTrainer
from lean_dojo_v2.lean_agent.config import TrainingConfig
config = TrainingConfig(
max_epochs=3,
lr=5e-4,
lambda_value=0.05, # lighter EWC regularisation
num_gpus=2,
)
trainer = RetrievalTrainer(config=config)
trainer.train(repos=repos, database=database, data_path=data_path)
trainer.evaluate()
RetrievalTrainer relies on a pre-existing PremiseRetriever checkpoint to
initialise from. The checkpoint path is discovered automatically via
find_latest_checkpoint() if model_checkpoint_path is not supplied. Ensure
that a valid checkpoint exists in CHECKPOINT_DIR before calling train(),
or pass the path explicitly.
Set run_progressive_training=False and lambda_value=0.0 in TrainingConfig
to disable EWC and train the retriever from a checkpoint without any
regularisation penalty — useful when you have a single large repository and
catastrophic forgetting is not a concern.