1. The Big Picture

Every LLM serving system must turn HTTP requests into GPU-generated tokens as fast as possible.

The Problem

You have a large language model. Users send prompts and want completions back with sub-second latency. The naive approach — tokenize, run forward passes, detokenize, all in one thread — collapses under concurrent load. Tokenization blocks the GPU. Detokenization blocks the scheduler. A single slow request holds up the entire pipeline.

The Idea

SGLang decomposes LLM serving into three asynchronous processes connected by ZMQ message-passing:

TokenizerManager receives HTTP requests, tokenizes text, and handles multimodal inputs. Scheduler manages batching, KV cache allocation, and GPU execution. DetokenizerManager converts output tokens back to text and streams results.

This means tokenization of new requests overlaps with GPU computation on current requests, which overlaps with detokenization of completed requests. No stage blocks another.

Why Separate Processes, Not Async Tasks?

ApproachGIL ContentionCrash IsolationCPU/GPU SeparationVerdict
Separate processes + ZMQNone (each has own GIL)One crash doesn't kill othersClean boundaryChosen
Async tasks in one processGIL serializes CPU workOne crash kills everythingShared memory contentionToo fragile

How It Works

A request's journey through SGLang follows this pipeline. Each arrow represents a boundary where work can be pipelined — while the scheduler runs a GPU forward pass on batch N, the tokenizer is already preparing batch N+1, and the detokenizer is streaming results from batch N-1.

HTTP POST /generate
  → GenerateReqInput           (raw user input: text, params, images)
  → TokenizerManager           (tokenize, validate, normalize)
  → TokenizedGenerateReqInput  (token IDs, sampling params, MM embeddings)
  → [ZMQ send to Scheduler]
  → Req                        (internal request object with KV cache tracking)
  → ScheduleBatch              (grouped requests for one forward pass)
  → ForwardBatch               (GPU tensors for the model)
  → output tokens              (from GPU forward pass)
  → [ZMQ send to Detokenizer]
  → text response              (streamed back to HTTP client)

The scheduler's event loop is the heart of the system. On each iteration it: (1) polls for new tokenized requests from ZMQ, (2) validates and queues them, (3) decides what to run next — a prefill batch for new requests, a decode batch for in-progress generation, or both, (4) executes the GPU forward pass, and (5) processes results, checking stop conditions and routing finished tokens downstream.

Notice the elegant simplicity of the loop below. No thread pools, no async/await, no callbacks. It is a pure synchronous loop that runs as fast as the GPU allows. recv_requests() is non-blocking (ZMQ poll with zero timeout), so the loop never stalls waiting for input. When there is no batch to run, on_idle() performs maintenance like cache eviction and defragmentation.

Source Code

SOURCE CODE python/sglang/srt/entrypoints/http_server.py:L684-L718
@app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request):
    if obj.stream:
        async def stream_results():
            async for out in _global_state.tokenizer_manager.generate_request(
                obj, request
            ):                                          # ← async generator yields partial results
                yield b"data: " + dumps_json(out) + b"\n\n"
            yield b"data: [DONE]\n\n"
        return StreamingResponse(stream_results(), media_type="text/event-stream")
    else:
        ret = await _global_state.tokenizer_manager.generate_request(
            obj, request
        ).__anext__()                                   # ← single-shot: take first result
        return orjson_response(ret)
SOURCE CODE python/sglang/srt/managers/scheduler.py:L1363-L1389
@DynamicGradMode()
def event_loop_normal(self):
    while True:
        recv_reqs = self.recv_requests()              # ← poll ZMQ for new requests
        self.process_input_requests(recv_reqs)        # ← validate and enqueue
        batch = self.get_next_batch_to_run()          # ← scheduling decision
        self.cur_batch = batch
        if batch:
            result = self.run_batch(batch)            # ← GPU forward pass
            self.process_batch_result(batch, result)  # ← check stop conditions
        else:
            self.on_idle()
        self.last_batch = batch
TRY IT

Request Flow Animation

Click "Send Request" to watch a request flow through the pipeline.

HTTP Server
 
Tokenizer
 
Scheduler
 
Detokenizer
 
0
Requests Sent
0
Tokens Generated
-
Avg Latency (ms)

Further Reading

Key Takeaway: SGLang's architecture is a three-stage async pipeline. The TokenizerManager, Scheduler, and DetokenizerManager run as independent processes via ZMQ. CPU-bound tokenization never blocks GPU-bound inference.

2. Tokenization & the Request Lifecycle

Three input paths converge to one Req object — the single source of truth for every in-flight request.

The Problem

LLM APIs accept wildly diverse inputs: raw text, pre-tokenized IDs, raw embeddings, images, audio. Before the scheduler can do anything, all of this must be normalized into a uniform internal representation: the Req object.

The Idea

Path 1: Text. The most common case — send a string, get token IDs back from the tokenizer. Path 2: Pre-tokenized IDs. Power users send input_ids directly. Path 3: Raw embeddings. For retrieval or adapter models. For multimodal inputs, a fourth layer activates on top of any path: the MM processor converts images/audio into embedding tensors.

Source Code

