Skip to main content
SlimeRouter is an optional lightweight HTTP router/proxy that sits in front of SGLang worker servers during rollout and data generation. It adds training-oriented capabilities that are not the main goal of serving-focused routers.

What is SlimeRouter?

SlimeRouter is a small FastAPI service that provides:
  • Worker registration: Registers SGLang HTTP servers into a local pool
  • Load balancing: Routes requests using simple least-inflight load balancing
  • Request proxying: Proxies arbitrary paths to selected workers (e.g. /generate)
  • Health monitoring: Runs periodic health checks and quarantines unhealthy workers
  • Middleware plugins: Supports middleware via --slime-router-middleware-paths for rollout-specific processing (e.g. caching, request/response transforms)
In slime’s architecture, the router is part of the rollout system (“SGLang + router”) that generates samples and pushes them into the data buffer.

How It Is Launched

In distributed training, slime automatically starts a router when --sglang-router-ip is not provided:
  • If --use-slime-router is set, slime starts SlimeRouter
  • Otherwise, slime starts SGLang Model Gateway
# Using SlimeRouter
python train.py \
  --use-slime-router \
  ${MODEL_ARGS[@]}

# SGLang Model Gateway (default)
python train.py \
  ${MODEL_ARGS[@]}

Why We Need SlimeRouter

Unlike production inference, RL rollout needs to capture additional metadata for training:
  • Token-level log probabilities
  • Loss masks
  • Expert routing decisions (for MoE models)
SlimeRouter provides these capabilities through its middleware system and passthrough proxy design.

Radix-Tree Cache

Use this when your rollout pipeline is text-in/text-out and you cannot reliably persist token IDs. If you already control token-in/token-out (e.g. search r1, multiturn VLM examples), you likely don’t need the radix-tree cache.
Text-in text-out interfaces can cause token retokenization mismatches - re-tokenizing text at training time may produce different token sequences than rollout, breaking per-token alignment needed for PPO/GRPO losses. The radix-tree cache solves this transparently:

How It Works

  1. Intercepts text-based requests and tokenizes them
  2. Stores trajectories (text, token IDs, logprobs, loss masks) keyed by text prefix in a radix tree
  3. Uses longest-prefix matching to reuse cached token sequences
  4. Allows insertion of new text continuations as rollout proceeds (multiple trajectories per prompt for GRPO)
  5. Periodically cleans up stale nodes to control memory usage
  6. After rollout finishes, calling /retrieve_from_text returns exact token sequences with aligned metadata

Implementation

The radix tree is a string-based trie structure optimized for prefix matching:
radix_tree.py
class StringRadixTrie:
    """
    String-based Radix Trie for efficient prefix matching and token caching.
    Features:
    - Efficient string prefix matching
    - Token ID caching for matched prefixes
    - Thread-safe operations
    - Weight version tracking
    - Automatic garbage collection based on weight version thresholds
    """
    
    def __init__(self, max_cache_size: int = 10000, gc_threshold_k: int = 5, 
                 tokenizer=None, verbose: bool = False):
        self.max_cache_size = max_cache_size
        self.gc_threshold_k = gc_threshold_k
        self.tokenizer = tokenizer
        self.root = StringTreeNode()
Prefix matching finds the longest cached prefix:
radix_tree.py
def find_longest_prefix(self, text: str) -> MatchResult:
    """Find the longest cached prefix for the given text."""
    matched_tokens = []
    matched_logp = []
    matched_loss_mask = []
    current_node = self.root
    remaining_text = text
    
    while remaining_text:
        # Find best matching child
        for child_node in current_node.children:
            if remaining_text.startswith(child_node.string_key):
                current_node = child_node
                matched_tokens.extend(child_node.token_ids)
                matched_logp.extend(child_node.logp)
                matched_loss_mask.extend(child_node.loss_mask)
                break
    
    return MatchResult(matched_prefix, matched_tokens, matched_logp, 
                      matched_loss_mask, remaining_text, current_node)

Middleware Integration

The radix tree is integrated via middleware that intercepts /generate requests:
radix_tree_middleware.py
class RadixTreeMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        if path != "/generate":
            return await call_next(request)
        
        # Retrieve cached tokens for input text
        input_tokens, input_logprobs, input_loss_mask = \
            self.radix_tree.retrieve_from_text(input_text, return_logprob=True)
        
        # Get response from SGLang
        response = await call_next(request)
        
        # Cache the full trajectory (input + output)
        if "output_token_logprobs" in response_data.get("meta_info", {}):
            full_token_ids = input_tokens + generated_token_ids
            full_logprobs = input_logprobs + generated_token_logprobs
            self.radix_tree.insert(full_text, full_token_ids, full_logprobs, 
                                  full_loss_mask, weight_version)

