Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt

Use this file to discover all available pages before exploring further.

These top-level functions wrap LoRe’s model classes into end-to-end training and evaluation pipelines. run() and run_regularized() sweep over (K, alpha) grids, train a shared reward basis, evaluate seen-user generalization, and measure few-shot accuracy for new users. The lower-level solve* family constructs and trains a single model for a given (K, alpha) pair. learn_multiple_few_shot() wraps PersonalizeBatch for new-user adaptation with a fixed basis.

run()

Full pipeline covering basis training, seen-user evaluation, few-shot adaptation, and unseen-user evaluation across every (K, alpha) combination. Uses solve_regularized (LoRe) internally.
from utils import run

(
    train_acc, seen_unseen_prompt_acc,
    fewshot_train_acc, unseen_acc,
    train_std, seen_unseen_prompt_std,
    fewshot_train_std, unseen_std,
) = run(
    K_list=[0, 1, 2, 3, 4, 5],
    alpha_list=[0],
    V_final=V_final,
    train_features=train_features,
    test_features_sparse=test_features_sparse,
    train_features_unseen=train_features_unseen,
    test_features_sparse_unseen=test_features_sparse_unseen,
    N=1000,
    N_unseen=500,
    device=device,
)

Parameters

K_list
list[int]
required
List of basis sizes to evaluate. K=0 skips basis training and uses V_final directly with unit weights. K=1 trains a single direction (equivalent to Bradley-Terry without personalization). K>=2 uses solve_regularized.
alpha_list
list[float]
required
List of regularization strengths. alpha=0 disables regularization. Results are accumulated in the order alpha is iterated first, then K.
V_final
torch.Tensor
required
Reward head weights from the pretrained SFT reward model, shape [num_features, 1]. Used as both the regularization anchor and as the full reward basis when K=0.
train_features
list[Tensor]
required
Per-user preference pairs for seen users. List of N tensors, each shape [m_i, F].
test_features_sparse
list[Tensor]
required
Held-out preference pairs for seen users on unseen prompts. Same list structure as train_features.
train_features_unseen
list[Tensor]
required
Few-shot preference pairs for unseen users, used to learn w with fixed V_joint.
test_features_sparse_unseen
list[Tensor]
required
Held-out preference pairs for unseen users on unseen prompts.
N
int
required
Number of seen (training) users.
N_unseen
int
required
Number of unseen (test) users.
device
torch.device
required
Target device, e.g., torch.device("cuda:0").

Return value

Returns a tuple of 8 numpy arrays. All arrays are ordered by the sweep index (outer loop: alpha, inner loop: K), so the length is len(alpha_list) * len(K_list).
train_acc
ndarray
Mean accuracy of seen users on their training preference pairs.
seen_unseen_prompt_acc
ndarray
Mean accuracy of seen users on held-out (unseen) prompts.
fewshot_train_acc
ndarray
Mean accuracy of unseen users on their few-shot training pairs after adapting w.
unseen_acc
ndarray
Mean accuracy of unseen users on fully held-out prompts after few-shot adaptation.
train_std / seen_unseen_prompt_std / fewshot_train_std / unseen_std
ndarray
Standard deviations of the above, scaled by 0.25 to represent standard error for N=4 sub-groups.
run() uses num_iterations=1000 and learning_rate=0.5 for solve_regularized, and num_iterations=500, learning_rate=0.1 for learn_multiple_few_shot. Use run_regularized() when you need LoRe_regularized with cosine regularization and num_iterations=20000.

run_regularized()

Same pipeline as run() but uses solve_regularized_simplex (LoRe_regularized) for basis training. Saves V and W checkpoints to disk for each (K, alpha) pair and runs for 20,000 iterations.
from utils import run_regularized

(
    train_acc, seen_unseen_prompt_acc,
    fewshot_train_acc, unseen_acc,
    train_std, seen_unseen_prompt_std,
    fewshot_train_std, unseen_std,
) = run_regularized(
    K_list=[0, 1, 5, 10, 15, 20, 25, 50],
    alpha_list=[1e4],
    V_final=V_final,
    train_features=train_seen,
    test_features_sparse=test_seen,
    train_features_unseen=train_unseen,
    test_features_sparse_unseen=test_unseen,
    N=N,
    N_unseen=N_unseen,
    device=device,
)

Parameters

Parameters are identical to run(). See above.

Key differences from run()

Aspectrun()run_regularized()
Basis solversolve_regularized (LoRe)solve_regularized_simplex (LoRe_regularized)
RegularizationL2 toward V_sftCosine similarity toward V_sft
num_iterations100020000
PruningNoYes — drops basis vectors with max weight < 1e-2
CheckpointingNoYes — saves V and W per (K, alpha)
run_regularized() saves checkpoints to /checkpoint/ai_society/representative_llms/data/lore/community/. Ensure this path exists and is writable, or patch the paths in the source before running.

run_few_shot_vary_shots()

Evaluates few-shot accuracy across multiple shot counts. For each (alpha, K, shots) triple, runs trials independent random sub-samples and reports mean and standard deviation.
from utils import run_few_shot_vary_shots