SOURCE CODE python/sglang/srt/managers/schedule_batch.py:L558-L693
class Req:
    def __init__(self, rid, origin_input_text, origin_input_ids, sampling_params, ...):
        self.rid = rid                          # ← unique request ID
        self.origin_input_ids = origin_input_ids # ← tokenized prompt (frozen)
        self.output_ids = []                    # ← generated tokens (grows)
        self.fill_ids = []                      # ← input + output combined
        self.kv_committed_len = 0               # ← tokens in radix cache
        self.kv_allocated_len = 0               # ← tokens with GPU allocation
        self.finished_reason = None             # ← stop condition
        self.stream = stream
SOURCE CODE python/sglang/srt/managers/tokenizer_manager.py:L700-L750
async def _tokenize_one_request(self, obj):
    if obj.input_embeds is not None:        # ← Path 3: raw embeddings
        input_embeds = obj.input_embeds
        input_ids = obj.input_ids
    elif obj.input_ids is not None:         # ← Path 2: pre-tokenized
        input_ids = obj.input_ids
    else:                                   # ← Path 1: text → tokenize
        input_ids, token_type_ids = await self._tokenize_texts(input_text)
    if self.mm_processor and obj.contains_mm_input():
        mm_inputs = await self.mm_processor.process_mm_data_async(
            image_data=obj.image_data, audio_data=obj.audio_data,
            input_text=(input_text or input_ids),
        )                                       # ← produces image/audio embeddings
Key Takeaway: Three input paths, one output format. The Req object starts with immutable input data and accumulates mutable state (output_ids, kv_committed_len) as generation proceeds.

3. KV Cache & GPU Memory Management

The KV cache is the largest memory consumer. Two pools and zero fragmentation make it manageable.

The Problem

To produce token N+1, the model needs KV pairs from all tokens 1 through N. A single Llama 70B request with 4K context consumes ~1.3 GB of KV cache. Naive allocation wastes memory on padding since most requests are shorter than the maximum.

The Idea

ReqToTokenPool is a 2D table mapping each request to its KV slot indices. TokenToKVPoolAllocator manages a free list of physical KV indices. This two-level design decouples logical structure from physical storage, enabling non-contiguous allocation with zero fragmentation.

What is a KV Cache? In transformer inference, each token's attention computation produces Key and Value vectors. To avoid recomputing them on every generation step, they're stored in a cache. For Llama 70B at 4K context, a single request's KV cache consumes ~1.3 GB of GPU memory. The KV cache is the dominant memory consumer in LLM serving — managing it efficiently is the difference between serving 10 requests and 1,000.
What is Paged Memory? Like virtual memory in an OS — logical addresses map to physical pages via a page table. SGLang's req_to_token tensor IS the page table: req_to_token[req_slot, token_pos] = physical_kv_index. This indirection means requests can use non-contiguous physical memory without knowing or caring.

Why Paged Pools, Not Contiguous Allocation?

ApproachFragmentationPrefix SharingMemory WasteVerdict
Two-level paged poolZeroNatural (shared indices)None (per-token alloc)Chosen
Contiguous pre-allocationSevereImpossibleHigh (max_len padding)Wastes memory

Source Code

SOURCE CODE python/sglang/srt/mem_cache/memory_pool.py:L126-L186
class ReqToTokenPool:
    def __init__(self, size, max_context_len, device, enable_memory_saver):
        self.req_to_token = torch.zeros(
            (size, max_context_len), dtype=torch.int32, device=device
        )                                       # ← [req_slot, token_pos] → kv_index
        self.free_slots = list(range(size))

    def alloc(self, reqs):
        need_size = len(reqs) - len([r for r in reqs if r.req_pool_idx is not None])
        if need_size > len(self.free_slots):
            return None                         # ← OOM: no free slots
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
        return select_index

    def free(self, req):
        self.free_slots.append(req.req_pool_idx) # ← reclaim slot
TRY IT

Memory Grid Visualization

Each cell = one KV slot. Add requests to allocate, free them to release.

0
Used Slots
320
Free Slots
0%
Utilization
0
Active Reqs

Further Reading

Key Takeaway: Two pools, zero fragmentation. ReqToTokenPool maps requests to tokens; TokenToKVPoolAllocator manages physical indices. Non-contiguous allocation is invisible to attention kernels.

4. RadixAttention — Prefix Caching

When users send similar prompts, recomputing the same KV from scratch is pure waste. RadixAttention eliminates it.

The Problem

In real deployments, prompts share enormous prefixes — system prompts, RAG documents, few-shot examples. Without prefix caching, each request recomputes identical KV pairs. A hash table has O(N) lookup per prefix length; a radix tree does it incrementally in O(1) per additional token.

The Idea

The RadixCache stores KV data in a radix tree. Each node holds a token segment and KV indices. Prefix matching walks the tree comparing tokens. Node splitting handles partial matches. Only leaf nodes are evicted (LRU/LFU/priority).

What is a Radix Tree? A radix tree (also called a compressed trie) stores sequences by sharing common prefixes. Instead of one node per token, each node holds a segment of tokens. The path from root to any node represents a complete prefix. The key property: prefix matching is incremental — once you've matched N tokens, checking token N+1 costs O(1), just follow one child edge. This is fundamentally different from a hash table, where checking a prefix of length N requires hashing all N tokens every time.

Why Radix Tree, Not Hash Table?

