Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/lean-dojo/LeanDojo-v2/llms.txt

Use this file to discover all available pages before exploring further.

RetrievalProver implements retrieval-augmented generation (RAG) for Lean 4 proof search. At each step of the DFS loop it queries a RetrievalAugmentedGenerator — built on kaiyuy/leandojo-lean4-retriever-tacgen-byt5-small — which first retrieves the most relevant premises from a pre-indexed corpus and then generates tactics conditioned on both the goal and those premises. Candidate tactics are sampled proportionally to their softmax-normalized log-probabilities, giving higher-confidence suggestions a better chance of being tried first.

Class

lean_dojo_v2.prover.retrieval_prover.RetrievalProver(BaseProver)
Exported from lean_dojo_v2.prover. Internally wraps lean_dojo_v2.lean_agent.generator.model.RetrievalAugmentedGenerator.

Constructor

RetrievalProver(
    ret_ckpt_path,
    gen_ckpt_path,
    indexed_corpus_path,
)
Loads the RAG model from gen_ckpt_path, injects the retriever checkpoint from ret_ckpt_path, loads the premise corpus from indexed_corpus_path, and rebuilds the retrieval index with batch_size=32. Construction is blocking and may take a minute or more on first run while the index is built.
ret_ckpt_path
str
required
Path to the retriever checkpoint produced by a RetrievalTrainer run (typically a .ckpt file such as retriever-epoch=5.ckpt). This checkpoint is injected into the generator’s retriever module.
gen_ckpt_path
str
required
Path to the combined generator Lightning checkpoint (e.g. model_lightning.ckpt). Defaults in LeanAgent to os.path.join(RAID_DIR, "model_lightning.ckpt").
indexed_corpus_path
str
required
Path to the corpus.jsonl file produced by DynamicDatabase.export_merged_data(). This file contains the premise strings that the retriever will index and search at inference time.

Hardcoded configuration

The following model hyperparameters are fixed at construction and cannot be overridden without subclassing:
ParameterValue
model_name"kaiyuy/leandojo-lean4-retriever-tacgen-byt5-small"
num_beams5
eval_num_retrieved10
max_inp_seq_len512
max_oup_seq_len128
For most workflows, use LeanAgent rather than constructing RetrievalProver directly. LeanAgent._setup_prover() handles checkpoint path resolution, device selection, and corpus wiring automatically.

Methods

next_tactic

def next_tactic(state: GoalState, goal_id: int) -> Optional[Tactic]
Generates a single tactic for the given goal using the RAG model. Called automatically by BaseProver.search() on every DFS step. Generation process:
  1. Calls tactic_generator.generate(state=..., file_path=..., theorem_full_name=..., theorem_pos=..., num_samples=10) to produce 10 (tactic, log_prob) pairs.
  2. Computes probs = softmax(log_probs) using NumPy.
  3. Samples one tactic using random.choices(..., weights=probs, k=1).
state
GoalState
required
The Pantograph GoalState at the current DFS node. Converted to a string and passed to the generator as the state argument.
goal_id
int
required
Index of the active goal in state.goals. Used for interface compatibility; the generator operates on the full goal state string.
tactic
Optional[Tactic]
A single tactic string sampled from the model’s top-10 candidates, or None if self.theorem has not been set.

generate_whole_proof

def generate_whole_proof(theorem: Theorem) -> str
generate_whole_proof() is not implemented in RetrievalProver and raises NotImplementedError when called. Only proof search via search() is supported. Use HFProver or ExternalProver if you need whole-proof generation.
theorem
Theorem
required
Not used — calling this method always raises NotImplementedError.

Example

The following shows the typical construction pattern used inside LeanAgent._setup_prover():
from lean_dojo_v2.prover.retrieval_prover import RetrievalProver

prover = RetrievalProver(
    ret_ckpt_path="raid/checkpoints/retriever-epoch=5.ckpt",
    gen_ckpt_path="raid/model_lightning.ckpt",
    indexed_corpus_path="raid/data/merged/corpus.jsonl",
)
After construction, prover is a fully initialized BaseProver and can be passed directly to search():
from pantograph.server import Server

server = Server(imports=["Mathlib"], project_path="/path/to/mathlib4")

result, tactics = prover.search(
    server=server,
    goal="⊢ ∀ (n : Nat), n + 0 = n",
    verbose=True,
)

if result.success:
    print(f"Proved in {result.steps} steps:")
    for tactic in tactics:
        print(" ", tactic)
else:
    print("Search failed.")

Device Handling

RetrievalProver automatically selects cuda if a GPU is available, otherwise falls back to cpu. There is no device constructor argument; override torch.cuda.is_available() or set CUDA_VISIBLE_DEVICES="" in the environment to force CPU execution.

See Also

  • BaseProver — search loop and abstract interface
  • HFProver — local HuggingFace model with whole-proof support
  • ExternalProver — HF Inference API tactic generator

Build docs developers (and LLMs) love