Documentation Index
Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt
Use this file to discover all available pages before exploring further.
train_basis.py is the central training script for LoRe. It loads the pre-computed embeddings produced by prepare.py, partitions users into seen and unseen groups, and then iterates over a list of basis ranks K_list. For each rank it jointly learns a shared reward basis matrix V (shape [hidden_dim, K]) and per-user weight vectors W (shape [N, K]), evaluates on seen users with unseen prompts, adapts W to unseen users via few-shot learning, and evaluates on unseen users with unseen prompts. The result is a full accuracy profile across all four evaluation settings for every rank K.
What the pipeline does
The core orchestration is handled by run() (or run_regularized() for PRISM):
def run(K_list, alpha_list, V_final, train_features, test_features_sparse,
train_features_unseen, test_features_sparse_unseen, N, N_unseen, device):
for alpha in alpha_list:
for K in K_list:
# 1. Joint basis + weight learning
if K == 0:
V_joint = V_final # fixed reference head
W_joint = [torch.tensor([1.0]).to(device) for i in range(N)]
else:
W_joint, V_joint = solve_regularized(
V_final, alpha, train_features, K,
num_iterations=1000, learning_rate=0.5
)
# 2. Seen-user evaluation (training prompts)
accuracies_train = eval_multiple(
W_joint, [V_joint.detach() for i in range(N)], train_features
)
# 3. Seen-user evaluation (unseen prompts)
accuracies_seen_user_unseen_prompts = eval_multiple(
W_joint, [V_joint.detach() for i in range(N)], test_features_sparse
)
# 4. Few-shot weight learning for unseen users
if K <= 1:
W_few_shot = [torch.tensor([1.0]).to(device) for i in range(N_unseen)]
else:
W_few_shot = learn_multiple_few_shot(
train_features_unseen, V_joint.detach(),
num_iterations=500, learning_rate=0.1
)
# 5. Unseen-user evaluation (unseen prompts)
accuracies_unseen_user_unseen_prompts = eval_multiple(
W_few_shot, [V_joint.detach() for i in range(N_unseen)],
test_features_sparse_unseen
)
The K parameter: basis rank
K controls the rank of the shared reward decomposition:
| K | Interpretation |
|---|
| 0 | Reference model — fixed pre-trained head V_final, single scalar weight per user |
| 1 | Bradley-Terry baseline — single reward direction, one weight per user |
| ≥ 2 | LoRe — K-dimensional basis; each user gets a K-vector weight |
Start with K_list = [0, 1, 2, 3] to quickly compare the reference, BT
baseline, and LoRe at ranks 2 and 3 before running the full sweep.
The alpha parameter: regularization strength
alpha_list controls how strongly V is regularized toward the pre-trained
head direction:
| alpha | Effect |
|---|
| 0 | No regularization — V is free to rotate arbitrarily |
| 1e4 | Strong cosine regularization — V columns stay close to V_final (used by PRISM) |
The LoRe_regularized class ramps the regularization weight from 0 to alpha
over the middle 60% of training steps:
def _alpha_at_step(self, step: int) -> float:
warmup_start = int(0.2 * self.num_iterations)
warmup_end = int(0.8 * self.num_iterations)
if step < warmup_start: return 0.0
if step >= warmup_end: return float(self.alpha)
return float(self.alpha) * (step - warmup_start) / (warmup_end - warmup_start)
All train_basis.py scripts extract the final linear layer of the reward
model to use as V_final:
last_linear_layer = None
for name, module in rm.named_modules():
if isinstance(module, torch.nn.Linear):
last_linear_layer = module
V_final = last_linear_layer.weight[:, 0].to(device).to(torch.float32).reshape(-1, 1)
PRISM sets V_final = None initially and populates it after loading the model,
then passes it to run_regularized().
Dataset-specific commands and defaults
RedditTLDR
PRISM
PersonalLLM
K_list = [0, 1, 2, 3, 4, 5, 6]
alpha_list = [0]
Embeddings are loaded from pickled worker-result files:with open('tldr_embeddings_train.pkl', 'rb') as f:
worker_results_train = pickle.load(f)
with open('tldr_embeddings_val.pkl', 'rb') as f:
worker_results_test = pickle.load(f)
Workers are split 50/50 into seen and unseen user sets. Up to 150
training examples per seen worker and up to 50 per unseen worker are
sampled.cd LoRe/RedditTLDR
python train_basis.py
K_list = [0, 1, 5, 10, 15, 20, 25, 50]
alpha_list = [1e4]
PRISM uses run_regularized() instead of run(). Embeddings are grouped
by user_id and seen/split flags:train_embeddings = torch.load("data/prism/train_embeddings.pkl")
test_embeddings = torch.load("data/prism/test_embeddings.pkl")
train_seen, train_unseen, test_seen, test_unseen = group_embeddings_by_user(
train_embeddings, test_embeddings, device
)
For K≥2, run_regularized() saves both V and W checkpoints to disk:filename = f"PRISM_V_lore_K_{K}_alpha_{alpha}.pt"
torch.save(V_joint, filename)
filename = f"PRISM_W_lore_seen_{K}_{alpha}.pt"
torch.save(W_joint.detach().cpu(), filename)
cd LoRe/PRISM
python train_basis.py
K_list = [0, 1, 2, 3, 4, 5]
alpha_list = [0]
Embeddings are loaded from safetensors and reshaped into per-prompt
feature lists:from safetensors.torch import load_file
embeddings = load_file("train.safetensors")["embeddings"]
features = []
for i in range(num_prompts):
temp = []
for j in range(8):
temp.append(embeddings[i * 8 + j])
features.append(temp)
cd LoRe/PersonalLLM
python train_basis.py
Output
run() and run_regularized() both return eight arrays (means and standard
deviations for four evaluation settings):
| Array | Description |
|---|
train_accuracies_joint | Seen users on training prompts |
seen_user_unseen_prompts_accuracies_joint | Seen users on test prompts |
few_shot_train_accuracies_few_shot | Unseen users on their few-shot prompts |
unseen_user_unseen_prompts_accuracies_few_shot | Unseen users on test prompts |
PRISM’s train_basis.py also generates an accuracy-vs-rank plot:
plt.plot(K_list, seen_user_unseen_prompts_accuracies_joint, marker='o', label="Seen Users")
plt.plot(K_list, unseen_user_unseen_prompts_accuracies_few_shot, marker='o', label="Unseen Users")
plt.xlabel('rank')
plt.ylabel('Accuracies')
plt.title('Generalization Accuracy vs. Rank')
plt.xticks(K_list, labels=["ref" if k==0 else str(k) for k in K_list])
plt.savefig(f'generalization_accuracy_vs_rank_lore_alpha_{alpha}.png', dpi=300, bbox_inches='tight')