ApproachPrefix LookupPartial MatchingMemory SharingVerdict
Radix treeO(1) incremental per tokenNatural (split nodes)Automatic (shared path)Chosen
Hash tableO(N) per lookup (hash entire prefix)None (exact match only)Manual dedup neededToo slow for long prefixes

Source Code

SOURCE CODE python/sglang/srt/mem_cache/radix_cache.py:L420-L444
def _match_prefix_helper(self, node, key):
    access_time = time.monotonic()
    node.last_access_time = access_time       # ← update LRU
    child_key = self.get_child_key_fn(key)
    value = []
    while len(key) > 0 and child_key in node.children.keys():
        child = node.children[child_key]
        prefix_len = self.key_match_fn(child.key, key)  # ← compare tokens
        if prefix_len < len(child.key):
            new_node = self._split_node(child.key, child, prefix_len)
            value.append(new_node.value)      # ← partial match: split
            node = new_node; break
        else:
            value.append(child.value)         # ← full match: collect KV
            node = child
            key = key[prefix_len:]
            if len(key): child_key = self.get_child_key_fn(key)
    return value, node
TRY IT

Radix Tree Visualizer

Insert token sequences to build the tree. Similar prefixes share nodes.

ROOT
1
Nodes
0
Cached Tokens
0
Hit Tokens
-
Hit Rate

Further Reading

Key Takeaway: Radix trees give you the longest prefix match, not just an exact match. For shared system prompts, this eliminates 70-90% of prefill computation.

5. Scheduling — Who Gets the GPU Next?

The scheduler decides who goes first — and the wrong answer wastes everything the cache built.

The Problem

A naive FIFO queue ignores cache affinity. Request #47 might share a 90% prefix with cached data while Request #12 has zero overlap. The scheduler must make these decisions thousands of times per second.

The Idea

SGLang defaults to LPM (Longest Prefix Match), which sorts the waiting queue so requests with the longest RadixCache prefix hit go first. Instead of computing 4096 tokens of prefill, a request might only need 512 new tokens because the other 3584 are already cached. Alternative policies include FCFS (first-come-first-served), LOF (longest output first), and DFS-Weight.

Why LPM, Not FCFS?

PolicyCache UtilizationFairnessOverheadVerdict
LPM (Longest Prefix Match)Maximizes reuseMay delay new promptsO(N) prefix scanDefault
FCFS (First Come First Served)Ignores cachePerfect fairnessO(1)Fallback when queue > 128

Two queues govern the flow. The waiting queue holds requests that have arrived but not yet begun prefill. The running batch holds requests in the decode phase. Every iteration, the scheduler decides how many new requests to promote from waiting into a prefill batch, without starving decode.

How It Works

The scheduler loop executes these steps on every iteration:

1. Receive requests. New tokenized requests arrive via ZMQ from the tokenizer. Each carries input token IDs, sampling params, and max output length.

2. Apply scheduling policy. calc_priority reorders the waiting queue. For LPM, this computes the prefix match length for each request against the RadixCache. When the queue exceeds 128 requests, it falls back to FCFS since computing prefix matches becomes expensive.

3. Build a batch with PrefillAdder. Walks the sorted queue, adding requests one at a time. Tracks three budgets: total input tokens (rem_input_tokens), chunk tokens (rem_chunk_tokens), and available KV cache pages. When any budget is exhausted, it stops.

4. Execute the forward pass. The assembled batch is sent to the model runner. Prefill requests run in extend mode; decode requests run one token at a time.

5. Process results. After the forward pass returns logits, the scheduler samples the next token, checks finish conditions (EOS, max length), and routes finished requests to the detokenizer.

Pseudocode: One scheduler iteration

new_reqs = recv_from_tokenizer()
waiting_queue.extend(new_reqs)

# Reorder by policy (LPM puts best cache hits first)
prefix_computed = policy.calc_priority(waiting_queue)

# Build batch within budget
prefill_batch = PrefillAdder.build(waiting_queue, token_budget, memory_budget)

# Merge with ongoing decode batch
batch = merge(running_batch, prefill_batch)

# GPU work
logits = model_runner.forward(batch)

# Post-process
for req in batch:
    next_token = sample(logits[req])
    if is_finished(req, next_token):
        send_result(req)
        running_batch.remove(req)

Source Code

SOURCE CODE python/sglang/srt/managers/schedule_policy.py:117-159
def calc_priority(self, waiting_queue, running_batch=None):
    if self.policy == CacheAgnosticPolicy.FCFS:
        if self.enable_priority_scheduling:
            SchedulePolicy._sort_by_priority_and_fcfs(waiting_queue, self.priority_sign)
        return False
    policy = self._determine_active_policy(waiting_queue) # ← FCFS fallback if queue > 128
    if isinstance(policy, CacheAwarePolicy):
        self._compute_prefix_matches(waiting_queue, policy) # ← find cache hits
        if policy == CacheAwarePolicy.LPM:
            SchedulePolicy._sort_by_longest_prefix(waiting_queue) # ← best hit first
    return prefix_computed

Further Reading

Key Takeaway: LPM scheduling is what makes RadixCache pay off. Without it, cached prefixes go unused because unrelated requests get scheduled first.

6. Continuous Batching & Chunked Prefill

Don't make the GPU wait for stragglers — mix prefill and decode in the same breath.

The Problem

