These top-level functions wrap LoRe’s model classes into end-to-end training and evaluation pipelines.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt
Use this file to discover all available pages before exploring further.
run() and run_regularized() sweep over (K, alpha) grids, train a shared reward basis, evaluate seen-user generalization, and measure few-shot accuracy for new users. The lower-level solve* family constructs and trains a single model for a given (K, alpha) pair. learn_multiple_few_shot() wraps PersonalizeBatch for new-user adaptation with a fixed basis.
run()
Full pipeline covering basis training, seen-user evaluation, few-shot adaptation, and unseen-user evaluation across every(K, alpha) combination. Uses solve_regularized (LoRe) internally.
Parameters
List of basis sizes to evaluate.
K=0 skips basis training and uses V_final directly with unit weights. K=1 trains a single direction (equivalent to Bradley-Terry without personalization). K>=2 uses solve_regularized.List of regularization strengths.
alpha=0 disables regularization. Results are accumulated in the order alpha is iterated first, then K.Reward head weights from the pretrained SFT reward model, shape
[num_features, 1]. Used as both the regularization anchor and as the full reward basis when K=0.Per-user preference pairs for seen users. List of
N tensors, each shape [m_i, F].Held-out preference pairs for seen users on unseen prompts. Same list structure as
train_features.Few-shot preference pairs for unseen users, used to learn
w with fixed V_joint.Held-out preference pairs for unseen users on unseen prompts.
Number of seen (training) users.
Number of unseen (test) users.
Target device, e.g.,
torch.device("cuda:0").Return value
Returns a tuple of 8 numpy arrays. All arrays are ordered by the sweep index (outer loop:alpha, inner loop: K), so the length is len(alpha_list) * len(K_list).
Mean accuracy of seen users on their training preference pairs.
Mean accuracy of seen users on held-out (unseen) prompts.
Mean accuracy of unseen users on their few-shot training pairs after adapting
w.Mean accuracy of unseen users on fully held-out prompts after few-shot adaptation.
Standard deviations of the above, scaled by
0.25 to represent standard error for N=4 sub-groups.run() uses num_iterations=1000 and learning_rate=0.5 for solve_regularized, and num_iterations=500, learning_rate=0.1 for learn_multiple_few_shot. Use run_regularized() when you need LoRe_regularized with cosine regularization and num_iterations=20000.run_regularized()
Same pipeline asrun() but uses solve_regularized_simplex (LoRe_regularized) for basis training. Saves V and W checkpoints to disk for each (K, alpha) pair and runs for 20,000 iterations.
Parameters
Parameters are identical torun(). See above.
Key differences from run()
| Aspect | run() | run_regularized() |
|---|---|---|
| Basis solver | solve_regularized (LoRe) | solve_regularized_simplex (LoRe_regularized) |
| Regularization | L2 toward V_sft | Cosine similarity toward V_sft |
num_iterations | 1000 | 20000 |
| Pruning | No | Yes — drops basis vectors with max weight < 1e-2 |
| Checkpointing | No | Yes — saves V and W per (K, alpha) |
run_few_shot_vary_shots()
Evaluates few-shot accuracy across multiple shot counts. For each(alpha, K, shots) triple, runs trials independent random sub-samples and reports mean and standard deviation.
Parameters
Number of random sub-sampling trials per
(K, shots) combination. Results are averaged over trials to reduce variance from random shot selection.Regularization strengths. Same semantics as
run().Basis sizes to sweep.
Shot counts to evaluate. Each value is passed to
sample_shots() to create a sub-sampled few-shot set.Full seen-user training features (used to train the basis).
Full unseen-user preference pairs. Sub-sampled per trial using
sample_shots().Held-out prompts for unseen users, evaluated after few-shot
w adaptation.Reference reward head weights. Used when
K=0.Number of seen users.
Number of unseen users.
Target device.
Return value
Mean few-shot training accuracy per shot count in
num_shots.Standard deviation of few-shot training accuracy across trials.
Mean accuracy on unseen prompts per shot count.
Standard deviation of unseen accuracy across trials.
solve()
Trains a reward basis without regularization.Parameters
Per-user preference pairs. List of
N tensors, each [m_i, F].Rank
K of the factorization.Adam gradient steps.
Adam learning rate.
Return value
Softmax-normalized user weight matrix, shape
[N, K].Detached reward basis matrix, shape
[F, K].solve_regularized()
Trains a basis with L2 regularization towardV_sft using LoRe.
Parameters
Reference SFT reward direction, shape
[F] or [F, 1].L2 regularization strength.
alpha=0 disables regularization.Per-user preference pairs.
Rank
K.Adam gradient steps.
Adam learning rate.
(W, V.detach()) with the same shapes as solve().
solve_regularized_simplex()
Trains a basis with cosine similarity regularization usingLoRe_regularized. Applies the warmup schedule and basis pruning described in the LoRe_regularized class.
Parameters
Reference SFT reward direction.
Maximum cosine regularization coefficient (reached after warmup).
Per-user preference pairs.
Initial rank
K before pruning.Adam gradient steps.
Adam learning rate.
num_features is hard-coded to 4096 inside solve_regularized_simplex. If your embeddings have a different dimension, instantiate LoRe_regularized directly with the correct num_features.(W_kept, V_kept.detach()) where K_kept <= K after pruning.
learn_multiple_few_shot()
Adapts a fixed reward basisV to new users by learning per-user weight vectors w. Wraps PersonalizeBatch.
Parameters
Few-shot preference pairs for new users. List of
N_unseen tensors, each [shots, F].Fixed reward basis from a prior
solve* call, shape [F, K]. Pass .detach() to prevent accidental gradient computation.Adam gradient steps for
w adaptation.Adam learning rate.
Return value
List of
N_unseen softmax-normalized weight vectors, each shape [K]. Detached from the computation graph.