Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt

Use this file to discover all available pages before exploring further.

Traditional reward models produce a single scalar score shared across all users, which loses the diversity of human preferences. LoRe replaces that scalar with a low-rank factorization — a matrix V of shared reward directions and a per-user weight vector w_i — enabling personalized reward signals while keeping the parameter count small.

The core factorization

A conventional Bradley-Terry reward model assigns a scalar reward to a response embedding x:
r(x) = x · v
Here v is a single direction in feature space, learned from aggregated preference data and shared by everyone. Any user-level variation is discarded. LoRe generalizes this to a rank-K factorization:
r_i(x) = x · V · w_i
where:
  • x is the response feature vector (shape [features])
  • V is the shared basis matrix (shape [features × K])
  • w_i is user i’s mixture weight vector (shape [K])
  • K is the rank — the number of independent reward directions
The matrix V captures the K most important axes of reward variation across the user population. Each user’s weight vector w_i then selects a mixture over those axes, personalizing their reward without training a separate model.

Rank as a spectrum of expressiveness

K valueInterpretation
K=0Reference model: V is frozen to the pretrained V_sft; all weights are 1. No personalization.
K=1Equivalent to a standard Bradley-Terry single reward: one direction, all users share it.
K≥2Full LoRe: multiple diverse directions; each user has a distinct mixture.
K=0 and K=1 serve as baselines in experiments. When K <= 1, the run and run_regularized functions assign all users the same all-ones weight vector rather than optimizing anything.

Softmax constraint on W

User weights are stored as raw logits and passed through a softmax before use:
W_row = F.softmax(W_logits, dim=1)  # [C, K] — positive, sums to 1
This places each w_i on the probability simplex: weights are non-negative and sum to 1. The simplex constraint makes w_i interpretable as a mixture distribution over reward basis directions.

Loss function

LoRe uses the Bradley-Terry pairwise loss: given a preference pair where x is the feature difference between the chosen and rejected response, the model should predict a positive score. The loss is the negative log-sigmoid (negative log-likelihood):
logits = (X_cat @ Vw) / 100.0          # scale for numerical stability
nll = -F.logsigmoid(logits).mean()
The division by 100 is a temperature scaling factor that prevents logits from saturating the sigmoid early in training.

Regularization toward the pretrained model

A key risk in learning V from scratch is that it drifts far from the pretrained supervised fine-tuning (SFT) reward model’s final layer V_sft. LoRe adds a cosine similarity regularization to keep learned basis directions aligned:
V_norm     = F.normalize(self.V,     dim=0)
V_sft_norm = F.normalize(self.V_sft, dim=0)
cos_sim    = (V_norm * V_sft_norm).sum(dim=0)   # per-column cosine similarity
reg        = torch.mean(1 - cos_sim)            # 0 = aligned, 1 = orthogonal
The penalty 1 - cos_sim is 0 when a learned column of V is perfectly aligned with the corresponding pretrained column, and grows as the direction rotates away. The scalar alpha controls the regularization strength.
Cosine regularization is direction-aware but scale-invariant — it penalizes rotation away from V_sft without constraining the magnitude of V, which is learned freely.

The LoRe_regularized forward pass

The full forward computation inside LoRe_regularized._forward_from_packed:
def _forward_from_packed(self, X_cat, y, alpha_curr):
    W_row = F.softmax(W_logits, dim=1)    # [C, K] — simplex weights
    Vw    = V_used @ W_row.T              # [features, C]

    logits_all = (X_cat @ Vw) / 100.0    # [N, C]
    logits = logits_all.gather(1, y.unsqueeze(1)).squeeze(1)
    nll = -F.logsigmoid(logits).mean()

    reg = 0.0
    if alpha_curr > 0:
        V_norm     = F.normalize(self.V,     dim=0)
        V_sft_norm = F.normalize(self.V_sft, dim=0)
        cos_sim    = (V_norm * V_sft_norm).sum(dim=0)
        reg        = torch.mean(1 - cos_sim)

    return nll, reg, entropy_loss
The batch X_cat is a concatenation of all users’ preference differences; y carries the user index for each row so the correct column of Vw is selected via gather.

Alternating minimization training

LoRe trains W and V with two separate Adam optimizers, updating them in alternating steps within each iteration:
optimizer_W = optim.Adam([self.W], lr=self.learning_rate)
optimizer_V = optim.Adam([self.V], lr=self.learning_rate)

for step in range(self.num_iterations):
    alpha_curr = self._alpha_at_step(step)

    # Update W: regularization disabled so V's gradient is not involved
    optimizer_W.zero_grad()
    nll_W, _, _ = self._forward_from_packed(X_cat, y, alpha_curr=0.0)
    nll_W.backward()
    optimizer_W.step()

    # Update V: regularization active, scaled by alpha_curr
    optimizer_V.zero_grad()
    nll_V, reg, _ = self._forward_from_packed(X_cat, y, alpha_curr=alpha_curr)
    total_loss_V = nll_V + alpha_curr * reg
    total_loss_V.backward()
    optimizer_V.step()
When updating W, the regularization term is set to zero (alpha_curr=0.0) because there is no reason to penalize W for V’s alignment with V_sft. When updating V, the full regularized loss is used.

The alpha warmup schedule

Regularization is not applied immediately. _alpha_at_step implements a linear warmup between 20% and 80% of total iterations:
def _alpha_at_step(self, step: int) -> float:
    warmup_start = int(0.2 * self.num_iterations)
    warmup_end   = int(0.8 * self.num_iterations)
    if step < warmup_start: return 0.0
    if step >= warmup_end:  return float(self.alpha)
    return float(self.alpha) * (step - warmup_start) / (warmup_end - warmup_start)
This lets V move freely early in training (exploring the loss landscape) before the regularization gradually pulls it toward V_sft.

The solve_regularized_simplex entry point

The top-level function that instantiates LoRe_regularized and returns the trained weights:
def solve_regularized_simplex(V_sft, alpha, train_features, num_basis_vectors,
                               num_iterations=1000, learning_rate=0.01):
    num_classes = len(train_features)
    num_features = 4096
    am = LoRe_regularized(
        V_sft, alpha, num_classes, num_features,
        num_basis_vectors, num_iterations, learning_rate
    )
    W, V = am.train(train_features)
    return W, V.detach()
After training, LoRe_regularized.train also prunes basis directions whose maximum softmax weight across all users falls below 1e-2, keeping only directions that at least one user meaningfully uses.
solve_regularized_simplex hard-codes num_features = 4096, matching the embedding dimension of the reward model backbone used in the paper’s experiments. If you use a different backbone, you must update this value.

Build docs developers (and LLMs) love