A long prefill (100K context) blocks all decode requests, spiking TTFB for everyone. Traditional batching waits for all requests to finish before starting the next batch.

The Idea

What is Continuous Batching? In static batching, the GPU processes a fixed batch and waits for ALL requests to finish before starting the next. In continuous batching, the scheduler checks after every forward pass: which requests finished? Remove them. Which new requests are waiting? Add them. The batch composition changes fluidly, iteration to iteration. No synchronization barrier.
What is Chunked Prefill? A 100K-token prompt would monopolize the GPU for hundreds of milliseconds if prefilled all at once. Chunked prefill slices it into pieces (e.g., 8192 tokens per chunk). Between chunks, decode requests get their turn. The long prefill resumes next iteration, picking up exactly where it left off. TTFB becomes bounded by chunk size, not input length.

Why Continuous + Chunked, Not Static Batching?

ApproachGPU UtilizationTTFBComplexityVerdict
Continuous + chunkedNear 100%Bounded by chunk sizeMediumChosen
Static batchingLow (waits for slowest)UnboundedSimpleToo wasteful

How It Works

Step 1: Prefill completes, merge into decode. After a prefill forward pass finishes, completed requests join the running decode batch. This is the "continuous" part — new decode requests join mid-stream, no synchronization needed.

Step 2: Chunked requests saved for later. If a prompt was too long for one chunk, the PrefillAdder records it as chunked_req. Next iteration, this request gets priority and resumes from where it left off.

Step 3: Mixed batches. In a single forward call, SGLang processes both prefill tokens (new requests) and decode tokens (ongoing generation). The attention kernel handles variable-length inputs via the backend metadata.

Step 4: Budget interplay. The chunk budget (rem_chunk_tokens) limits per-request prefill. The input budget (rem_input_tokens) limits total prefill. Decode tokens from the running batch eat into the total, ensuring the GPU is not overloaded.

Pseudocode: Continuous batching iteration

# After last forward pass completed:
if last_batch was prefill:
    for req in last_batch:
        if req fully prefilled:
            running_batch.add(req)        # Join decode batch
        else:
            save as chunked_req           # Resume next iteration

# Build next batch:
new_prefill = PrefillAdder.build(
    waiting_queue,
    rem_input_tokens = budget - running_batch.decode_tokens,
    rem_chunk_tokens = 8192
)
combined = merge(running_batch, new_prefill)
logits = model.forward(combined)

Source Code

SOURCE CODE python/sglang/srt/managers/scheduler.py:2290-2318
if self.last_batch and self.last_batch.forward_mode.is_extend():  # ← prefill just finished
    if not self.last_batch.is_empty():
        if self.running_batch.is_empty():
            self.running_batch = self.last_batch       # ← move to decode
        else:
            self.running_batch.merge_batch(self.last_batch)  # ← merge into running
Key Takeaway: Continuous batching keeps the GPU full; chunked prefill keeps latency fair. The chunk size is the tuning knob.

7. Model Execution — The Forward Pass

All roads lead to model.forward() — but CUDA graphs pave them with one kernel launch.

The Problem

Hundreds of small CUDA kernels per forward pass mean significant launch overhead, especially during decode when each request contributes just one token. The model runner minimizes this through CUDA graph caching.

The Idea

SGLang's ModelRunner sits between the scheduler and GPU. It receives a ForwardBatch and dispatches to one of three paths: Extend (prefill, many tokens), Decode (one token), or Idle (pipeline sync).

What is a CUDA Graph? A recorded sequence of GPU operations that can be replayed as a single unit. Instead of the CPU telling the GPU "do matmul, then layernorm, then attention..." for each of 80 layers, the entire sequence is captured once and replayed in one launch. The GPU never waits for CPU instructions between operations. At small batch sizes where kernel launch overhead dominates, this reduces decode latency by 2-3x.

Why CUDA Graphs, Not torch.compile?

ApproachLatency ReductionFlexibilityWarmup CostVerdict
CUDA graph replay2-3x for decodeFixed shapes onlyPer-shape captureDefault for decode
torch.compileVariableDynamic shapesCompilation timeUsed for extend/prefill

How It Works

Stage 1: Dispatch. _forward_raw checks three conditions for CUDA graph eligibility: decode mode, GraphRunner initialized, and batch shape matches a captured graph. If all three, replay.

Stage 2: Metadata init. For non-graph paths, the attention backend's init_forward_metadata builds KV index structures. For FlashInfer, this means ragged tensor indices mapping requests to KV pages. Precomputed once per batch, reused across all 32+ attention layers.

Stage 3: Model execution. Embedding lookup, N transformer layers (attention + MLP each), final vocabulary projection. In graph mode, all of this is a single graph.replay() call.

Pseudocode: Forward pass dispatch

def forward(forward_batch):
    # Stage 1: Can we replay a CUDA graph?
    if mode == DECODE and graph_runner.has_graph(batch_shape):
        return graph_runner.replay(forward_batch)   # Single launch

    # Stage 2: Prepare attention metadata
    attn_backend.init_forward_metadata(forward_batch)  # Build KV indices

    # Stage 3: Run the model
    if mode == DECODE:
        logits = model.forward(input_ids, positions, forward_batch)
    elif mode == EXTEND:
        logits = model.forward(input_ids, positions, forward_batch)
    return logits

