Batch inference workloads for causal transformer models frequently process sequences that share common prefixes, like system prompts, few-shot examples, or shared queries in reranking tasks. Standard inference engines treat each sequence independently, redundantly recomputing identical MLP activations for every copy of the shared prefix.
To eliminate this redundancy, we developed RadixMLP, a technique that exploits the position-wise nature of MLPs, LayerNorms, linear projections, and embeddings. RadixMLP dynamically maps batches to a prefix trie, gathering shared segments into a compressed representation for position-wise computation and scattering results back only at attention boundaries.
In end-to-end serving benchmarks on MS MARCO v1.1 with Qwen3 models (0.6B to 8B parameters), RadixMLP achieves 1.44–1.59x speedups in realistic reranking workloads, with up to 5x speedups on synthetic benchmarks with longer shared prefixes.
RadixMLP is open-source and production-ready: you can use it today on Baseten as part of Baseten Embeddings Inference (BEI—sample config below) as well as TEI. Reach out if you have any questions.
On synthetic benchmarks with longer shared prefixes, we see 2.74x-4.98x speedups using RadixMLP with Qwen models of different sizes.The problem: Redundant computation in batch inference
Consider a typical batch inference workload:
Embedding models prepend the same global context to every document.
Cross-encoder rerankers pair each candidate passage with the same query.
Classification systems reuse the same few-shot examples across inputs.
Even with ragged layouts that eliminate padding, duplicate prefix tokens are still materialized as distinct rows in GPU memory. The model therefore recomputes identical activations for each copy. This inefficiency is especially pronounced during prefill, where position-wise components (MLPs, LayerNorms, and linear projections) account for a large fraction of total FLOPs.
This observation motivates RadixMLP, a technique designed to eliminate redundant prefix computation by exploiting the position-wise structure of transformer layers.
RadixMLP: Technical deep dive
While causal self-attention is a sequence-mixing operation, the MLP, LayerNorm, linear projections (Q, K, V, O), and embeddings are position-wise. For tokens with identical causal history (i.e., the same prefix path), these per-token computations are identical and can be reused.
This invariance persists across layers: if tokens share a prefix, their position-wise outputs remain the same at each layer. Since attention for prefix tokens depends only on earlier tokens within that same prefix, its outputs are likewise identical. As a result, the shared representation can be preserved layer by layer.
RadixMLP integration into causal transformers. The hidden state remains in compact space (N’ tokens) between layers. Position-wise operations (pre-attention norm, projections, MLP) run in compact space. Only the attention operation requires the original space (N tokens), with scatter-gather operations after attention.Prefix trie construction
RadixMLP constructs a prefix trie over the batch. Each trie node is identified by its parent and the next (token_id, position_id) pair, so nodes are path-specific. No GPU-resident tree structure is required; in practice, we only materialize the resulting index maps (I_gather, I_scatter) for GPU execution.
For a batch of B sequences with shared prefix length P and per-sequence suffix length S, the compression ratio is:
Where is the total number of tokens and is the number of compacted tokens. In batch workloads with shared prefixes, typically .
Gather and scatter operations
We define two index mapping tensors:
Gather indices (
I_gather): For each position in the compact buffer, this specifies which position in the original input to read and selects a representative occurrence for that compact token.Scatter indices (
I_scatter): For each position in the original layout, this specifies which position in the compact result buffer to read and broadcasts computed results back to all duplicate positions.
The forward pass for the MLP block is modified as follows:
X_unique = X[I_gather]
Y_unique = MLP(X_unique)
Y_restored = Y_unique[I_scatter]Since the MLP has complexity , the cost is reduced to . As the gather/scatter operations are memory-bound copies, the net speedup approaches the compression ratio as the hidden dimension grows.
Integration with attention
A critical requirement is causal consistency. While position-wise layers are independent per token, the attention mechanism requires the full context window.
We place scatter and gather operations at the boundary between position-wise and sequence-mixing computations:
Compute all position-wise operations in compact space: pre-attention LayerNorm, RoPE, Embedding Layer, Q/K norms, and the Q, K, V projections on tokens
Scatter the resulting Q/K/V tensors back to the full -token ragged layout
Run the attention operation (FlashAttention) in full space
Gather the attention output back to a compact space for the O projection, post-attention LayerNorm, MLP, and residual additions
Position indices are preserved alongside compact tokens: we gather the position ID tensor using I_gather, so each compact token retains its original sequence position for RoPE computation.
Key innovation: Stateless design
Unlike PagedAttention or RadixAttention, which require persistent KV caches with block tables, eviction policies, and distributed coordination, RadixMLP operates entirely within a single forward pass.
Beyond that, it's the first cache-like operator for training and deduplicating tokens based on common prefixes. This makes it valuable for batch inference, where maintaining caches for millions of heterogeneous documents is impractical. It enables libraries like Baseten Embeddings Inference (BEI) to achieve cache-like speedups, even when inference times are in the 5-millisecond range.
CPU-side scheduling
Constructing the trie and generating index maps on GPUs would introduce synchronization overheads. Instead, we utilize a pipelined scheduler:
While the GPU executes batch
t, the CPU scheduler analyzes the request queue for batcht+1The scheduler constructs the prefix trie and pre-calculates
I_gatherandI_scatter
With an efficient Rust implementation, computing approximately 16,384 tokens takes between 129μs and 750μs on a single-thread Intel Xeon CPU (depending on the prefix ratio). This is around 3 orders of magnitude below the corresponding GPU inference time, making even synchronous usage feasible.
Results
MS MARCO v1.1 benchmarks
We evaluated a real serving pipeline by benchmarking TEI on MS MARCO v1.1 validation query–passage pairs. Each request embeds multiple query–document pairs formatted with a reranker-style Qwen3 chat template.
These gains are smaller than synthetic forward-pass speedups because prefix sharing is less extreme in this workload and because end-to-end latency includes non-model overheads. Still, the improvement remains substantial in a practical serving setting.
Synthetic benchmarks
For controlled evaluation, we used a synthetic fixed batch size of B=32 sequences, varying the prefix length (shared across all sequences) and suffix length (unique per sequence).
We define the compact-token ratio and the corresponding compression ratio . The results demonstrate three key trends:
Prefix length: Longer shared prefixes yield lower (higher ) and greater speedups
Model size: Larger models benefit more from RadixMLP, as MLP compute dominates over gather/scatter overhead
Suffix ratio: Shorter suffixes (higher prefix:suffix ratio) maximize compression benefits
Comparison to vLLM
We compared TEI with and without RadixMLP against vLLM v0.13.0 on two MS MARCO configurations:
TEI without RadixMLP is the worst-performing inference system. With RadixMLP enabled, TEI outperforms vLLM for the Qwen3-0.6B model. For larger models, vLLM outperforms TEI by a 3-7% margin.
A key difference is memory usage: TEI uses ~17GB vs ~79GB for vLLM with Qwen3-8B, as no additional memory is allocated for a paged KV cache.
Implementation: Deploying RadixMLP with Truss
Here's how to deploy a Qwen3 model with RadixMLP using Truss:
1# Qwen3 with RadixMLP
2
3model_name: BEI-qwen3-radixmlp
4python_version: py39
5
6model_metadata:
7 example_model_input:
8 encoding_format: float
9 input:
10 type: string
11 model: model
12
13resources:
14 accelerator: H100
15 cpu: 1
16 memory: 10Gi
17 use_gpu: true
18
19trt_llm:
20 build:
21 base_model: encoder_bert
22 checkpoint_repository:
23 repo: michaelfeil/Qwen3-Embedding-0.6B-auto
24 revision: main
25 source: HF
26 max_num_tokens: 32768
27 runtime:
28 webserver_default_route: /v1/embeddingsDeploy with:
truss push --publishWhen to use RadixMLP
RadixMLP excels in scenarios with high prefix redundancy in prefill-heavy workloads. Some such workloads are often:
Reranking: Query-passage pairs with shared system prompts
Embeddings: Global context prepended to every document
Post-Training: Shared system templates in RLHF
Classification: Few-shot examples reused across inputs
Instruction following: Shared instruction templates
Related work
RadixMLP operates orthogonally to several existing optimization techniques that are utilized by Baseten today:
PagedAttention: Introduces paged, block-based KV management with block-level sharing. These stateful approaches require complex scheduling and are limited by block granularity, whereas RadixMLP offers block-free, stateless reuse within individual batches.
RadixAttention: Organizes prefixes in a radix tree for cross-request KV reuse. Requires persistent state management, while RadixMLP operates within a single forward pass.
FlashInfer Cascade Inference: Computes attention over shared prefix segments once, then merges partial attention states with per-request suffix attention. RadixMLP complements this by optimizing position-wise components. This is similar to `HydraGen` kernels.
These approaches are fully complementary—RadixMLP can stack with KV cache, attention optimizations, or batching/scheduling optimizations.
Training compatibility
While the focus of this work is inference, RadixMLP is compatible with training because the compaction operators are differentiable. The compacting gather is an index_select, and its backward pass is a scatter-add into the original layout (index_add), accumulating gradients when multiple original tokens map to the same compact representative.
We validated both forward and backward correctness on a 2-layer Qwen3-style model across synthetic prefix-sharing patterns. With PyTorch SDPA attention as the reference implementation, enabling RadixMLP produces numerically identical logits up to the specified tolerances (rtol=1e-4, atol=1e-4) and gradient differences on the order of or lower.
RadixMLP is open-source
RadixMLP and Index-Select Kernels, as well as the experimental results, are open-source under the MIT License. You can find the GitHub repository here.
You can read the full paper on arXiv here. Reach out to talk to our engineers about optimizing model performance for your workloads.