Documentation Index
Fetch the complete documentation index at: https://mintlify.com/intuit-ai-research/REMem/llms.txt
Use this file to discover all available pages before exploring further.
Overview
RAG strategies control how REMem indexes documents and performs retrieval for question answering. Each extraction method can have a corresponding strategy that defines its indexing and QA logic.Strategy Factory Pattern
REMem uses a factory pattern (rag_strategies/factory.py:9-59) to manage strategies:
from typing import Dict, Type
from .base_strategy import RAGStrategy
class RAGStrategyFactory:
"""Factory class for creating RAG strategies based on extraction method."""
_strategies: Dict[str, Type[RAGStrategy]] = {
"openie": PassageTripleStrategy,
"episodic_gist": EpisodicGistStrategy,
"temporal": TemporalStrategy,
}
@classmethod
def create_strategy(cls, extract_method: str, remem_instance) -> RAGStrategy:
"""Create a RAG strategy based on the extraction method."""
if extract_method not in cls._strategies:
raise ValueError(
f"Unsupported extraction method: {extract_method}. "
f"Supported methods: {list(cls._strategies.keys())}"
)
strategy_class = cls._strategies[extract_method]
return strategy_class(remem_instance)
@classmethod
def register_strategy(cls, extract_method: str, strategy_class: Type[RAGStrategy]):
"""Register a new strategy for an extraction method."""
cls._strategies[extract_method] = strategy_class
@classmethod
def get_supported_methods(cls) -> list[str]:
"""Get list of supported extraction methods."""
return list(cls._strategies.keys())
Base Strategy Interface
All strategies inherit fromRAGStrategy (rag_strategies/base_strategy.py:7-146):
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from remem.utils.misc_utils import QuerySolution
class RAGStrategy(ABC):
"""Abstract base class for RAG strategies."""
def __init__(self, remem_instance):
self.remem = remem_instance
@abstractmethod
def index(self, docs: List[str]) -> None:
"""Index documents using the specific strategy."""
pass
@abstractmethod
def rag_for_qa(
self,
queries: Union[List[str], List[QuerySolution]],
num_to_retrieve: int = 10,
gold_answers: Optional[List[List[str]]] = None,
gold_docs: Optional[List[List[str]]] = None,
metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
question_metadata: Optional[List[Dict]] = None,
**kwargs,
) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
"""Perform RAG-based QA using the specific strategy."""
pass
def retrieve_each_query(self, query: str, return_chunk: Optional[str] = None):
"""Retrieve documents for a single query. Can be overridden."""
return self.remem.retrieve_each_query(query, return_chunk)
def get_graph_info(self) -> Dict:
"""Get statistics about the graph structure."""
# Implementation in base class
pass
Built-in Strategies
DefaultRAGStrategy (OpenIE)
Fromrag_strategies/default_strategy.py:8-60:
class DefaultRAGStrategy(RAGStrategy):
"""Default RAG strategy for standard OpenIE-based extraction."""
def index(self, docs: List[str]) -> None:
"""Index documents using standard OpenIE approach."""
self.remem.index_original(docs)
def rag_for_qa(
self,
queries: Union[List[str], List[QuerySolution]],
num_to_retrieve: int = 5,
gold_answers: Optional[List[List[str]]] = None,
gold_docs: Optional[List[List[str]]] = None,
metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
question_metadata: Optional[List[Dict]] = None,
to_save: bool = True,
**kwargs,
) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
"""Perform QA using standard RAG approach."""
# Retrieve documents
if not isinstance(queries[0], QuerySolution):
query_solutions = self.remem.retrieve(queries=queries)
else:
query_solutions = queries
# Set metadata
if question_metadata is not None:
for idx, q in enumerate(query_solutions):
q.question_metadata = question_metadata[idx]
# Evaluate retrieval
qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
gold_answers, gold_docs, metrics
)
overall_retrieval_metrics = self.remem.evaluate_retrieval(
gold_docs, query_solutions, retrieval_evaluators
)
# Perform QA
query_solutions, all_response_message, all_metadata = self.remem.qa(query_solutions)
# Evaluate QA
overall_qa_metrics = self.remem.evaluate_qa(
gold_answers, qa_evaluators, query_solutions, question_metadata
)
# Save results
if to_save:
self.remem.save_rag_results(
gold_answers, gold_docs, query_solutions,
overall_qa_metrics, overall_retrieval_metrics
)
return query_solutions, all_response_message, all_metadata, \
overall_retrieval_metrics, overall_qa_metrics
EpisodicGistStrategy
Advanced strategy with gist-based retrieval (rag_strategies/episodic_gist_strategy.py:21-1013):
class EpisodicGistStrategy(RAGStrategy):
"""Strategy for episodic gist-based extraction and retrieval."""
def __init__(self, remem_instance):
super().__init__(remem_instance)
self.concatenate_gists_per_chunk = remem_instance.global_config.concatenate_gists_per_chunk
self.split_verbatim_per_chunk = remem_instance.global_config.split_verbatim_per_chunk
def index(self, docs: List) -> None:
"""Index with episodic gist extraction."""
# Add chunk embeddings
self.remem.add_chunk_and_embeddings(docs)
chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
# Load or perform extraction
all_openie_info, chunk_keys_to_process = self.remem.load_existing_openie(
chunk_dict.keys()
)
if len(chunk_keys_to_process) > 0:
ie_results = self.remem.openie.batch_openie(new_openie_rows)
self.merge_gist_extraction_results(all_openie_info, chunk_keys_to_process, ie_results)
# Build episodic embedding stores
element_to_encode = defaultdict(list)
for chunk in episode_results_dict.values():
# Process verbatim, gists, facts, entities
# ...
# Construct graph with gist->fact and verbatim->gist edges
self._augment_episodic_graph()
self.remem.save_igraph()
def rag_for_qa(self, queries, **kwargs):
"""Perform QA with parallel processing and per-sample evaluation."""
# Support for:
# - Parallel query processing
# - Per-sample saving/loading
# - Gist-based retrieval
# - Agent-based reasoning
# ...
Creating a Custom Strategy
1. Define Your Strategy Class
# src/remem/rag_strategies/my_custom_strategy.py
from typing import Dict, List, Optional, Tuple, Union
from .base_strategy import RAGStrategy
from remem.utils.misc_utils import QuerySolution
class MyCustomStrategy(RAGStrategy):
"""Custom RAG strategy with specialized retrieval logic."""
def __init__(self, remem_instance):
super().__init__(remem_instance)
# Initialize strategy-specific parameters
self.custom_param = remem_instance.global_config.custom_param
def index(self, docs: List[str]) -> None:
"""Custom indexing logic."""
# 1. Add chunks and embeddings
self.remem.add_chunk_and_embeddings(docs)
# 2. Run extraction
chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
ie_results = self.remem.openie.batch_openie(chunk_dict)
# 3. Build custom data structures
self._build_custom_index(ie_results)
# 4. Construct graph
self._augment_custom_graph()
self.remem.save_igraph()
def rag_for_qa(
self,
queries: Union[List[str], List[QuerySolution]],
num_to_retrieve: int = 10,
gold_answers: Optional[List[List[str]]] = None,
gold_docs: Optional[List[List[str]]] = None,
metrics: Tuple[str, ...] = ("qa_em", "qa_f1"),
**kwargs,
) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
"""Custom QA logic."""
query_solutions = []
for query in queries:
# 1. Custom retrieval
docs, scores = self._custom_retrieve(query, num_to_retrieve)
# 2. Create QuerySolution
query_solution = QuerySolution(
question=query,
docs=docs,
doc_scores=scores,
)
query_solutions.append(query_solution)
# 3. Generate answers
query_solutions, responses, metadata = self.remem.qa(query_solutions)
# 4. Evaluate
qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
gold_answers, gold_docs, metrics
)
overall_qa_metrics = self.remem.evaluate_qa(
gold_answers, qa_evaluators, query_solutions, None
)
return query_solutions, responses, metadata, {}, overall_qa_metrics
def _custom_retrieve(self, query: str, k: int):
"""Implement custom retrieval logic."""
# Example: Combine semantic search with custom ranking
# 1. Get initial candidates
candidates = self.remem.chunk_embedding_store.search(query, k=k*2)
# 2. Apply custom reranking
reranked = self._custom_rerank(query, candidates)
# 3. Return top-k
return reranked[:k]
def _build_custom_index(self, ie_results):
"""Build strategy-specific indices."""
# Example: Build entity co-occurrence matrix
pass
def _augment_custom_graph(self):
"""Construct custom graph structure."""
# Example: Add weighted edges based on custom similarity
pass
2. Register Your Strategy
# In your application code or remem/__init__.py
from remem.rag_strategies.factory import RAGStrategyFactory
from remem.rag_strategies.my_custom_strategy import MyCustomStrategy
# Register the strategy
RAGStrategyFactory.register_strategy("my_custom", MyCustomStrategy)
3. Use Your Strategy
from remem.remem import ReMem
from remem.utils.config_utils import BaseConfig
config = BaseConfig(
dataset="test",
extract_method="my_custom", # Must match registered name
llm_name="gpt-4o-mini",
custom_param="value", # Strategy-specific params
)
rag = ReMem(global_config=config)
docs = ["Document 1", "Document 2"]
rag.index(docs)
queries = ["What is in the documents?"]
solutions, responses, meta, ret_metrics, qa_metrics = rag.rag_for_qa(
queries=queries,
gold_answers=[["Answer"]],
)
Advanced: Multi-Step Retrieval
Example fromEpisodicGistStrategy:
def _rag_each_query(self, remem, query, return_chunk="gists", **kwargs):
"""Multi-step retrieval with gist-based exploration."""
# Step 1: Initial gist retrieval
gist_results = remem.episodic_embedding_stores["gists"].search(
query, k=20
)
# Step 2: Expand via graph
expanded_facts = []
for gist_id in gist_results:
# Find connected facts in graph
neighbors = remem.graph.neighbors(gist_id)
fact_neighbors = [n for n in neighbors if n.startswith("facts-")]
expanded_facts.extend(fact_neighbors)
# Step 3: Rerank and return
if return_chunk == "verbatim":
# Map back to verbatim chunks
verbatim_ids = self._map_to_verbatim(gist_results)
return verbatim_ids
else:
return gist_results
Strategy-Specific Configuration
Add custom config fields:from dataclasses import dataclass
from remem.utils.config_utils import BaseConfig
@dataclass
class MyCustomConfig(BaseConfig):
# Strategy-specific fields
custom_rerank_weight: float = 0.5
custom_expansion_hops: int = 2
custom_threshold: float = 0.7
class MyCustomStrategy(RAGStrategy):
def __init__(self, remem_instance):
super().__init__(remem_instance)
self.rerank_weight = remem_instance.global_config.custom_rerank_weight
self.expansion_hops = remem_instance.global_config.custom_expansion_hops
Parallel Processing
FromEpisodicGistStrategy.rag_for_qa():
def rag_for_qa(self, queries, parallel=True, max_workers=8, **kwargs):
"""Process queries in parallel."""
if parallel:
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(self._process_single_query, args): idx
for idx, args in enumerate(query_args)
}
for future in tqdm(as_completed(future_to_idx), total=len(queries)):
idx = future_to_idx[future]
query_solutions[idx] = future.result()
else:
# Sequential processing
for idx, query in tqdm(enumerate(queries)):
query_solutions[idx] = self._process_single_query(query)
Helper: Get Graph Info
The base class providesget_graph_info() (base_strategy.py:68-146):
def get_graph_info(self) -> Dict:
"""Get statistics about the graph."""
graph_info = {}
# Count phrase nodes
phrase_nodes = self.remem.phrase_embedding_store.get_all_ids()
graph_info["num_phrase_nodes"] = len(set(phrase_nodes))
# Count passage nodes
passage_nodes = self.remem.chunk_embedding_store.get_all_ids()
graph_info["num_passage_nodes"] = len(set(passage_nodes))
# Count edges
graph_info["num_extracted_edges"] = len(
self.remem.triple_embedding_store.get_all_ids()
)
graph_info["num_total_edges"] = len(self.remem.node_to_node_count)
return graph_info
Next Steps
- Custom Extraction - Define what gets extracted
- Custom Prompts - Control LLM behavior in QA
- Custom Metrics - Evaluate strategy performance
- Architecture - Understand the full pipeline