Source Code

SOURCE CODE python/sglang/srt/model_executor/model_runner.py:2940-2990
def _forward_raw(self, forward_batch, ...):
    can_run_graph = (
        forward_batch.forward_mode.is_cuda_graph()
        and self.graph_runner
        and self.graph_runner.can_run(forward_batch) # ← check CUDA graph eligibility
    )
    if can_run_graph:
        ret = self.graph_runner.replay(forward_batch)  # ← replay recorded ops
        return ModelRunnerOutput(logits_output=ret, can_run_graph=True)
    if forward_batch.forward_mode.is_decode():
        ret = self.forward_decode(forward_batch)       # ← single-token decode
    elif forward_batch.forward_mode.is_extend():
        ret, _ = self.forward_extend(forward_batch)    # ← multi-token prefill
    return ModelRunnerOutput(logits_output=ret, can_run_graph=can_run_graph)
Key Takeaway: CUDA graphs turn hundreds of kernel launches into one. For decode-heavy workloads, this is worth 2-3x in latency.

8. Attention Backends

FlashInfer, Triton, and Friends — the same math, four ways to run it fast.

The Problem

Attention is up to 70% of inference compute. Different hardware needs different kernels. A serving system needs a pluggable abstraction so new backends can be integrated without touching model code.

The Idea

What is Flash Attention? An IO-aware attention algorithm that tiles the Q/K/V computation to keep data in fast GPU SRAM instead of slow HBM. Standard attention reads/writes O(N²) to HBM; Flash Attention reduces this to O(N) while computing exact (not approximate) attention. The key insight: recompute softmax statistics during the backward pass instead of storing them — trade cheap FLOPs for expensive memory bandwidth.

Why Pluggable Backends?

BackendHardwareSpeedPortabilityVerdict
FlashInferNVIDIAFastestNVIDIA onlyDefault
TritonAny GPU10-20% slowerPortableAMD/Intel fallback
TorchNativeAnySlowestUniversalDebug only

SGLang uses a two-layer design. The model-facing layer is RadixAttention, a PyTorch module that every transformer layer calls. It reshapes Q, K, V tensors, optionally saves KV to cache, and delegates computation to whatever backend is active. The model code never imports FlashInfer or Triton directly.

The backend layer is the AttentionBackend abstract class. Each concrete backend implements three methods: init_forward_metadata (precompute KV index structures), forward_decode (single-token against full KV cache), and forward_extend (multi-token for prefill with partially populated KV).

How It Works

RadixAttention's forward: When torch.compile is active, it routes to a custom op (unified_attention_with_output) for compiler compatibility. Otherwise, it calls forward_batch.attn_backend.forward(), dispatching to decode or extend based on forward mode.

FlashInfer's decode path (most common production path on NVIDIA): init_forward_metadata builds a ragged index — for each request, a list of KV cache page IDs and valid tokens in the last page. The FlashInfer kernel reads this index, gathers scattered KV vectors, computes fused softmax attention, and writes output in a single kernel. Because the index is precomputed once per batch, all 32+ layers reuse the same metadata.

Triton's path uses similar logic expressed in Triton's Python DSL, making it portable to AMD and Intel GPUs. Tradeoff: typically 10-20% lower throughput on NVIDIA vs FlashInfer's hand-tuned CUDA.

Pseudocode: Attention dispatch flow

# In transformer layer N:
output = radix_attention.forward(q, k, v, forward_batch)

# Inside RadixAttention:
def forward(q, k, v, forward_batch):
    reshape q, k, v to [tokens, heads, head_dim]
    if torch_compile_active:
        return unified_attention_custom_op(q, k, v)
    else:
        return forward_batch.attn_backend.forward(q, k, v)

# Inside AttentionBackend:
def forward(q, k, v):
    if mode == DECODE:
        return self.forward_decode(q, k, v)   # 1 query vs full KV
    else:
        return self.forward_extend(q, k, v)   # N queries vs partial KV

Source Code

SOURCE CODE python/sglang/srt/layers/radix_attention.py:86-115
def forward(self, q, k, v, forward_batch, save_kv_cache=True, **kwargs):
    if k is not None:
        k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
        v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
    if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
        output = torch.empty_like(q)
        unified_attention_with_output(q, k, v, output, save_kv_cache, self.layer_id)
        return output                      # ← torch.compile custom op path
    else:
        return forward_batch.attn_backend.forward(  # ← dispatch to active backend
            q, k, v, self, forward_batch, save_kv_cache, **kwargs)
SOURCE CODE python/sglang/srt/layers/attention/base_attn_backend.py:50-100
class AttentionBackend(ABC):
    @abstractmethod
    def init_forward_metadata(self, forward_batch): ...   # ← precompute KV indices

    def forward(self, q, k, v, layer, forward_batch, ...):
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(q, k, v, layer, forward_batch)
        else:
            return self.forward_extend(q, k, v, layer, forward_batch)

    @abstractmethod
    def forward_decode(self, q, k, v, layer, forward_batch): ...  # ← 1 token vs full KV
    @abstractmethod
    def forward_extend(self, q, k, v, layer, forward_batch): ...  # ← N tokens vs partial KV

Further Reading

Key Takeaway: The AttentionBackend abstraction decouples kernel innovation from model logic. SGLang can adopt new kernels in days by implementing three methods.

