Reward functions are the core signal that drives Group Relative Policy Optimization (GRPO). During each training step,Documentation Index
Fetch the complete documentation index at: https://mintlify.com/avnlp/llm-finetuning/llms.txt
Use this file to discover all available pages before exploring further.
GRPOTrainer generates num_generations completions per prompt, calls every registered reward function with the full batch, and normalises the scores within each group to compute relative advantages before updating the policy. This project implements reward functions as subclasses of BaseReward — a thin abstract class that satisfies TRL’s callable interface while giving every reward a stable name for logging.
The BaseReward abstract class
BaseReward lives in src/llm_finetuning/core/reward.py and is the only class you need to subclass to add a new reward.
RewardConfig
RewardConfig is a frozen dataclass with a single field: name. TRL uses reward.__name__ when logging reward scalars to Weights & Biases or TensorBoard. Setting a meaningful name (e.g. "answer_correctness") keeps training runs readable.
Passing rewards to GRPOTrainer
Pass a list of BaseReward instances directly to the reward_funcs argument. TRL accepts any callable, and BaseReward instances satisfy that contract.
How GRPO uses reward scores
For each prompt in the batch,GRPOTrainer generates num_generations completions (controlled by the num_generations hyperparameter in config.yaml). All rewards are called once with the full batch:
num_generations completions that share a prompt) to produce relative advantages. A completion that scores above the group mean receives a positive advantage; one that scores below the mean receives a negative advantage. The policy is updated to increase the probability of high-advantage completions.
The __call__ signature
Every reward must implement this exact signature:
Batch of prompt histories. Each element is a list of chat message dicts
(
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}])
representing the full conversation leading up to the assistant turn.Batch of generated assistant message lists, aligned one-to-one with
prompts.
Each element is a list containing a single dict: [{"role": "assistant", "content": "..."}].
Access the response text with completion[0]["content"].Extra per-row dataset columns forwarded by
GRPOTrainer. For example, if the
dataset has an answer column, it is available as kwargs["answer"] — a list
of ground-truth values, one per completion in the batch.list[float] — one score per completion, in the same order as completions.
completions structure
Each inner list always contains exactly one assistant message dict:
The as_fn() method
as_fn() wraps the reward instance in a plain function with the same signature and __name__:
as_fn() when:
- A framework requires a picklable callable (e.g. multiprocessing-based evaluation harnesses that pickle reward functions before dispatching to worker processes).
- A framework performs a strict type check for
callableorFunctionTypeand rejects class instances.
GRPOTrainer usage, passing the instance directly (AnswerCorrectnessReward()) is preferred.
Implementing a custom format reward
The following example adds a format reward that checks for the presence and correct nesting of<reasoning> and <answer> XML tags. It is adapted from the production implementation in math_reasoning/reward_functions/format/reasoning_tags.py.
Implementing a DeepEval correctness reward
Use this pattern when you want an LLM judge to score completions. Requiresdeepeval and OPENAI_API_KEY.
The built-in
AbstractDeepEvalGEvalRAGReward base class (in
llm_finetuning.core.llm_judges.deepeval) provides automatic retry on
rate-limit errors and bounded concurrency via asyncio.Semaphore. For
production use, prefer subclassing it over the manual approach shown above.Common pitfalls
Not returning list[float]
Not returning list[float]
GRPOTrainer expects exactly list[float] with one entry per completion.
Returning a generator, a NumPy array, or a list of wrong length will raise a
runtime error during training. Always construct a plain Python list and verify
its length equals len(completions) before returning.Wrong completions indexing
Wrong completions indexing
Each element of
completions is a list with one dict. A common mistake is
treating it as a flat dict or indexing into prompts instead.Mutating kwargs in-place
Mutating kwargs in-place
**kwargs values are shared across all reward functions in the same batch.
Modifying a list in kwargs (e.g. kwargs["answer"].pop()) will corrupt the
data seen by subsequent reward functions. Always read from kwargs without
mutating it.Relying on completion order within a group
Relying on completion order within a group
GRPO normalises scores within each group of
num_generations completions.
Your reward should score each completion independently — do not rank or sort
completions inside __call__, as GRPO handles the relative comparison itself.