Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/facebookresearch/LoRe/llms.txt

Use this file to discover all available pages before exploring further.

PRISM is a multi-turn dialogue dataset collected from real users with diverse backgrounds and stated preferences. Unlike Reddit TLDR, where each annotator evaluates standalone summaries, PRISM captures the full conversation history leading up to each preference judgment. This makes it the right benchmark when conversation context and turn-level dynamics matter. LoRe uses the dataset’s own metadata to define a clean seen/unseen user split, and adds cosine regularization to the shared reward basis to prevent overfitting to PRISM-specific dialogue patterns.

Data sources

prepare.py downloads three JSONL files directly from the HuggingFace repository with retry logic to handle rate limits:
files_to_download = [
    ("https://huggingface.co/datasets/HannahRoseKirk/prism-alignment/resolve/main/survey.jsonl",
     "data/prism/survey.jsonl"),
    ("https://huggingface.co/datasets/HannahRoseKirk/prism-alignment/resolve/main/conversations.jsonl",
     "data/prism/conversations.jsonl"),
    ("https://huggingface.co/datasets/HannahRoseKirk/prism-alignment/resolve/main/utterances.jsonl",
     "data/prism/utterances.jsonl"),
]
  • survey.jsonl — user-level metadata: demographics, stated preferences, and system_string
  • conversations.jsonl — dialogue-level structure: turn counts, conversation history, open feedback
  • utterances.jsonl — individual utterance content

Data parsing and schema

prepare.py uses Pydantic models to structure user and dialogue data. Users with num_completed_conversations == 0 are filtered out. Each conversation turn is parsed into chosen (accepted) and rejected response lists:
if utterance["role"] == "user":
    data_dialog[...].turns[turn].user_utterance.append(utterance["content"])
elif utterance["if_chosen"]:
    data_dialog[...].turns[turn].chosen_utterance.append(utterance["content"])
else:
    data_dialog[...].turns[turn].rejected_utterance.append(utterance["content"])
Dialogues where every turn lacks at least one user utterance, one chosen response, and one rejected response are dropped entirely. The cleaned data is saved as:
  • data/prism/prism_data_user.json
  • data/prism/prism_data_dialog.json

Seen/unseen user split

Users are shuffled with seed=123 and split 80/20 into seen and unseen groups. Both seen and unseen users must have more than 5 dialogues to qualify. Each user’s dialogs are then split 50/50 into train and test subsets:
seen_user_ids_init  = user_ids[:int(len(user_ids) * 0.8)]
unseen_user_ids_init = user_ids[int(len(user_ids) * 0.8):]

for user_id in seen_user_ids_init:
    to_choose_from = np.array(data_user[user_id]["dialog_ids"])
    if len(to_choose_from) > 5:
        seen_user_ids.append(user_id)
        np.random.shuffle(to_choose_from)
        train_dialog_ids = np.concatenate(
            (train_dialog_ids, to_choose_from[:int(len(to_choose_from) * 0.5)])
        )
        test_dialog_ids = np.concatenate(
            (test_dialog_ids, to_choose_from[int(len(to_choose_from) * 0.5):])
        )
The split metadata is saved to data/prism/prism_split_ids_50.json, which stores train_dialog_ids, test_dialog_ids, seen_user_ids (mapped to integer labels), and unseen_user_ids.

Chat format conversion

load_prism_comparisons() in prepare.py converts each qualifying turn into a structured entry that accumulates the full dialogue history up to that turn as the prompt:
entry = {
    'data_source': 'prism',
    'prompt': copy.deepcopy(full_dialog),   # all prior turns as chat messages
    'extra_info': {
        'split': 'train' if is_train else 'test',
        'seen': user_id in split_ids["seen_user_ids"],
        'user_id': user_id,
        'dialog_id': dialog_id,
        'turn_nb': turn['turn_nb'],
        'total_turn_nb': data_dialog[dialog_id]["total_turn_nb"],
        'chosen_utterance': chosen_utterance,
        'rejected_utterance': rejected_utterance,
    }
}
The chosen response is appended to full_dialog after each turn so that subsequent turns include the full accepted conversation history. Processed datasets are saved as:
  • data/prism/train.parquet
  • data/prism/test.parquet

Embedding generation

generate-prism-embeddings.py loads the parquet files and runs Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 (using attn_implementation="eager" for PRISM) to embed each conversation:
model = AutoModel.from_pretrained(
    "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2",
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    attn_implementation="eager",
    num_labels=1,
)

# For each entry, embed chosen and rejected conversations:
chosen_conv  = prompt + [{"content": entry["extra_info"]["chosen_utterance"],   "role": "assistant"}]
rejected_conv = prompt + [{"content": entry["extra_info"]["rejected_utterance"], "role": "assistant"}]