9. Sampling — From Logits to Tokens

The model outputs 100,000 probabilities. How does SGLang pick exactly one token?

The Problem

After the forward pass, the model produces a logit vector over the full vocabulary. Greedy picks are repetitive; pure random is incoherent. Each request in a batch may have different sampling preferences.

The Idea

What is Nucleus (Top-P) Sampling? Sort tokens by descending probability. Keep the smallest set whose cumulative probability exceeds p (e.g., 0.95). When the model is confident, few tokens survive; when uncertain, many do. This adapts the candidate pool to the model's certainty — unlike top-k which always keeps exactly k tokens regardless of the probability landscape.

Sampling is a three-stage pipeline: scale, filter, pick. Temperature scaling controls distribution spread (0.1 = very confident, 2.0 = nearly uniform). Filtering — top-k, top-p (nucleus), and min-p — prunes away low-probability tokens. Finally, the surviving probabilities are normalized and a token is drawn via torch.multinomial.

The catch: SGLang serves a batch of requests, each with different parameters. SamplingBatchInfo packs per-request params into GPU tensors for vectorized execution. When all requests are greedy, a boolean flag is_all_greedy skips the entire softmax-filter-sample pipeline, jumping straight to argmax.

How It Works

1. Reshape and bias. Raw logits arrive as [batch_size, vocab_size]. Per-token logit biases (repetition penalty) are applied additively.

2. Greedy fast path. When is_all_greedy is true, call argmax and return. No softmax, no filtering, no random numbers. This is the hot path for many API workloads.

3. Temperature scaling. Divide each request's logits by its temperature: logits / temperatures. Softmax produces the probability distribution.

4. Min-p filtering. Mask out tokens with probability less than min_p * max_prob. Adaptive: when the model is confident, few tokens survive; when uncertain, many do.

5. Top-p (nucleus). Sort by descending probability, compute cumulative sum. Zero out tokens beyond the top_p threshold. A top-p of 0.95 keeps the smallest set with 95%+ combined probability.

6. Top-k. Keep only the top K tokens; zero the rest. Hard cutoff regardless of mass.

7. Multinomial draw. torch.multinomial(probs, 1) draws one token per request.

Each filtering stage is guarded by a batch-level flag (need_min_p_sampling, need_top_p_sampling, need_top_k_sampling). If no request in the batch uses top-k, the entire top-k kernel is skipped.

logits [batch, vocab]
    ↓
apply_logit_bias (additive)
    ↓
is_all_greedy? ──YES──> argmax → done
    ↓ NO
softmax(logits / temperature)
    ↓
need_min_p? ──YES──> zero out tokens < min_p * max_prob
    ↓
need_top_p? ──YES──> sort, cumsum, zero beyond threshold
    ↓
need_top_k? ──YES──> keep only top-k
    ↓
multinomial(probs, 1) → token_ids

Source Code

SOURCE CODE python/sglang/srt/layers/sampler.py
def forward(self, logits, forward_batch):
    logits = logits.contiguous().view(-1, self.vocab_size)
    if forward_batch.sampling_info.is_all_greedy:
        batch_next_token_ids = torch.argmax(logits, dim=-1)  # ← fast path
    else:
        probs = torch.softmax(logits / temperatures, dim=-1)
        if forward_batch.sampling_info.need_min_p_sampling:
            probs = self._apply_min_p(probs, min_ps)     # ← min-p filter
        if forward_batch.sampling_info.need_top_p_sampling:
            probs = self._apply_top_p(probs, top_ps)     # ← nucleus sampling
        if forward_batch.sampling_info.need_top_k_sampling:
            probs = self._apply_top_k(probs, top_ks)     # ← top-K filter
        batch_next_token_ids = torch.multinomial(probs, 1).squeeze(1)
    return batch_next_token_ids
TRY IT

Sampling Parameter Explorer

Adjust sliders to see how temperature, top-k, top-p, and min-p affect the probability distribution.

1.0
off
1.0
off
15
Surviving Tokens
-
Entropy (bits)
-
Top Token Prob
Key Takeaway: Batch-level flags (is_all_greedy, need_top_p_sampling) skip entire GPU kernels. The greedy path is nearly free.

10. Constrained Decoding

How SGLang guarantees valid JSON from a 70-billion-parameter model, every single time.

The Problem

Models usually produce valid JSON, but "usually" is not enough for production. Retry-on-failure wastes GPU time. Constrained decoding makes invalid output structurally impossible.

The Idea

What is an FSM (Finite State Machine)? A model with discrete states and transitions triggered by inputs. For JSON decoding, states include EXPECT_KEY, IN_STRING, EXPECT_COLON, EXPECT_VALUE. At each state, only certain tokens are valid. The FSM computes a vocabulary mask — a binary vector over all 128K tokens — that zeros out invalid choices before sampling.

Constrained decoding runs an FSM alongside the language model. At each step, it computes allowed tokens as a vocabulary mask (1 = allowed, 0 = forbidden), applied before sampling by setting forbidden logits to negative infinity.

Why FSM Masking, Not Retry-on-Invalid?

ApproachGPU WasteGuaranteeLatencyVerdict
FSM vocabulary maskingZeroStructural (impossible to fail)Same as unconstrainedChosen
Retry on invalid output2-10xProbabilistic (may never converge)UnboundedUnreliable

