Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt

Use this file to discover all available pages before exploring further.

train_basis.py is the central training script for LoRe. It loads the pre-computed embeddings produced by prepare.py, partitions users into seen and unseen groups, and then iterates over a list of basis ranks K_list. For each rank it jointly learns a shared reward basis matrix V (shape [hidden_dim, K]) and per-user weight vectors W (shape [N, K]), evaluates on seen users with unseen prompts, adapts W to unseen users via few-shot learning, and evaluates on unseen users with unseen prompts. The result is a full accuracy profile across all four evaluation settings for every rank K.

What the pipeline does

The core orchestration is handled by run() (or run_regularized() for PRISM):
def run(K_list, alpha_list, V_final, train_features, test_features_sparse,
        train_features_unseen, test_features_sparse_unseen, N, N_unseen, device):
    for alpha in alpha_list:
        for K in K_list:
            # 1. Joint basis + weight learning
            if K == 0:
                V_joint = V_final          # fixed reference head
                W_joint = [torch.tensor([1.0]).to(device) for i in range(N)]
            else:
                W_joint, V_joint = solve_regularized(
                    V_final, alpha, train_features, K,
                    num_iterations=1000, learning_rate=0.5
                )

            # 2. Seen-user evaluation (training prompts)
            accuracies_train = eval_multiple(
                W_joint, [V_joint.detach() for i in range(N)], train_features
            )

            # 3. Seen-user evaluation (unseen prompts)
            accuracies_seen_user_unseen_prompts = eval_multiple(
                W_joint, [V_joint.detach() for i in range(N)], test_features_sparse
            )

            # 4. Few-shot weight learning for unseen users
            if K <= 1:
                W_few_shot = [torch.tensor([1.0]).to(device) for i in range(N_unseen)]
            else:
                W_few_shot = learn_multiple_few_shot(
                    train_features_unseen, V_joint.detach(),
                    num_iterations=500, learning_rate=0.1
                )

            # 5. Unseen-user evaluation (unseen prompts)
            accuracies_unseen_user_unseen_prompts = eval_multiple(
                W_few_shot, [V_joint.detach() for i in range(N_unseen)],
                test_features_sparse_unseen
            )

The K parameter: basis rank

K controls the rank of the shared reward decomposition:
KInterpretation
0Reference model — fixed pre-trained head V_final, single scalar weight per user
1Bradley-Terry baseline — single reward direction, one weight per user
≥ 2LoRe — K-dimensional basis; each user gets a K-vector weight
Start with K_list = [0, 1, 2, 3] to quickly compare the reference, BT baseline, and LoRe at ranks 2 and 3 before running the full sweep.

The alpha parameter: regularization strength

alpha_list controls how strongly V is regularized toward the pre-trained head direction:
alphaEffect
0No regularization — V is free to rotate arbitrarily
1e4Strong cosine regularization — V columns stay close to V_final (used by PRISM)
The LoRe_regularized class ramps the regularization weight from 0 to alpha over the middle 60% of training steps:
def _alpha_at_step(self, step: int) -> float:
    warmup_start = int(0.2 * self.num_iterations)
    warmup_end   = int(0.8 * self.num_iterations)
    if step < warmup_start: return 0.0
    if step >= warmup_end:  return float(self.alpha)
    return float(self.alpha) * (step - warmup_start) / (warmup_end - warmup_start)

Extracting the pre-trained head

All train_basis.py scripts extract the final linear layer of the reward model to use as V_final:
last_linear_layer = None
for name, module in rm.named_modules():
    if isinstance(module, torch.nn.Linear):
        last_linear_layer = module

V_final = last_linear_layer.weight[:, 0].to(device).to(torch.float32).reshape(-1, 1)
PRISM sets V_final = None initially and populates it after loading the model, then passes it to run_regularized().

Dataset-specific commands and defaults

K_list     = [0, 1, 2, 3, 4, 5, 6]
alpha_list = [0]
Embeddings are loaded from pickled worker-result files:
with open('tldr_embeddings_train.pkl', 'rb') as f:
    worker_results_train = pickle.load(f)
with open('tldr_embeddings_val.pkl', 'rb') as f:
    worker_results_test = pickle.load(f)
Workers are split 50/50 into seen and unseen user sets. Up to 150 training examples per seen worker and up to 50 per unseen worker are sampled.
cd LoRe/RedditTLDR
python train_basis.py

Output

run() and run_regularized() both return eight arrays (means and standard deviations for four evaluation settings):
ArrayDescription
train_accuracies_jointSeen users on training prompts
seen_user_unseen_prompts_accuracies_jointSeen users on test prompts
few_shot_train_accuracies_few_shotUnseen users on their few-shot prompts
unseen_user_unseen_prompts_accuracies_few_shotUnseen users on test prompts
PRISM’s train_basis.py also generates an accuracy-vs-rank plot:
plt.plot(K_list, seen_user_unseen_prompts_accuracies_joint,   marker='o', label="Seen Users")
plt.plot(K_list, unseen_user_unseen_prompts_accuracies_few_shot, marker='o', label="Unseen Users")
plt.xlabel('rank')
plt.ylabel('Accuracies')
plt.title('Generalization Accuracy vs. Rank')
plt.xticks(K_list, labels=["ref" if k==0 else str(k) for k in K_list])
plt.savefig(f'generalization_accuracy_vs_rank_lore_alpha_{alpha}.png', dpi=300, bbox_inches='tight')

Build docs developers (and LLMs) love