Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/intuit-ai-research/REMem/llms.txt

Use this file to discover all available pages before exploring further.

Overview

RAG strategies control how REMem indexes documents and performs retrieval for question answering. Each extraction method can have a corresponding strategy that defines its indexing and QA logic.

Strategy Factory Pattern

REMem uses a factory pattern (rag_strategies/factory.py:9-59) to manage strategies:
from typing import Dict, Type
from .base_strategy import RAGStrategy

class RAGStrategyFactory:
    """Factory class for creating RAG strategies based on extraction method."""
    
    _strategies: Dict[str, Type[RAGStrategy]] = {
        "openie": PassageTripleStrategy,
        "episodic_gist": EpisodicGistStrategy,
        "temporal": TemporalStrategy,
    }
    
    @classmethod
    def create_strategy(cls, extract_method: str, remem_instance) -> RAGStrategy:
        """Create a RAG strategy based on the extraction method."""
        if extract_method not in cls._strategies:
            raise ValueError(
                f"Unsupported extraction method: {extract_method}. "
                f"Supported methods: {list(cls._strategies.keys())}"
            )
        
        strategy_class = cls._strategies[extract_method]
        return strategy_class(remem_instance)
    
    @classmethod
    def register_strategy(cls, extract_method: str, strategy_class: Type[RAGStrategy]):
        """Register a new strategy for an extraction method."""
        cls._strategies[extract_method] = strategy_class
    
    @classmethod
    def get_supported_methods(cls) -> list[str]:
        """Get list of supported extraction methods."""
        return list(cls._strategies.keys())

Base Strategy Interface

All strategies inherit from RAGStrategy (rag_strategies/base_strategy.py:7-146):
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from remem.utils.misc_utils import QuerySolution

class RAGStrategy(ABC):
    """Abstract base class for RAG strategies."""
    
    def __init__(self, remem_instance):
        self.remem = remem_instance
    
    @abstractmethod
    def index(self, docs: List[str]) -> None:
        """Index documents using the specific strategy."""
        pass
    
    @abstractmethod
    def rag_for_qa(
        self,
        queries: Union[List[str], List[QuerySolution]],
        num_to_retrieve: int = 10,
        gold_answers: Optional[List[List[str]]] = None,
        gold_docs: Optional[List[List[str]]] = None,
        metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
        question_metadata: Optional[List[Dict]] = None,
        **kwargs,
    ) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
        """Perform RAG-based QA using the specific strategy."""
        pass
    
    def retrieve_each_query(self, query: str, return_chunk: Optional[str] = None):
        """Retrieve documents for a single query. Can be overridden."""
        return self.remem.retrieve_each_query(query, return_chunk)
    
    def get_graph_info(self) -> Dict:
        """Get statistics about the graph structure."""
        # Implementation in base class
        pass

Built-in Strategies

DefaultRAGStrategy (OpenIE)

From rag_strategies/default_strategy.py:8-60:
class DefaultRAGStrategy(RAGStrategy):
    """Default RAG strategy for standard OpenIE-based extraction."""
    
    def index(self, docs: List[str]) -> None:
        """Index documents using standard OpenIE approach."""
        self.remem.index_original(docs)
    
    def rag_for_qa(
        self,
        queries: Union[List[str], List[QuerySolution]],
        num_to_retrieve: int = 5,
        gold_answers: Optional[List[List[str]]] = None,
        gold_docs: Optional[List[List[str]]] = None,
        metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
        question_metadata: Optional[List[Dict]] = None,
        to_save: bool = True,
        **kwargs,
    ) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
        """Perform QA using standard RAG approach."""
        # Retrieve documents
        if not isinstance(queries[0], QuerySolution):
            query_solutions = self.remem.retrieve(queries=queries)
        else:
            query_solutions = queries
        
        # Set metadata
        if question_metadata is not None:
            for idx, q in enumerate(query_solutions):
                q.question_metadata = question_metadata[idx]
        
        # Evaluate retrieval
        qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
            gold_answers, gold_docs, metrics
        )
        overall_retrieval_metrics = self.remem.evaluate_retrieval(
            gold_docs, query_solutions, retrieval_evaluators
        )
        
        # Perform QA
        query_solutions, all_response_message, all_metadata = self.remem.qa(query_solutions)
        
        # Evaluate QA
        overall_qa_metrics = self.remem.evaluate_qa(
            gold_answers, qa_evaluators, query_solutions, question_metadata
        )
        
        # Save results
        if to_save:
            self.remem.save_rag_results(
                gold_answers, gold_docs, query_solutions, 
                overall_qa_metrics, overall_retrieval_metrics
            )
        
        return query_solutions, all_response_message, all_metadata, \
               overall_retrieval_metrics, overall_qa_metrics