XGrammar is the fastest backend: compiles grammars into optimized FSMs with bit-level mask operations. Outlines provides a Python regex-to-FSM fallback. Compilation is async + cached — the first request pays the cost; subsequent requests reuse instantly.

Jump-forward decoding: When the FSM reaches a state where only one continuation is valid (e.g., after {"name": the next character must be "), the system skips sampling entirely and emits deterministic tokens directly, sometimes skipping 5-10 tokens in one jump.

How It Works

The constrained pipeline integrates at three points in the decode loop:

1. Grammar compilation (once per schema). GrammarManager checks its cache. Hit: clone the cached FSM (cheap pointer copy). Miss: dispatch compilation to ThreadPoolExecutor so the serving loop is not blocked.

2. Vocabulary masking (every decode step). Each request's grammar calls fill_next_token_bitmask, which consults the FSM state and sets bits for allowed tokens. Applied to logits: logits[~mask] = -inf. Sampling then proceeds normally.

3. State advancement (after sampling). The sampled token is fed to accept_token. The FSM transitions to a new state. Then try_jump_forward checks for unique continuations.

Request arrives with json_schema={"type": "object", "properties": ...}
    ↓
GrammarManager.get_or_create_grammar()
    |— cache hit?  → clone cached FSM
    |— cache miss? → compile in ThreadPoolExecutor
    ↓
For each decode step:
    FSM.current_state → fill_vocab_mask() → bitmask [vocab_size]
    logits[~bitmask] = -inf
    sample from masked logits → token_id
    FSM.accept_token(token_id)
    FSM.try_jump_forward() → skip deterministic tokens

Source Code

SOURCE CODE python/sglang/srt/constrained/grammar_manager.py
class GrammarManager:
    def __init__(self, tokenizer, backend_type):
        self.grammar_cache = {}               # ← cache compiled grammars
        self.executor = ThreadPoolExecutor()   # ← background compilation

    async def get_or_create_grammar(self, constraint_type, constraint_value):
        cache_key = hash((constraint_type, constraint_value))
        if cache_key in self.grammar_cache:
            return self.grammar_cache[cache_key].clone()  # ← reuse compiled
        grammar = await asyncio.get_event_loop().run_in_executor(
            self.executor, self._compile_grammar, constraint_type, constraint_value
        )                                     # ← compile in background
        self.grammar_cache[cache_key] = grammar
        return grammar.clone()

Further Reading

Key Takeaway: Constrained decoding makes invalid output impossible via FSM vocabulary masking. Jump-forward turns constraints into an optimization, not a tax.

11. Scaling Up — Data & Tensor Parallelism

One GPU is not enough. Here is how SGLang splits work across 2, 4, 8, or 64 GPUs.

The Problem

A 70B model in FP16 needs 140 GB — far more than one GPU. Even if the model fits, one GPU cannot serve production traffic. You need to either split the model or replicate it.

The Idea

What is Tensor Parallelism? Split each layer's weight matrices across N GPUs. With 64 attention heads and TP=4, each GPU handles 16 heads and computes its shard independently. After each layer, GPUs synchronize via AllReduce — a collective operation that sums partial results across GPUs and distributes the total back. The communication cost per layer: one [batch, seq_len, hidden_dim] tensor. Over NVLink (900 GB/s) this takes microseconds; over Ethernet it becomes the bottleneck.

Why TP + DP, Not Just One?

StrategyFits Larger ModelsThroughput ScalingCommunicationVerdict
Tensor Parallelism onlyYes (split weights)Limited by AllReducePer-layer, needs NVLinkFor models > 1 GPU
Data Parallelism onlyNo (full replica)LinearNone during inferenceFor throughput
Hybrid DP + TPYesLinear × TP groupsWithin TP group onlyProduction default

Tensor Parallelism (TP) splits a single model across N GPUs. Each GPU holds 1/N of attention heads and 1/N of MLP width. After each layer, GPUs synchronize via AllReduce — a collective op that sums partial results. TP is the standard for models exceeding single-GPU memory. The tradeoff: AllReduce happens after every layer (twice — after attention and after MLP), so NVLink bandwidth (900 GB/s) is essential.

Data Parallelism (DP) replicates the entire model on N independent GPU groups. A DataParallelController routes requests to the least-loaded worker. Throughput scales linearly — no inter-replica communication during inference. Cost: each replica needs full model weights + KV cache.

Hybrid DP+TP: With 8 GPUs, run 2 DP replicas each using 4-way TP. DP routes requests to replicas; within each replica, TP splits the model. This serves models too large for 4 GPUs while doubling throughput.

How It Works

TP internals: At init, SGLang creates TP groups via PyTorch's process group API. With 8 GPUs and TP=4: groups [0,1,2,3] and [4,5,6,7]. Within each group:

Attention is split by heads. With 64 heads and TP=4, each GPU handles 16. QKV weights partitioned column-wise; output projection row-wise, followed by AllReduce.

MLP split by hidden dim. Gate/up projection column-wise; down projection row-wise + AllReduce.

• Each AllReduce communicates [batch, seq_len, hidden_dim]. With 4096 hidden dim in FP16, that is 8 KB per token per layer. Over 80 layers and 1000 tokens: ~640 MB per iteration.

DP internals: DataParallelController receives requests from the API server and dispatches. Two policies: round-robin (simple, fair) and shortest-queue (better for variable-length prompts). Each worker runs its own scheduler, KV cache, and RadixCache independently.

8 GPUs, TP=4, DP=2:

Request stream
    ↓
DataParallelController
    |                    |
    ↓                    ↓
  Replica 0            Replica 1
  [GPU 0,1,2,3]       [GPU 4,5,6,7]
   TP group 0           TP group 1
   AllReduce             AllReduce
   within group          within group

Source Code

SOURCE CODE python/sglang/srt/managers/data_parallel_controller.py
class DataParallelController:
    def event_loop(self):
        while True:
            recv_reqs = self.recv_requests()
            for req in recv_reqs:
                worker_idx = self._select_worker(req) # ← load balance
                self.workers[worker_idx].send(req)

    def _select_worker(self, req):
        if self.policy == "round_robin":
            idx = self.next_worker
            self.next_worker = (self.next_worker + 1) % self.dp_size
            return idx
        else:  # shortest_queue
            return min(range(self.dp_size), key=lambda i: self.workers[i].queue_size)

Further Reading

Key Takeaway: TP lets you serve models that don't fit on one GPU (AllReduce per layer). DP gives linear throughput scaling (independent replicas). Most deployments use both.

12. Speculative Decoding

A small model guesses the next 5 tokens. The large model checks them all at once.

The Problem

Autoregressive decoding is serial: each token requires a full forward pass. For large models, this is memory-bandwidth bound — the GPU waits for 140 GB of weights to stream from HBM. You pay the full cost to generate one token.

The Idea

What is Speculative Decoding? The key insight is that loading model weights from memory costs the same for 1 token or K tokens — the bottleneck is memory bandwidth, not compute. Large model inference is memory-bandwidth bound: you pay the full cost of streaming 140 GB of weights through the memory bus to generate a single token. The GPU's compute cores are mostly idle, waiting for data. Speculative decoding exploits this by verifying K tokens for the price of ~1.

Why Speculate, Not Just Decode Normally?

ApproachTokens per Target ForwardCostNet SpeedupVerdict
Speculative (K=5, 80% accept)~5 tokens1 draft + 1 target pass2-4xChosen for large models
Standard autoregressive1 token1 target pass1x (baseline)Simple, always correct

Speculative decoding exploits the fact that a small draft model (1B params) can predict what a large target model (70B) would generate with 70-80% accuracy. Let the draft model sprint ahead and generate K candidates quickly. Feed all K to the target model in a single forward pass — this works because the target can process K tokens in parallel, like processing a prompt. Accept consecutive matches, reject the first mismatch, sample a bonus token at the rejection point. Net speedup: 2-4x.

SGLang supports EAGLE (tree-structured drafting from target hidden states — explores multiple branches simultaneously), N-gram (frequency-based, zero-cost, no extra model), and DFlash (block-structured with hidden state reuse).

How It Works

1. Draft phase. The draft model generates K candidate tokens autoregressively. Fast because the model is small. EAGLE generates a tree of candidates rather than a single chain, so if the first choice is rejected, alternatives are available.

2. Verify phase. All K draft tokens are fed to the target model in one forward pass. The target processes positions 1..K in parallel, producing K sets of logits.

3. Accept/reject. Compare each draft token against the target's top choice. Accept all consecutive matches. At the first rejection, discard it and all subsequent drafts.

4. Bonus token. At the rejection point (or after all K accepted), sample one token from the target's distribution. This guarantees every verify pass produces at least one new token — speculative decoding is never slower than standard decoding.

5. Update state. KV cache updated with accepted + bonus tokens. Draft model resets to last accepted position. Loop repeats.

The key insight: the target model's forward pass costs the same for 1 token or K tokens, because the bottleneck is loading 140 GB of weights from HBM, not the actual matrix multiplications.

Standard decoding:    [target] → t1 → [target] → t2 → [target] → t3
                       30ms          30ms          30ms
                       Total: 90ms for 3 tokens

Speculative decoding: [draft ×5] → [target verifies 5] → accept 4, bonus 1
                        5ms              30ms
                       Total: 35ms for 5 tokens (2.6x speedup)

Source Code

SOURCE CODE python/sglang/srt/speculative/dflash_utils.py
def compute_accept_reject(draft_tokens, target_logits, draft_logits):
    target_choices = torch.argmax(target_logits, dim=-1)
    accepted = (draft_tokens == target_choices)           # ← accept if same token
    accept_mask = torch.cumprod(accepted.int(), dim=-1)   # ← consecutive matches only
    num_accepted = accept_mask.sum(dim=-1)
    bonus_token = target_choices[num_accepted]            # ← always make progress
    return num_accepted, bonus_token
TRY IT

Speculative Decoding Simulator

Watch the draft model propose tokens and the target model verify them.

75%
Draft Model (fast)
Target Model (verify)
0
Rounds
0
Accepted
-
Accept Rate
-
Speedup

Further Reading

Key Takeaway: Speculative decoding turns the memory-bandwidth bottleneck into an advantage. A cheap draft model proposes, the expensive target verifies in bulk. 2-4x faster with identical output distribution.
Tutorial Complete!