Use Cases

GRPO with Multiple Trajectories

Multiple samples sharing the same prompt prefix can reuse cached tokens, reducing tokenization overhead and ensuring consistency.

Text-Based Rollout Code

When you have text-based rollout code and want token-level precision without rewriting your pipeline.

Rollout Routing Replay (R3) for MoE

For MoE models, slime supports Rollout Routing Replay (R3): record expert routing decisions during rollout and replay them during training to improve stability.

SGLang Side

SGLang provides expert routing capture via:
  • --enable-return-routed-experts: Server argument to enable routing capture
  • RoutedExpertsCapturer: Captures topk_ids (selected expert IDs) at each MoE layer during forward pass
  • return_routed_experts: Request parameter to retrieve routing data
  • Returns routed_experts in response meta_info - a [seq_len - 1, num_layers, top_k] tensor of expert IDs

Slime Side

Slime consumes the routing data and replays it during training:
# Both flags required to enable R3
--use-slime-router \
--use-rollout-routing-replay
The workflow:
  1. Rollout sends return_routed_experts=True and stores results in sample.rollout_routed_experts
  2. Training calls fill_routing_replay() to load routing data into RoutingReplay objects
  3. During forward pass, recorded routing decisions are replayed instead of recomputed

Why SlimeRouter Is Required

SlimeRouter is needed because the SGLang worker returns routed experts in the response (meta_info.routed_experts), and SlimeRouter preserves this field end-to-end. SGLang Model Gateway may drop this extra metadata when it reconstructs responses with a fixed schema.

Architecture

router.py
class SlimeRouter:
    def __init__(self, args, verbose=False):
        self.app = FastAPI()
        
        # URL -> Active Request Count (load state)
        self.worker_request_counts: dict[str, int] = {}
        # URL -> Consecutive Failures
        self.worker_failure_counts: dict[str, int] = {}
        # Quarantined workers excluded from routing pool
        self.dead_workers: set[str] = set()
        
        # Setup HTTP client with connection pooling
        self.client = httpx.AsyncClient(
            limits=httpx.Limits(max_connections=max_connections),
            timeout=httpx.Timeout(timeout),
        )
        
        # Load middleware plugins
        for middleware_path in args.slime_router_middleware_paths or []:
            middleware = load_function(middleware_path)
            self.app.add_middleware(middleware, router=self)

Load Balancing

SlimeRouter uses least-inflight load balancing:
router.py
def _use_url(self):
    """Select worker URL with minimal active requests."""
    if not self.dead_workers:
        url = min(self.worker_request_counts, key=self.worker_request_counts.get)
    else:
        valid_workers = (w for w in self.worker_request_counts 
                        if w not in self.dead_workers)
        url = min(valid_workers, key=self.worker_request_counts.get)
    
    self.worker_request_counts[url] += 1
    return url

Health Monitoring

Background health check loop monitors all workers:
router.py
async def _health_check_loop(self):
    while True:
        await asyncio.sleep(interval)
        
        urls = [u for u in self.worker_request_counts if u not in self.dead_workers]
        results = await asyncio.gather(*(
            self._check_worker_health(url) for url in urls
        ))
        
        for url, is_healthy in results:
            if not is_healthy:
                failures = self.worker_failure_counts.get(url, 0) + 1
                self.worker_failure_counts[url] = failures
                
                if failures >= threshold:
                    logger.warning(f"Worker {url} failed {threshold} checks. Marking DEAD.")
                    self.dead_workers.add(url)
            else:
                self.worker_failure_counts[url] = 0

SlimeRouter vs SGLang Model Gateway

SlimeRouter and SGLang Model Gateway can both route requests to workers, but they are optimized for different goals.

Key Differences

# Lightweight Python/FastAPI proxy
# Acts as passthrough to SGLang workers
# Preserves all response metadata
# Supports custom middleware for RL features

class SlimeRouter:
    async def proxy(self, request: Request, path: str):
        worker_url = self._use_url()
        response = await self.client.request(
            request.method, f"{worker_url}/{path}", 
            content=body, headers=headers
        )
        # Return response as-is (passthrough)
        return JSONResponse(content=data, status_code=response.status_code)