EpisodicGistStrategy

Advanced strategy with gist-based retrieval (rag_strategies/episodic_gist_strategy.py:21-1013):
class EpisodicGistStrategy(RAGStrategy):
    """Strategy for episodic gist-based extraction and retrieval."""
    
    def __init__(self, remem_instance):
        super().__init__(remem_instance)
        self.concatenate_gists_per_chunk = remem_instance.global_config.concatenate_gists_per_chunk
        self.split_verbatim_per_chunk = remem_instance.global_config.split_verbatim_per_chunk
    
    def index(self, docs: List) -> None:
        """Index with episodic gist extraction."""
        # Add chunk embeddings
        self.remem.add_chunk_and_embeddings(docs)
        chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
        
        # Load or perform extraction
        all_openie_info, chunk_keys_to_process = self.remem.load_existing_openie(
            chunk_dict.keys()
        )
        
        if len(chunk_keys_to_process) > 0:
            ie_results = self.remem.openie.batch_openie(new_openie_rows)
            self.merge_gist_extraction_results(all_openie_info, chunk_keys_to_process, ie_results)
        
        # Build episodic embedding stores
        element_to_encode = defaultdict(list)
        for chunk in episode_results_dict.values():
            # Process verbatim, gists, facts, entities
            # ...
        
        # Construct graph with gist->fact and verbatim->gist edges
        self._augment_episodic_graph()
        self.remem.save_igraph()
    
    def rag_for_qa(self, queries, **kwargs):
        """Perform QA with parallel processing and per-sample evaluation."""
        # Support for:
        # - Parallel query processing
        # - Per-sample saving/loading
        # - Gist-based retrieval
        # - Agent-based reasoning
        # ...

Creating a Custom Strategy

1. Define Your Strategy Class

# src/remem/rag_strategies/my_custom_strategy.py
from typing import Dict, List, Optional, Tuple, Union
from .base_strategy import RAGStrategy
from remem.utils.misc_utils import QuerySolution

class MyCustomStrategy(RAGStrategy):
    """Custom RAG strategy with specialized retrieval logic."""
    
    def __init__(self, remem_instance):
        super().__init__(remem_instance)
        # Initialize strategy-specific parameters
        self.custom_param = remem_instance.global_config.custom_param
    
    def index(self, docs: List[str]) -> None:
        """Custom indexing logic."""
        # 1. Add chunks and embeddings
        self.remem.add_chunk_and_embeddings(docs)
        
        # 2. Run extraction
        chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
        ie_results = self.remem.openie.batch_openie(chunk_dict)
        
        # 3. Build custom data structures
        self._build_custom_index(ie_results)
        
        # 4. Construct graph
        self._augment_custom_graph()
        self.remem.save_igraph()
    
    def rag_for_qa(
        self,
        queries: Union[List[str], List[QuerySolution]],
        num_to_retrieve: int = 10,
        gold_answers: Optional[List[List[str]]] = None,
        gold_docs: Optional[List[List[str]]] = None,
        metrics: Tuple[str, ...] = ("qa_em", "qa_f1"),
        **kwargs,
    ) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
        """Custom QA logic."""
        query_solutions = []
        
        for query in queries:
            # 1. Custom retrieval
            docs, scores = self._custom_retrieve(query, num_to_retrieve)
            
            # 2. Create QuerySolution
            query_solution = QuerySolution(
                question=query,
                docs=docs,
                doc_scores=scores,
            )
            query_solutions.append(query_solution)
        
        # 3. Generate answers
        query_solutions, responses, metadata = self.remem.qa(query_solutions)
        
        # 4. Evaluate
        qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
            gold_answers, gold_docs, metrics
        )
        overall_qa_metrics = self.remem.evaluate_qa(
            gold_answers, qa_evaluators, query_solutions, None
        )
        
        return query_solutions, responses, metadata, {}, overall_qa_metrics
    
    def _custom_retrieve(self, query: str, k: int):
        """Implement custom retrieval logic."""
        # Example: Combine semantic search with custom ranking
        # 1. Get initial candidates
        candidates = self.remem.chunk_embedding_store.search(query, k=k*2)
        
        # 2. Apply custom reranking
        reranked = self._custom_rerank(query, candidates)
        
        # 3. Return top-k
        return reranked[:k]
    
    def _build_custom_index(self, ie_results):
        """Build strategy-specific indices."""
        # Example: Build entity co-occurrence matrix
        pass
    
    def _augment_custom_graph(self):
        """Construct custom graph structure."""
        # Example: Add weighted edges based on custom similarity
        pass