tokenized = tokenizer.apply_chat_template(chosen_conv, tokenize=True, return_tensors="pt").to(device)
with torch.no_grad():
    output = model(tokenized)
    embedding = output.last_hidden_state[0, -1].cpu()  # [hidden_dim]
Both chosen_conv_embedding and rejected_conv_embedding are stored back into each entry’s extra_info. The final output is saved via torch.save:
  • data/prism/train_embeddings.pkl
  • data/prism/test_embeddings.pkl

User grouping for training

train_basis.py provides group_embeddings_by_user(), which processes the saved embedding files into four splits based on the seen and split fields in extra_info:
def group_embeddings_by_user(train_embeddings, test_embeddings, device):
    def process_dataset(dataset, seen_value, split_name):
        grouped = defaultdict(lambda: {"embeddings": []})
        for example in dataset:
            extra_info = example.get("extra_info", {})
            if extra_info.get("seen") == seen_value and extra_info.get("split") == split_name:
                user_id = extra_info.get("user_id")
                chosen   = torch.tensor(extra_info["chosen_conv_embedding"],   dtype=torch.float32, device=device)
                rejected = torch.tensor(extra_info["rejected_conv_embedding"], dtype=torch.float32, device=device)
                grouped[user_id]["embeddings"].append(chosen - rejected)
        sorted_grouped = []
        for user_id in sorted(grouped.keys()):
            sorted_grouped.append(torch.stack(grouped[user_id]["embeddings"]))
        return sorted_grouped

    train_seen   = process_dataset(train_embeddings, seen_value=True,  split_name="train")
    train_unseen = process_dataset(train_embeddings, seen_value=False, split_name="train")
    test_seen    = process_dataset(test_embeddings,  seen_value=True,  split_name="test")
    test_unseen  = process_dataset(test_embeddings,  seen_value=False, split_name="test")

    return train_seen, train_unseen, test_seen, test_unseen
The preference feature for each turn is the difference between the chosen and rejected embeddings.

Training configuration

PRISM uses run_regularized() instead of run(), applying cosine similarity regularization to keep the learned basis anchored near the base Skywork reward head:
K_list    = [0, 1, 5, 10, 15, 20, 25, 50]
alpha_list = [1e4]   # cosine regularization strength
V_final   = None     # loaded from Skywork last linear layer
  • K=0 — reference model (base Skywork reward head, no learned basis)
  • K=1 — single Bradley-Terry reward model
  • K=5..50 — LoRe with increasing rank; PRISM benefits from higher ranks than TLDR due to greater user diversity
The regularization penalty is computed inside LoRe_regularized as:
V_norm     = F.normalize(self.V, dim=0)
V_sft_norm = F.normalize(self.V_sft, dim=0)
cos_sim    = (V_norm * V_sft_norm).sum(dim=0)
reg        = torch.mean(1 - cos_sim)
alpha=1e4 applies a strong cosine regularization, with a warmup schedule from step 20% to 80% of training.

RewardBench 2 evaluation

eval_rb2.py evaluates the learned V matrix on the RewardBench 2 benchmark to check that the basis generalizes beyond PRISM and has not overfit to dialogue-specific patterns:
python eval_rb2.py --rm_head "path/to/saved/V_weights.pt"
PRISM uses attn_implementation="eager" rather than flash_attention_2. This is intentional — use the same attention implementation in both generate-prism-embeddings.py and train_basis.py to ensure embedding consistency.

Run commands

1

Install dependencies

pip install -r requirements.txt
2

Prepare the dataset (one-time)

cd LoRe/PRISM
python prepare.py
Downloads the three PRISM JSONL files, parses and validates dialogues, applies the 80/20 user split, and saves train.parquet and test.parquet.
3

Generate embeddings (one-time)

python generate-prism-embeddings.py
Loads the parquet files and encodes every conversation with Skywork-Reward-Llama-3.1-8B-v0.2, saving data/prism/train_embeddings.pkl and data/prism/test_embeddings.pkl.
4

Train the reward model basis

python train_basis.py
Groups embeddings by user, then runs run_regularized() over K_list = [0, 1, 5, 10, 15, 20, 25, 50] with alpha=1e4 cosine regularization.
5

Evaluate on RewardBench 2

python eval_rb2.py --rm_head "path/to/saved/V_weights.pt"
Evaluates the learned reward basis on RewardBench 2 to verify it has not overfit to PRISM.
The PRISM dataset is hosted on HuggingFace and may be rate-limited. prepare.py includes retry logic with a 30-second delay between attempts (up to 10 retries per file). If all retries fail, the script exits rather than proceeding with a corrupted or empty file.

Build docs developers (and LLMs) love