fewshot_means, fewshot_stds, unseen_means, unseen_stds = run_few_shot_vary_shots(
    trials=10,
    alpha_list=[0],
    K_list=[5, 10],
    num_shots=[1, 5, 10, 20, 50],
    train_features=train_features,
    train_features_unseen=train_features_unseen,
    test_features_sparse_unseen=test_features_sparse_unseen,
    V_final=V_final,
    N=1000,
    N_unseen=500,
    device=device,
)

Parameters

trials
int
required
Number of random sub-sampling trials per (K, shots) combination. Results are averaged over trials to reduce variance from random shot selection.
alpha_list
list[float]
required
Regularization strengths. Same semantics as run().
K_list
list[int]
required
Basis sizes to sweep.
num_shots
list[int]
required
Shot counts to evaluate. Each value is passed to sample_shots() to create a sub-sampled few-shot set.
train_features
list[Tensor]
required
Full seen-user training features (used to train the basis).
train_features_unseen
list[Tensor]
required
Full unseen-user preference pairs. Sub-sampled per trial using sample_shots().
test_features_sparse_unseen
list[Tensor]
required
Held-out prompts for unseen users, evaluated after few-shot w adaptation.
V_final
torch.Tensor
required
Reference reward head weights. Used when K=0.
N
int
required
Number of seen users.
N_unseen
int
required
Number of unseen users.
device
torch.device
required
Target device.

Return value

fewshot_train_means
list[float]
Mean few-shot training accuracy per shot count in num_shots.
fewshot_train_stds
list[float]
Standard deviation of few-shot training accuracy across trials.
unseen_means
list[float]
Mean accuracy on unseen prompts per shot count.
unseen_stds
list[float]
Standard deviation of unseen accuracy across trials.

solve()

Trains a reward basis without regularization.
from utils import solve

W, V = solve(
    train_features=train_features,
    num_basis_vectors=5,
    num_iterations=1000,
    learning_rate=0.01,
)

Parameters

train_features
list[Tensor]
required
Per-user preference pairs. List of N tensors, each [m_i, F].
num_basis_vectors
int
required
Rank K of the factorization.
num_iterations
int
default:"1000"
Adam gradient steps.
learning_rate
float
default:"0.01"
Adam learning rate.

Return value

W
Tensor
Softmax-normalized user weight matrix, shape [N, K].
V
Tensor
Detached reward basis matrix, shape [F, K].

solve_regularized()

Trains a basis with L2 regularization toward V_sft using LoRe.
from utils import solve_regularized

W, V = solve_regularized(
    V_sft=V_final,
    alpha=0.01,
    train_features=train_features,
    num_basis_vectors=5,
    num_iterations=1000,
    learning_rate=0.5,
)

Parameters

V_sft
torch.Tensor
required
Reference SFT reward direction, shape [F] or [F, 1].
alpha
float
required
L2 regularization strength. alpha=0 disables regularization.
train_features
list[Tensor]
required
Per-user preference pairs.
num_basis_vectors
int
required
Rank K.
num_iterations
int
default:"1000"
Adam gradient steps.
learning_rate
float
default:"0.01"
Adam learning rate.
Returns (W, V.detach()) with the same shapes as solve().

solve_regularized_simplex()

Trains a basis with cosine similarity regularization using LoRe_regularized. Applies the warmup schedule and basis pruning described in the LoRe_regularized class.
from utils import solve_regularized_simplex

W, V = solve_regularized_simplex(
    V_sft=V_final,
    alpha=1e4,
    train_features=train_features,
    num_basis_vectors=25,
    num_iterations=20000,
    learning_rate=0.5,
)

Parameters

V_sft
torch.Tensor
required
Reference SFT reward direction.
alpha
float
required
Maximum cosine regularization coefficient (reached after warmup).
train_features
list[Tensor]
required
Per-user preference pairs.
num_basis_vectors
int
required
Initial rank K before pruning.
num_iterations
int
default:"1000"
Adam gradient steps.
learning_rate
float
default:"0.01"
Adam learning rate.
num_features is hard-coded to 4096 inside solve_regularized_simplex. If your embeddings have a different dimension, instantiate LoRe_regularized directly with the correct num_features.
Returns (W_kept, V_kept.detach()) where K_kept <= K after pruning.

learn_multiple_few_shot()

Adapts a fixed reward basis V to new users by learning per-user weight vectors w. Wraps PersonalizeBatch.
from utils import learn_multiple_few_shot

W_few_shot = learn_multiple_few_shot(
    train_features=train_features_unseen,
    V=V_joint.detach(),
    num_iterations=500,
    learning_rate=0.1,
)
# W_few_shot: list of N_unseen tensors, each shape [K]

Parameters

train_features
list[Tensor]
required
Few-shot preference pairs for new users. List of N_unseen tensors, each [shots, F].
V
torch.Tensor
required
Fixed reward basis from a prior solve* call, shape [F, K]. Pass .detach() to prevent accidental gradient computation.
num_iterations
int
default:"1000"
Adam gradient steps for w adaptation.
learning_rate
float
default:"0.01"
Adam learning rate.

Return value

W
list[Tensor]
List of N_unseen softmax-normalized weight vectors, each shape [K]. Detached from the computation graph.

Build docs developers (and LLMs) love