2. Register Your Strategy

# In your application code or remem/__init__.py
from remem.rag_strategies.factory import RAGStrategyFactory
from remem.rag_strategies.my_custom_strategy import MyCustomStrategy

# Register the strategy
RAGStrategyFactory.register_strategy("my_custom", MyCustomStrategy)

3. Use Your Strategy

from remem.remem import ReMem
from remem.utils.config_utils import BaseConfig

config = BaseConfig(
    dataset="test",
    extract_method="my_custom",  # Must match registered name
    llm_name="gpt-4o-mini",
    custom_param="value",  # Strategy-specific params
)

rag = ReMem(global_config=config)
docs = ["Document 1", "Document 2"]
rag.index(docs)

queries = ["What is in the documents?"]
solutions, responses, meta, ret_metrics, qa_metrics = rag.rag_for_qa(
    queries=queries,
    gold_answers=[["Answer"]],
)

Advanced: Multi-Step Retrieval

Example from EpisodicGistStrategy:
def _rag_each_query(self, remem, query, return_chunk="gists", **kwargs):
    """Multi-step retrieval with gist-based exploration."""
    # Step 1: Initial gist retrieval
    gist_results = remem.episodic_embedding_stores["gists"].search(
        query, k=20
    )
    
    # Step 2: Expand via graph
    expanded_facts = []
    for gist_id in gist_results:
        # Find connected facts in graph
        neighbors = remem.graph.neighbors(gist_id)
        fact_neighbors = [n for n in neighbors if n.startswith("facts-")]
        expanded_facts.extend(fact_neighbors)
    
    # Step 3: Rerank and return
    if return_chunk == "verbatim":
        # Map back to verbatim chunks
        verbatim_ids = self._map_to_verbatim(gist_results)
        return verbatim_ids
    else:
        return gist_results

Strategy-Specific Configuration

Add custom config fields:
from dataclasses import dataclass
from remem.utils.config_utils import BaseConfig

@dataclass
class MyCustomConfig(BaseConfig):
    # Strategy-specific fields
    custom_rerank_weight: float = 0.5
    custom_expansion_hops: int = 2
    custom_threshold: float = 0.7
Use in strategy:
class MyCustomStrategy(RAGStrategy):
    def __init__(self, remem_instance):
        super().__init__(remem_instance)
        self.rerank_weight = remem_instance.global_config.custom_rerank_weight
        self.expansion_hops = remem_instance.global_config.custom_expansion_hops

Parallel Processing

From EpisodicGistStrategy.rag_for_qa():
def rag_for_qa(self, queries, parallel=True, max_workers=8, **kwargs):
    """Process queries in parallel."""
    if parallel:
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_idx = {
                executor.submit(self._process_single_query, args): idx 
                for idx, args in enumerate(query_args)
            }
            
            for future in tqdm(as_completed(future_to_idx), total=len(queries)):
                idx = future_to_idx[future]
                query_solutions[idx] = future.result()
    else:
        # Sequential processing
        for idx, query in tqdm(enumerate(queries)):
            query_solutions[idx] = self._process_single_query(query)

Helper: Get Graph Info

The base class provides get_graph_info() (base_strategy.py:68-146):
def get_graph_info(self) -> Dict:
    """Get statistics about the graph."""
    graph_info = {}
    
    # Count phrase nodes
    phrase_nodes = self.remem.phrase_embedding_store.get_all_ids()
    graph_info["num_phrase_nodes"] = len(set(phrase_nodes))
    
    # Count passage nodes
    passage_nodes = self.remem.chunk_embedding_store.get_all_ids()
    graph_info["num_passage_nodes"] = len(set(passage_nodes))
    
    # Count edges
    graph_info["num_extracted_edges"] = len(
        self.remem.triple_embedding_store.get_all_ids()
    )
    graph_info["num_total_edges"] = len(self.remem.node_to_node_count)
    
    return graph_info

Next Steps

Build docs developers (and LLMs) love