When to Use Which

Use SlimeRouter

  • You need R3 (rollout routing replay)
  • You need radix-tree caching
  • You need custom middleware for RL metadata

Use SGLang Model Gateway

  • Everything else (recommended default)
  • Maximum throughput and scalability
  • Advanced fault tolerance
  • Cache-aware routing
For more details on SGLang Model Gateway, see the official documentation.

Session-Affinity Routing for Multi-Turn Agents

When using SGLang Model Gateway with consistent hashing routing policy, Slime automatically assigns each rollout session a unique session ID and uses it as the routing key to enable session affinity.

What Is Session Affinity?

Session affinity (also called sticky sessions) ensures that all requests belonging to the same conversation or agent session are routed to the same backend worker. This is beneficial for:
  • Multi-turn dialogues: Keeping the same worker improves prefix cache hit rates
  • Multi-agent systems: Ensures agent state consistency and better resource locality
  • Debugging: Makes it easier to trace and debug specific sessions

How It Works

When the rollout system generates samples, each sample is assigned a unique session_id:
  1. Automatically generated using UUID for each sample
  2. Stored in sample.session_id field
  3. Passed as X-SMG-Routing-Key header when the router policy is consistent_hashing
The SGLang Model Gateway’s consistent hashing policy then uses this routing key to deterministically select the same worker for all requests with the same session ID.

Configuration

To enable session-affinity routing:
--sglang-router-policy consistent_hashing
Slime will automatically start SGLang Model Gateway with the consistent hashing policy.
If you encounter an error about the consistent_hashing policy not being available, upgrade sglang-router:
pip install -U sglang-router

Notes

  • Each sample gets its own unique session ID
  • Different samples in the same group may be routed to different workers
  • The same sample’s subsequent turns will maintain the same session ID
  • Currently, this feature is only available for SGLang Model Gateway

API Reference

POST /add_worker

Add a new worker to the router. Request:
# Via query parameter
curl -X POST "http://localhost:30000/add_worker?url=http://127.0.0.1:10090"

# Via JSON body
curl -X POST "http://localhost:30000/add_worker" \
  -H "Content-Type: application/json" \
  -d '{"url": "http://127.0.0.1:10090"}'
Response:
{
  "status": "success",
  "worker_urls": {
    "http://127.0.0.1:10090": 0,
    "http://127.0.0.1:10091": 0
  }
}

GET /list_workers

List all registered workers. Request:
curl "http://localhost:30000/list_workers"
Response:
{
  "urls": [
    "http://127.0.0.1:10090",
    "http://127.0.0.1:10091"
  ]
}

POST /retrieve_from_text

Get token information from text input (requires RadixTreeMiddleware). Request:
curl -X POST "http://localhost:30000/retrieve_from_text" \
  -H "Content-Type: application/json" \
  -d '{"text": "Hello, how are you?", "return_logp": true}'
Response:
{
  "tokens": [9906, 11, 703, 527, 499, 30],
  "response": "Hello, how are you?",
  "loss_mask": [0, 0, 0, 0, 0, 0],
  "token_length": 6,
  "loss_mask_length": 6,
  "rollout_logp": [-0.1, -0.2, -0.15, -0.18, -0.12, -0.25]
}

Example: Using RadixTreeMiddleware

SGLANG_ARGS=(
  --use-slime-router
  --slime-router-middleware-paths slime.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware
)

ray job submit --address="http://127.0.0.1:8265" \
  -- python3 train.py \
  ${MODEL_ARGS[@]} \
  ${SGLANG_ARGS[@]}

Best Practices

Choose the Right Router

Use SlimeRouter only when you need its specialized features (R3, radix-tree). Otherwise, use SGLang Model Gateway for better performance.

Configure Connection Pooling

Set slime-router-max-connections based on your concurrency needs. Default is sglang-server-concurrency * rollout-num-gpus / rollout-num-gpus-per-engine.

Monitor Cache Hit Rates

When using radix-tree caching, monitor cache hit rates to ensure the cache is effective. Low hit rates may indicate the cache size is too small.

Test Middleware Plugins

Test custom middleware plugins thoroughly in development before deploying to production. Middleware errors can break the entire rollout pipeline.

Build docs developers (and LLMs) love