1. The Big Picture
Every LLM serving system must turn HTTP requests into GPU-generated tokens as fast as possible.
The Problem
You have a large language model. Users send prompts and want completions back with sub-second latency. The naive approach — tokenize, run forward passes, detokenize, all in one thread — collapses under concurrent load. Tokenization blocks the GPU. Detokenization blocks the scheduler. A single slow request holds up the entire pipeline.
The Idea
SGLang decomposes LLM serving into three asynchronous processes connected by ZMQ message-passing:
TokenizerManager receives HTTP requests, tokenizes text, and handles multimodal inputs. Scheduler manages batching, KV cache allocation, and GPU execution. DetokenizerManager converts output tokens back to text and streams results.
This means tokenization of new requests overlaps with GPU computation on current requests, which overlaps with detokenization of completed requests. No stage blocks another.
Why Separate Processes, Not Async Tasks?
| Approach | GIL Contention | Crash Isolation | CPU/GPU Separation | Verdict |
|---|---|---|---|---|
| Separate processes + ZMQ | None (each has own GIL) | One crash doesn't kill others | Clean boundary | Chosen |
| Async tasks in one process | GIL serializes CPU work | One crash kills everything | Shared memory contention | Too fragile |
How It Works
A request's journey through SGLang follows this pipeline. Each arrow represents a boundary where work can be pipelined — while the scheduler runs a GPU forward pass on batch N, the tokenizer is already preparing batch N+1, and the detokenizer is streaming results from batch N-1.
HTTP POST /generate
→ GenerateReqInput (raw user input: text, params, images)
→ TokenizerManager (tokenize, validate, normalize)
→ TokenizedGenerateReqInput (token IDs, sampling params, MM embeddings)
→ [ZMQ send to Scheduler]
→ Req (internal request object with KV cache tracking)
→ ScheduleBatch (grouped requests for one forward pass)
→ ForwardBatch (GPU tensors for the model)
→ output tokens (from GPU forward pass)
→ [ZMQ send to Detokenizer]
→ text response (streamed back to HTTP client)
The scheduler's event loop is the heart of the system. On each iteration it: (1) polls for new tokenized requests from ZMQ, (2) validates and queues them, (3) decides what to run next — a prefill batch for new requests, a decode batch for in-progress generation, or both, (4) executes the GPU forward pass, and (5) processes results, checking stop conditions and routing finished tokens downstream.
Notice the elegant simplicity of the loop below. No thread pools, no async/await, no callbacks. It is a pure synchronous loop that runs as fast as the GPU allows. recv_requests() is non-blocking (ZMQ poll with zero timeout), so the loop never stalls waiting for input. When there is no batch to run, on_idle() performs maintenance like cache eviction and defragmentation.
Source Code
@app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request):
if obj.stream:
async def stream_results():
async for out in _global_state.tokenizer_manager.generate_request(
obj, request
): # ← async generator yields partial results
yield b"data: " + dumps_json(out) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
else:
ret = await _global_state.tokenizer_manager.generate_request(
obj, request
).__anext__() # ← single-shot: take first result
return orjson_response(ret)
@DynamicGradMode()
def event_loop_normal(self):
while True:
recv_reqs = self.recv_requests() # ← poll ZMQ for new requests
self.process_input_requests(recv_reqs) # ← validate and enqueue
batch = self.get_next_batch_to_run() # ← scheduling decision
self.cur_batch = batch
if batch:
result = self.run_batch(batch) # ← GPU forward pass
self.process_batch_result(batch, result) # ← check stop conditions
else:
self.on_idle()
self.last_batch = batch
Request Flow Animation
Click "Send Request" to watch a request flow through the pipeline.
Further Reading
- SGLang: Efficient Execution of Structured Language Model Programs — Introduced the SGLang system, RadixAttention, and the structured generation language.
2. Tokenization & the Request Lifecycle
Three input paths converge to one Req object — the single source of truth for every in-flight request.
The Problem
LLM APIs accept wildly diverse inputs: raw text, pre-tokenized IDs, raw embeddings, images, audio. Before the scheduler can do anything, all of this must be normalized into a uniform internal representation: the Req object.
The Idea
Path 1: Text. The most common case — send a string, get token IDs back from the tokenizer. Path 2: Pre-tokenized IDs. Power users send input_ids directly. Path 3: Raw embeddings. For retrieval or adapter models. For multimodal inputs, a fourth layer activates on top of any path: the MM processor converts images/audio into embedding tensors.
Source Code
class Req:
def __init__(self, rid, origin_input_text, origin_input_ids, sampling_params, ...):
self.rid = rid # ← unique request ID
self.origin_input_ids = origin_input_ids # ← tokenized prompt (frozen)
self.output_ids = [] # ← generated tokens (grows)
self.fill_ids = [] # ← input + output combined
self.kv_committed_len = 0 # ← tokens in radix cache
self.kv_allocated_len = 0 # ← tokens with GPU allocation
self.finished_reason = None # ← stop condition
self.stream = stream
async def _tokenize_one_request(self, obj):
if obj.input_embeds is not None: # ← Path 3: raw embeddings
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is not None: # ← Path 2: pre-tokenized
input_ids = obj.input_ids
else: # ← Path 1: text → tokenize
input_ids, token_type_ids = await self._tokenize_texts(input_text)
if self.mm_processor and obj.contains_mm_input():
mm_inputs = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, audio_data=obj.audio_data,
input_text=(input_text or input_ids),
) # ← produces image/audio embeddings
Req object starts with immutable input data and accumulates mutable state (output_ids, kv_committed_len) as generation proceeds.
3. KV Cache & GPU Memory Management
The KV cache is the largest memory consumer. Two pools and zero fragmentation make it manageable.
The Problem
To produce token N+1, the model needs KV pairs from all tokens 1 through N. A single Llama 70B request with 4K context consumes ~1.3 GB of KV cache. Naive allocation wastes memory on padding since most requests are shorter than the maximum.
The Idea
ReqToTokenPool is a 2D table mapping each request to its KV slot indices. TokenToKVPoolAllocator manages a free list of physical KV indices. This two-level design decouples logical structure from physical storage, enabling non-contiguous allocation with zero fragmentation.
req_to_token tensor IS the page table: req_to_token[req_slot, token_pos] = physical_kv_index. This indirection means requests can use non-contiguous physical memory without knowing or caring.
Why Paged Pools, Not Contiguous Allocation?
| Approach | Fragmentation | Prefix Sharing | Memory Waste | Verdict |
|---|---|---|---|---|
| Two-level paged pool | Zero | Natural (shared indices) | None (per-token alloc) | Chosen |
| Contiguous pre-allocation | Severe | Impossible | High (max_len padding) | Wastes memory |
Source Code
class ReqToTokenPool:
def __init__(self, size, max_context_len, device, enable_memory_saver):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
) # ← [req_slot, token_pos] → kv_index
self.free_slots = list(range(size))
def alloc(self, reqs):
need_size = len(reqs) - len([r for r in reqs if r.req_pool_idx is not None])
if need_size > len(self.free_slots):
return None # ← OOM: no free slots
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, req):
self.free_slots.append(req.req_pool_idx) # ← reclaim slot
Memory Grid Visualization
Each cell = one KV slot. Add requests to allocate, free them to release.
Further Reading
- Efficient Memory Management for Large Language Model Serving with PagedAttention — Introduced paged KV cache management for LLM serving (vLLM). SGLang builds on this with radix tree indexing.
ReqToTokenPool maps requests to tokens; TokenToKVPoolAllocator manages physical indices. Non-contiguous allocation is invisible to attention kernels.
4. RadixAttention — Prefix Caching
When users send similar prompts, recomputing the same KV from scratch is pure waste. RadixAttention eliminates it.
The Problem
In real deployments, prompts share enormous prefixes — system prompts, RAG documents, few-shot examples. Without prefix caching, each request recomputes identical KV pairs. A hash table has O(N) lookup per prefix length; a radix tree does it incrementally in O(1) per additional token.
The Idea
The RadixCache stores KV data in a radix tree. Each node holds a token segment and KV indices. Prefix matching walks the tree comparing tokens. Node splitting handles partial matches. Only leaf nodes are evicted (LRU/LFU/priority).
Why Radix Tree, Not Hash Table?
| Approach | Prefix Lookup | Partial Matching | Memory Sharing | Verdict |
|---|---|---|---|---|
| Radix tree | O(1) incremental per token | Natural (split nodes) | Automatic (shared path) | Chosen |
| Hash table | O(N) per lookup (hash entire prefix) | None (exact match only) | Manual dedup needed | Too slow for long prefixes |
Source Code
def _match_prefix_helper(self, node, key):
access_time = time.monotonic()
node.last_access_time = access_time # ← update LRU
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
prefix_len = self.key_match_fn(child.key, key) # ← compare tokens
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value) # ← partial match: split
node = new_node; break
else:
value.append(child.value) # ← full match: collect KV
node = child
key = key[prefix_len:]
if len(key): child_key = self.get_child_key_fn(key)
return value, node
Radix Tree Visualizer
Insert token sequences to build the tree. Similar prefixes share nodes.
Further Reading
- SGLang: Efficient Execution of Structured Language Model Programs — Introduced RadixAttention: radix tree-based KV cache for automatic prefix sharing across requests.
5. Scheduling — Who Gets the GPU Next?
The scheduler decides who goes first — and the wrong answer wastes everything the cache built.
The Problem
A naive FIFO queue ignores cache affinity. Request #47 might share a 90% prefix with cached data while Request #12 has zero overlap. The scheduler must make these decisions thousands of times per second.
The Idea
SGLang defaults to LPM (Longest Prefix Match), which sorts the waiting queue so requests with the longest RadixCache prefix hit go first. Instead of computing 4096 tokens of prefill, a request might only need 512 new tokens because the other 3584 are already cached. Alternative policies include FCFS (first-come-first-served), LOF (longest output first), and DFS-Weight.
Why LPM, Not FCFS?
| Policy | Cache Utilization | Fairness | Overhead | Verdict |
|---|---|---|---|---|
| LPM (Longest Prefix Match) | Maximizes reuse | May delay new prompts | O(N) prefix scan | Default |
| FCFS (First Come First Served) | Ignores cache | Perfect fairness | O(1) | Fallback when queue > 128 |
Two queues govern the flow. The waiting queue holds requests that have arrived but not yet begun prefill. The running batch holds requests in the decode phase. Every iteration, the scheduler decides how many new requests to promote from waiting into a prefill batch, without starving decode.
How It Works
The scheduler loop executes these steps on every iteration:
1. Receive requests. New tokenized requests arrive via ZMQ from the tokenizer. Each carries input token IDs, sampling params, and max output length.
2. Apply scheduling policy. calc_priority reorders the waiting queue. For LPM, this computes the prefix match length for each request against the RadixCache. When the queue exceeds 128 requests, it falls back to FCFS since computing prefix matches becomes expensive.
3. Build a batch with PrefillAdder. Walks the sorted queue, adding requests one at a time. Tracks three budgets: total input tokens (rem_input_tokens), chunk tokens (rem_chunk_tokens), and available KV cache pages. When any budget is exhausted, it stops.
4. Execute the forward pass. The assembled batch is sent to the model runner. Prefill requests run in extend mode; decode requests run one token at a time.
5. Process results. After the forward pass returns logits, the scheduler samples the next token, checks finish conditions (EOS, max length), and routes finished requests to the detokenizer.
Pseudocode: One scheduler iteration
new_reqs = recv_from_tokenizer()
waiting_queue.extend(new_reqs)
# Reorder by policy (LPM puts best cache hits first)
prefix_computed = policy.calc_priority(waiting_queue)
# Build batch within budget
prefill_batch = PrefillAdder.build(waiting_queue, token_budget, memory_budget)
# Merge with ongoing decode batch
batch = merge(running_batch, prefill_batch)
# GPU work
logits = model_runner.forward(batch)
# Post-process
for req in batch:
next_token = sample(logits[req])
if is_finished(req, next_token):
send_result(req)
running_batch.remove(req)
Source Code
def calc_priority(self, waiting_queue, running_batch=None):
if self.policy == CacheAgnosticPolicy.FCFS:
if self.enable_priority_scheduling:
SchedulePolicy._sort_by_priority_and_fcfs(waiting_queue, self.priority_sign)
return False
policy = self._determine_active_policy(waiting_queue) # ← FCFS fallback if queue > 128
if isinstance(policy, CacheAwarePolicy):
self._compute_prefix_matches(waiting_queue, policy) # ← find cache hits
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(waiting_queue) # ← best hit first
return prefix_computed
Further Reading
- Orca: A Distributed Serving System for Transformer-Based Generative Models — Introduced iteration-level scheduling (continuous batching) for transformer serving.
6. Continuous Batching & Chunked Prefill
Don't make the GPU wait for stragglers — mix prefill and decode in the same breath.
The Problem
A long prefill (100K context) blocks all decode requests, spiking TTFB for everyone. Traditional batching waits for all requests to finish before starting the next batch.
The Idea
Why Continuous + Chunked, Not Static Batching?
| Approach | GPU Utilization | TTFB | Complexity | Verdict |
|---|---|---|---|---|
| Continuous + chunked | Near 100% | Bounded by chunk size | Medium | Chosen |
| Static batching | Low (waits for slowest) | Unbounded | Simple | Too wasteful |
How It Works
Step 1: Prefill completes, merge into decode. After a prefill forward pass finishes, completed requests join the running decode batch. This is the "continuous" part — new decode requests join mid-stream, no synchronization needed.
Step 2: Chunked requests saved for later. If a prompt was too long for one chunk, the PrefillAdder records it as chunked_req. Next iteration, this request gets priority and resumes from where it left off.
Step 3: Mixed batches. In a single forward call, SGLang processes both prefill tokens (new requests) and decode tokens (ongoing generation). The attention kernel handles variable-length inputs via the backend metadata.
Step 4: Budget interplay. The chunk budget (rem_chunk_tokens) limits per-request prefill. The input budget (rem_input_tokens) limits total prefill. Decode tokens from the running batch eat into the total, ensuring the GPU is not overloaded.
Pseudocode: Continuous batching iteration
# After last forward pass completed:
if last_batch was prefill:
for req in last_batch:
if req fully prefilled:
running_batch.add(req) # Join decode batch
else:
save as chunked_req # Resume next iteration
# Build next batch:
new_prefill = PrefillAdder.build(
waiting_queue,
rem_input_tokens = budget - running_batch.decode_tokens,
rem_chunk_tokens = 8192
)
combined = merge(running_batch, new_prefill)
logits = model.forward(combined)
Source Code
if self.last_batch and self.last_batch.forward_mode.is_extend(): # ← prefill just finished
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch # ← move to decode
else:
self.running_batch.merge_batch(self.last_batch) # ← merge into running
7. Model Execution — The Forward Pass
All roads lead to model.forward() — but CUDA graphs pave them with one kernel launch.
The Problem
Hundreds of small CUDA kernels per forward pass mean significant launch overhead, especially during decode when each request contributes just one token. The model runner minimizes this through CUDA graph caching.
The Idea
SGLang's ModelRunner sits between the scheduler and GPU. It receives a ForwardBatch and dispatches to one of three paths: Extend (prefill, many tokens), Decode (one token), or Idle (pipeline sync).
Why CUDA Graphs, Not torch.compile?
| Approach | Latency Reduction | Flexibility | Warmup Cost | Verdict |
|---|---|---|---|---|
| CUDA graph replay | 2-3x for decode | Fixed shapes only | Per-shape capture | Default for decode |
| torch.compile | Variable | Dynamic shapes | Compilation time | Used for extend/prefill |
How It Works
Stage 1: Dispatch. _forward_raw checks three conditions for CUDA graph eligibility: decode mode, GraphRunner initialized, and batch shape matches a captured graph. If all three, replay.
Stage 2: Metadata init. For non-graph paths, the attention backend's init_forward_metadata builds KV index structures. For FlashInfer, this means ragged tensor indices mapping requests to KV pages. Precomputed once per batch, reused across all 32+ attention layers.
Stage 3: Model execution. Embedding lookup, N transformer layers (attention + MLP each), final vocabulary projection. In graph mode, all of this is a single graph.replay() call.
Pseudocode: Forward pass dispatch
def forward(forward_batch):
# Stage 1: Can we replay a CUDA graph?
if mode == DECODE and graph_runner.has_graph(batch_shape):
return graph_runner.replay(forward_batch) # Single launch
# Stage 2: Prepare attention metadata
attn_backend.init_forward_metadata(forward_batch) # Build KV indices
# Stage 3: Run the model
if mode == DECODE:
logits = model.forward(input_ids, positions, forward_batch)
elif mode == EXTEND:
logits = model.forward(input_ids, positions, forward_batch)
return logits
Source Code
def _forward_raw(self, forward_batch, ...):
can_run_graph = (
forward_batch.forward_mode.is_cuda_graph()
and self.graph_runner
and self.graph_runner.can_run(forward_batch) # ← check CUDA graph eligibility
)
if can_run_graph:
ret = self.graph_runner.replay(forward_batch) # ← replay recorded ops
return ModelRunnerOutput(logits_output=ret, can_run_graph=True)
if forward_batch.forward_mode.is_decode():
ret = self.forward_decode(forward_batch) # ← single-token decode
elif forward_batch.forward_mode.is_extend():
ret, _ = self.forward_extend(forward_batch) # ← multi-token prefill
return ModelRunnerOutput(logits_output=ret, can_run_graph=can_run_graph)
8. Attention Backends
FlashInfer, Triton, and Friends — the same math, four ways to run it fast.
The Problem
Attention is up to 70% of inference compute. Different hardware needs different kernels. A serving system needs a pluggable abstraction so new backends can be integrated without touching model code.
The Idea
Why Pluggable Backends?
| Backend | Hardware | Speed | Portability | Verdict |
|---|---|---|---|---|
| FlashInfer | NVIDIA | Fastest | NVIDIA only | Default |
| Triton | Any GPU | 10-20% slower | Portable | AMD/Intel fallback |
| TorchNative | Any | Slowest | Universal | Debug only |
SGLang uses a two-layer design. The model-facing layer is RadixAttention, a PyTorch module that every transformer layer calls. It reshapes Q, K, V tensors, optionally saves KV to cache, and delegates computation to whatever backend is active. The model code never imports FlashInfer or Triton directly.
The backend layer is the AttentionBackend abstract class. Each concrete backend implements three methods: init_forward_metadata (precompute KV index structures), forward_decode (single-token against full KV cache), and forward_extend (multi-token for prefill with partially populated KV).
How It Works
RadixAttention's forward: When torch.compile is active, it routes to a custom op (unified_attention_with_output) for compiler compatibility. Otherwise, it calls forward_batch.attn_backend.forward(), dispatching to decode or extend based on forward mode.
FlashInfer's decode path (most common production path on NVIDIA): init_forward_metadata builds a ragged index — for each request, a list of KV cache page IDs and valid tokens in the last page. The FlashInfer kernel reads this index, gathers scattered KV vectors, computes fused softmax attention, and writes output in a single kernel. Because the index is precomputed once per batch, all 32+ layers reuse the same metadata.
Triton's path uses similar logic expressed in Triton's Python DSL, making it portable to AMD and Intel GPUs. Tradeoff: typically 10-20% lower throughput on NVIDIA vs FlashInfer's hand-tuned CUDA.
Pseudocode: Attention dispatch flow
# In transformer layer N:
output = radix_attention.forward(q, k, v, forward_batch)
# Inside RadixAttention:
def forward(q, k, v, forward_batch):
reshape q, k, v to [tokens, heads, head_dim]
if torch_compile_active:
return unified_attention_custom_op(q, k, v)
else:
return forward_batch.attn_backend.forward(q, k, v)
# Inside AttentionBackend:
def forward(q, k, v):
if mode == DECODE:
return self.forward_decode(q, k, v) # 1 query vs full KV
else:
return self.forward_extend(q, k, v) # N queries vs partial KV
Source Code
def forward(self, q, k, v, forward_batch, save_kv_cache=True, **kwargs):
if k is not None:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
output = torch.empty_like(q)
unified_attention_with_output(q, k, v, output, save_kv_cache, self.layer_id)
return output # ← torch.compile custom op path
else:
return forward_batch.attn_backend.forward( # ← dispatch to active backend
q, k, v, self, forward_batch, save_kv_cache, **kwargs)
class AttentionBackend(ABC):
@abstractmethod
def init_forward_metadata(self, forward_batch): ... # ← precompute KV indices
def forward(self, q, k, v, layer, forward_batch, ...):
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch)
else:
return self.forward_extend(q, k, v, layer, forward_batch)
@abstractmethod
def forward_decode(self, q, k, v, layer, forward_batch): ... # ← 1 token vs full KV
@abstractmethod
def forward_extend(self, q, k, v, layer, forward_batch): ... # ← N tokens vs partial KV
Further Reading
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — The IO-aware tiling algorithm that made efficient attention practical.
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — Improved parallelism and 2x speedup over FlashAttention.
- FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving — The batched attention kernel library SGLang uses as its default backend.
AttentionBackend abstraction decouples kernel innovation from model logic. SGLang can adopt new kernels in days by implementing three methods.
9. Sampling — From Logits to Tokens
The model outputs 100,000 probabilities. How does SGLang pick exactly one token?
The Problem
After the forward pass, the model produces a logit vector over the full vocabulary. Greedy picks are repetitive; pure random is incoherent. Each request in a batch may have different sampling preferences.
The Idea
Sampling is a three-stage pipeline: scale, filter, pick. Temperature scaling controls distribution spread (0.1 = very confident, 2.0 = nearly uniform). Filtering — top-k, top-p (nucleus), and min-p — prunes away low-probability tokens. Finally, the surviving probabilities are normalized and a token is drawn via torch.multinomial.
The catch: SGLang serves a batch of requests, each with different parameters. SamplingBatchInfo packs per-request params into GPU tensors for vectorized execution. When all requests are greedy, a boolean flag is_all_greedy skips the entire softmax-filter-sample pipeline, jumping straight to argmax.
How It Works
1. Reshape and bias. Raw logits arrive as [batch_size, vocab_size]. Per-token logit biases (repetition penalty) are applied additively.
2. Greedy fast path. When is_all_greedy is true, call argmax and return. No softmax, no filtering, no random numbers. This is the hot path for many API workloads.
3. Temperature scaling. Divide each request's logits by its temperature: logits / temperatures. Softmax produces the probability distribution.
4. Min-p filtering. Mask out tokens with probability less than min_p * max_prob. Adaptive: when the model is confident, few tokens survive; when uncertain, many do.
5. Top-p (nucleus). Sort by descending probability, compute cumulative sum. Zero out tokens beyond the top_p threshold. A top-p of 0.95 keeps the smallest set with 95%+ combined probability.
6. Top-k. Keep only the top K tokens; zero the rest. Hard cutoff regardless of mass.
7. Multinomial draw. torch.multinomial(probs, 1) draws one token per request.
Each filtering stage is guarded by a batch-level flag (need_min_p_sampling, need_top_p_sampling, need_top_k_sampling). If no request in the batch uses top-k, the entire top-k kernel is skipped.
logits [batch, vocab]
↓
apply_logit_bias (additive)
↓
is_all_greedy? ──YES──> argmax → done
↓ NO
softmax(logits / temperature)
↓
need_min_p? ──YES──> zero out tokens < min_p * max_prob
↓
need_top_p? ──YES──> sort, cumsum, zero beyond threshold
↓
need_top_k? ──YES──> keep only top-k
↓
multinomial(probs, 1) → token_ids
Source Code
def forward(self, logits, forward_batch):
logits = logits.contiguous().view(-1, self.vocab_size)
if forward_batch.sampling_info.is_all_greedy:
batch_next_token_ids = torch.argmax(logits, dim=-1) # ← fast path
else:
probs = torch.softmax(logits / temperatures, dim=-1)
if forward_batch.sampling_info.need_min_p_sampling:
probs = self._apply_min_p(probs, min_ps) # ← min-p filter
if forward_batch.sampling_info.need_top_p_sampling:
probs = self._apply_top_p(probs, top_ps) # ← nucleus sampling
if forward_batch.sampling_info.need_top_k_sampling:
probs = self._apply_top_k(probs, top_ks) # ← top-K filter
batch_next_token_ids = torch.multinomial(probs, 1).squeeze(1)
return batch_next_token_ids
Sampling Parameter Explorer
Adjust sliders to see how temperature, top-k, top-p, and min-p affect the probability distribution.
is_all_greedy, need_top_p_sampling) skip entire GPU kernels. The greedy path is nearly free.
10. Constrained Decoding
How SGLang guarantees valid JSON from a 70-billion-parameter model, every single time.
The Problem
Models usually produce valid JSON, but "usually" is not enough for production. Retry-on-failure wastes GPU time. Constrained decoding makes invalid output structurally impossible.
The Idea
Constrained decoding runs an FSM alongside the language model. At each step, it computes allowed tokens as a vocabulary mask (1 = allowed, 0 = forbidden), applied before sampling by setting forbidden logits to negative infinity.
Why FSM Masking, Not Retry-on-Invalid?
| Approach | GPU Waste | Guarantee | Latency | Verdict |
|---|---|---|---|---|
| FSM vocabulary masking | Zero | Structural (impossible to fail) | Same as unconstrained | Chosen |
| Retry on invalid output | 2-10x | Probabilistic (may never converge) | Unbounded | Unreliable |
XGrammar is the fastest backend: compiles grammars into optimized FSMs with bit-level mask operations. Outlines provides a Python regex-to-FSM fallback. Compilation is async + cached — the first request pays the cost; subsequent requests reuse instantly.
Jump-forward decoding: When the FSM reaches a state where only one continuation is valid (e.g., after {"name": the next character must be "), the system skips sampling entirely and emits deterministic tokens directly, sometimes skipping 5-10 tokens in one jump.
How It Works
The constrained pipeline integrates at three points in the decode loop:
1. Grammar compilation (once per schema). GrammarManager checks its cache. Hit: clone the cached FSM (cheap pointer copy). Miss: dispatch compilation to ThreadPoolExecutor so the serving loop is not blocked.
2. Vocabulary masking (every decode step). Each request's grammar calls fill_next_token_bitmask, which consults the FSM state and sets bits for allowed tokens. Applied to logits: logits[~mask] = -inf. Sampling then proceeds normally.
3. State advancement (after sampling). The sampled token is fed to accept_token. The FSM transitions to a new state. Then try_jump_forward checks for unique continuations.
Request arrives with json_schema={"type": "object", "properties": ...}
↓
GrammarManager.get_or_create_grammar()
|— cache hit? → clone cached FSM
|— cache miss? → compile in ThreadPoolExecutor
↓
For each decode step:
FSM.current_state → fill_vocab_mask() → bitmask [vocab_size]
logits[~bitmask] = -inf
sample from masked logits → token_id
FSM.accept_token(token_id)
FSM.try_jump_forward() → skip deterministic tokens
Source Code
class GrammarManager:
def __init__(self, tokenizer, backend_type):
self.grammar_cache = {} # ← cache compiled grammars
self.executor = ThreadPoolExecutor() # ← background compilation
async def get_or_create_grammar(self, constraint_type, constraint_value):
cache_key = hash((constraint_type, constraint_value))
if cache_key in self.grammar_cache:
return self.grammar_cache[cache_key].clone() # ← reuse compiled
grammar = await asyncio.get_event_loop().run_in_executor(
self.executor, self._compile_grammar, constraint_type, constraint_value
) # ← compile in background
self.grammar_cache[cache_key] = grammar
return grammar.clone()
Further Reading
- Grammar-Constrained Decoding for Structured NLP Tasks without Finetuning — Foundational work on using formal grammars to constrain LLM token generation.
11. Scaling Up — Data & Tensor Parallelism
One GPU is not enough. Here is how SGLang splits work across 2, 4, 8, or 64 GPUs.
The Problem
A 70B model in FP16 needs 140 GB — far more than one GPU. Even if the model fits, one GPU cannot serve production traffic. You need to either split the model or replicate it.
The Idea
[batch, seq_len, hidden_dim] tensor. Over NVLink (900 GB/s) this takes microseconds; over Ethernet it becomes the bottleneck.
Why TP + DP, Not Just One?
| Strategy | Fits Larger Models | Throughput Scaling | Communication | Verdict |
|---|---|---|---|---|
| Tensor Parallelism only | Yes (split weights) | Limited by AllReduce | Per-layer, needs NVLink | For models > 1 GPU |
| Data Parallelism only | No (full replica) | Linear | None during inference | For throughput |
| Hybrid DP + TP | Yes | Linear × TP groups | Within TP group only | Production default |
Tensor Parallelism (TP) splits a single model across N GPUs. Each GPU holds 1/N of attention heads and 1/N of MLP width. After each layer, GPUs synchronize via AllReduce — a collective op that sums partial results. TP is the standard for models exceeding single-GPU memory. The tradeoff: AllReduce happens after every layer (twice — after attention and after MLP), so NVLink bandwidth (900 GB/s) is essential.
Data Parallelism (DP) replicates the entire model on N independent GPU groups. A DataParallelController routes requests to the least-loaded worker. Throughput scales linearly — no inter-replica communication during inference. Cost: each replica needs full model weights + KV cache.
Hybrid DP+TP: With 8 GPUs, run 2 DP replicas each using 4-way TP. DP routes requests to replicas; within each replica, TP splits the model. This serves models too large for 4 GPUs while doubling throughput.
How It Works
TP internals: At init, SGLang creates TP groups via PyTorch's process group API. With 8 GPUs and TP=4: groups [0,1,2,3] and [4,5,6,7]. Within each group:
• Attention is split by heads. With 64 heads and TP=4, each GPU handles 16. QKV weights partitioned column-wise; output projection row-wise, followed by AllReduce.
• MLP split by hidden dim. Gate/up projection column-wise; down projection row-wise + AllReduce.
• Each AllReduce communicates [batch, seq_len, hidden_dim]. With 4096 hidden dim in FP16, that is 8 KB per token per layer. Over 80 layers and 1000 tokens: ~640 MB per iteration.
DP internals: DataParallelController receives requests from the API server and dispatches. Two policies: round-robin (simple, fair) and shortest-queue (better for variable-length prompts). Each worker runs its own scheduler, KV cache, and RadixCache independently.
8 GPUs, TP=4, DP=2:
Request stream
↓
DataParallelController
| |
↓ ↓
Replica 0 Replica 1
[GPU 0,1,2,3] [GPU 4,5,6,7]
TP group 0 TP group 1
AllReduce AllReduce
within group within group
Source Code
class DataParallelController:
def event_loop(self):
while True:
recv_reqs = self.recv_requests()
for req in recv_reqs:
worker_idx = self._select_worker(req) # ← load balance
self.workers[worker_idx].send(req)
def _select_worker(self, req):
if self.policy == "round_robin":
idx = self.next_worker
self.next_worker = (self.next_worker + 1) % self.dp_size
return idx
else: # shortest_queue
return min(range(self.dp_size), key=lambda i: self.workers[i].queue_size)
Further Reading
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism — The foundational paper on tensor parallelism for large language models.
12. Speculative Decoding
A small model guesses the next 5 tokens. The large model checks them all at once.
The Problem
Autoregressive decoding is serial: each token requires a full forward pass. For large models, this is memory-bandwidth bound — the GPU waits for 140 GB of weights to stream from HBM. You pay the full cost to generate one token.
The Idea
Why Speculate, Not Just Decode Normally?
| Approach | Tokens per Target Forward | Cost | Net Speedup | Verdict |
|---|---|---|---|---|
| Speculative (K=5, 80% accept) | ~5 tokens | 1 draft + 1 target pass | 2-4x | Chosen for large models |
| Standard autoregressive | 1 token | 1 target pass | 1x (baseline) | Simple, always correct |
Speculative decoding exploits the fact that a small draft model (1B params) can predict what a large target model (70B) would generate with 70-80% accuracy. Let the draft model sprint ahead and generate K candidates quickly. Feed all K to the target model in a single forward pass — this works because the target can process K tokens in parallel, like processing a prompt. Accept consecutive matches, reject the first mismatch, sample a bonus token at the rejection point. Net speedup: 2-4x.
SGLang supports EAGLE (tree-structured drafting from target hidden states — explores multiple branches simultaneously), N-gram (frequency-based, zero-cost, no extra model), and DFlash (block-structured with hidden state reuse).
How It Works
1. Draft phase. The draft model generates K candidate tokens autoregressively. Fast because the model is small. EAGLE generates a tree of candidates rather than a single chain, so if the first choice is rejected, alternatives are available.
2. Verify phase. All K draft tokens are fed to the target model in one forward pass. The target processes positions 1..K in parallel, producing K sets of logits.
3. Accept/reject. Compare each draft token against the target's top choice. Accept all consecutive matches. At the first rejection, discard it and all subsequent drafts.
4. Bonus token. At the rejection point (or after all K accepted), sample one token from the target's distribution. This guarantees every verify pass produces at least one new token — speculative decoding is never slower than standard decoding.
5. Update state. KV cache updated with accepted + bonus tokens. Draft model resets to last accepted position. Loop repeats.
The key insight: the target model's forward pass costs the same for 1 token or K tokens, because the bottleneck is loading 140 GB of weights from HBM, not the actual matrix multiplications.
Standard decoding: [target] → t1 → [target] → t2 → [target] → t3
30ms 30ms 30ms
Total: 90ms for 3 tokens
Speculative decoding: [draft ×5] → [target verifies 5] → accept 4, bonus 1
5ms 30ms
Total: 35ms for 5 tokens (2.6x speedup)
Source Code
def compute_accept_reject(draft_tokens, target_logits, draft_logits):
target_choices = torch.argmax(target_logits, dim=-1)
accepted = (draft_tokens == target_choices) # ← accept if same token
accept_mask = torch.cumprod(accepted.int(), dim=-1) # ← consecutive matches only
num_accepted = accept_mask.sum(dim=-1)
bonus_token = target_choices[num_accepted] # ← always make progress
return num_accepted, bonus_token
Speculative Decoding Simulator
Watch the draft model propose tokens and the target model verify them.
Draft Model (fast)
Target Model (verify)
Further Reading
- Fast Inference from Transformers via Speculative Decoding — Introduced the draft-then-verify paradigm for LLM inference acceleration.
- EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty — Tree-structured speculative decoding using target model hidden states. Implemented in SGLang's speculative module.