Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt

Use this file to discover all available pages before exploring further.

After training the shared basis V, a key question is how quickly a brand-new user can be personalized with only a handful of preference examples. vary_fewshot.py answers this by sweeping over a range of shot counts, re-running few-shot weight learning for each count across multiple random trials, and reporting mean and standard deviation of accuracy. The result is a learning curve that reveals both the adaptation speed (low shots) and the ceiling accuracy (high shots) for each basis rank K.

Which datasets support vary_fewshot.py

RedditTLDR

RedditTLDR/vary_fewshot.py — standalone script with its own embedding loading and user split logic. Sweeps shots from 5 to 50, 20 trials each.

PRISM / PersonalLLM

Neither PRISM nor PersonalLLM has a vary_fewshot.py. Their few-shot evaluation runs inside train_basis.py using fixed shot counts.

run_few_shot_vary_shots

The core function lives in utils.py:
def run_few_shot_vary_shots(
    trials, alpha_list, K_list, num_shots,
    train_features, train_features_unseen, test_features_sparse_unseen,
    V_final, N, N_unseen, device
):
    for alpha in alpha_list:
        for K in K_list:
            # Train shared basis on seen users
            if K == 0:
                V_joint = V_final
                W_joint = [torch.tensor([1.0]).to(device) for i in range(N)]
            else:
                W_joint, V_joint = solve_regularized(
                    V_final, alpha, train_features, K,
                    num_iterations=500, learning_rate=0.5
                )

            few_shot_train_accuracies_few_shot_means = []
            unseen_user_unseen_prompts_accuracies_few_shot_means = []

            for shots in num_shots:
                few_shot_train_accuracies_few_shot = []
                unseen_user_unseen_prompts_accuracies_few_shot = []

                for _ in range(trials):
                    # Sample exactly `shots` preference pairs per unseen user
                    train_features_unseen_shots = sample_shots(train_features_unseen, shots)

                    # Fit per-user weights on the sampled shots
                    if K <= 1:
                        W_few_shot = [torch.tensor([1.0]).to(device) for i in range(N_unseen)]
                    else:
                        W_few_shot = learn_multiple_few_shot(
                            train_features_unseen_shots, V_joint.detach(),
                            num_iterations=500, learning_rate=0.1
                        )

                    # Evaluate on the shot prompts
                    accuracies_few_shot_train = eval_multiple(
                        W_few_shot,
                        [V_joint.detach() for i in range(N_unseen)],
                        train_features_unseen_shots
                    )
                    few_shot_train_accuracies_few_shot.append(np.mean(accuracies_few_shot_train))

                    # Evaluate on fully held-out unseen prompts
                    accuracies_unseen = eval_multiple(
                        W_few_shot,
                        [V_joint.detach() for i in range(N_unseen)],
                        test_features_sparse_unseen
                    )
                    unseen_user_unseen_prompts_accuracies_few_shot.append(np.mean(accuracies_unseen))

                few_shot_train_accuracies_few_shot_means.append(
                    np.mean(few_shot_train_accuracies_few_shot)
                )
                unseen_user_unseen_prompts_accuracies_few_shot_means.append(
                    np.mean(unseen_user_unseen_prompts_accuracies_few_shot)
                )

    return (
        few_shot_train_accuracies_few_shot_means,
        few_shot_train_accuracies_few_shot_stds,
        unseen_user_unseen_prompts_accuracies_few_shot_means,
        unseen_user_unseen_prompts_accuracies_few_shot_stds,
    )

sample_shots

Random shot sampling is done without replacement by shuffling and truncating each user’s tensor:
def sample_shots(train_features_unseen, shots):
    sampled_features = [
        tensor[torch.randperm(tensor.size(0))[:shots]]
        for tensor in train_features_unseen
    ]
    return sampled_features
Each element of train_features_unseen is a [M, hidden_dim] tensor for one user. The function returns a list of [shots, hidden_dim] tensors.

RedditTLDR vary_fewshot configuration

K_list     = [5]
alpha_list = [0]
trials     = 20
num_shots  = [5 * (i + 1) for i in range(10)]  # [5, 10, 15, ..., 50]

few_shot_train_means, few_shot_train_stds, \
unseen_unseen_means, unseen_unseen_stds = run_few_shot_vary_shots(
    trials, alpha_list, K_list, num_shots,
    train_features, train_features_unseen, test_features_sparse_unseen,
    V_final, N, N_unseen, device
)
The RedditTLDR vary_fewshot.py fixes K=5 (the highest rank from train_basis.py) and sweeps shots from 5 to 50 in steps of 5. To compare multiple ranks, extend K_list.

Running the script

cd LoRe/RedditTLDR
python vary_fewshot.py
vary_fewshot.py loads the reward model at startup to extract V_final. Make sure prepare.py has already been run and tldr_embeddings_train.pkl / tldr_embeddings_val.pkl exist in the RedditTLDR/ directory before executing.

Interpreting results

MetricMeaning
few_shot_train_accuracies_few_shot_meansAccuracy on the shot examples themselves — measures overfitting to few-shot data
unseen_user_unseen_prompts_accuracies_few_shot_meansAccuracy on held-out test prompts — the true personalization metric
  • Higher K → higher ceiling: a richer basis captures more preference diversity, so the accuracy on unseen prompts improves as K grows — provided enough shots are available.
  • Lower shots → faster convergence with LoRe: even with 5–10 shots, a well-trained K≥2 basis outperforms K=1 (BT) because users only need to locate themselves within the pre-learned preference space.
  • K=0 and K=1 are flat: W_few_shot is forced to a constant [1.0] tensor for these cases, so shot count has no effect.

Build docs developers (and LLMs) love