Our Series E: we raised $300M at a $5B valuation to power a multi-model future. READ

Four Bits

This writeup on FLUX.2 [dev] inference optimization reveals exactly how we achieved a 1.6x speed improvement with no perceptible quality loss.

Four Bits

One of these images was generated by a model compressed to a third of its size. The other generated by the same model in full precision. Which is which?

Image A is the full-precision model, image B is the 4-bit modelImage A is the full-precision model, image B is the 4-bit model

Trimming off two-thirds of a model’s weights implies, by definition, a loss of information and a degradation in overall quality.

In this post, we explain exactly how we optimized the best image generation model on the market today with 4-bit quantization to achieve a 1.6x speed improvement with (almost) no quality impact.

The FLUX.2 [dev] model

First off, a deep dive into the architecture of the FLUX.2 [dev] model. This is a SOTA diffusion model, and follows most diffusion model architectures with some tweaks.

We run our inference, denoising steps, and kernel optimizations in SGLang.

Here’s an overview of the entire process.

Full process of a single denoising step for FLUX.2 [dev]. We will examine this diagram in chunks throughout this writeup.Full process of a single denoising step for FLUX.2 [dev]. We will examine this diagram in chunks throughout this writeup.

How does FLUX.2 [dev] run inference?

Image generation models start by passing the prompt through a text encoder for a single forward pass. For FLUX.2 [dev], this text encoder is a Mistral model. In this piece, we’ll skip an examination of the text encoder, and focus on the main inference loop for image generation.

The Mistral model runs the text prompt ( e.g “a blue tiger with amber eyes”). It adds meaning to this prompt, each passing layer growing richer in meaning.

FLUX.2 [dev] then extracts from three separate layers of the text encoder the hidden latents/activations and concatenates them together to create a 512x5120x3 embedding matrix. This captures the semantic meaning of what image we want to generate.

Simplified embedding matrix capturing semantic meaningSimplified embedding matrix capturing semantic meaning

At its core, FLUX.2 [dev] takes in this embedding matrix above, and one more input- the noisy latent. A noisy latent is one additional matrix, instantiated with random numbers. This additional matrix is our starting point. It has no meaning and it can be thought of as an empty canvas onto which our model will draw its output.

It then feeds both of these to two separate embedders (MLPs at heart):

  1. The X-embedder for the latent canvas

  2. The context embedder for the embedding matrix

Inputs to the embeddersInputs to the embedders

Outputting two separate hidden state matrices, the model then feeds the matrices into a transformer block that takes these states as input, representing the canvas (what has been drawn), and the state of the prompt (what it needs to draw).

8 Dual Transformer Blocks8 Dual Transformer Blocks

The output of this 8x transformer block is then fed to another transformer block, this one much larger, and takes as its input the concatenated result of the two output states.

48 Single Transformer Blocks48 Single Transformer Blocks

The output of this second transformer block gets transformed via an MLP (linear projection) into a matrix that matches the dimensions of our latent canvas.

More precisely, the final hidden matrix is just a delta that will be added to the original random canvas (the noisy latent instantiated at the start of inference).

Final output gets projected into the latent spaceFinal output gets projected into the latent space

Along with this delta comes a weight multiplier that emphasizes how seriously the model should take this nudge.

Take the analogy of an artist. At the beginning, the canvas is empty, and the artist can do a lot of work (tracing, outlining, filling in the background, coloring the sky) in a few passes, with a large paintbrush and heavy strokes. Consequent passes refine the details (adding wrinkles to a human face), are done with care, and in a few light strokes.

Similarly, at the start of inference (i.e., this is the first pass and our input latent canvas is complete noise), the model multiplies this nudge with a very high value. There is very little on the canvas, and it needs to do a lot of work On the other hand, by the 8th and final pass, the model is probably just refining very minor details and shouldn’t change much - the core image is already placed.

The inference process then take this final answer, this updated latent canvas, and feeds it back as the input for the next pass. The number of times this process is repeated, ie the number of passes the model goes through, is referred to as denoising steps, and is usually 8.

Inevitably, each step needs its own weight, referred to as a timestep. The earlier the denoising step, the higher the timestep. These time steps are used to create the shift, the scale, and the gate- modulation matrices used by the transformer to manipulate outputs (weight multipliers).

Full FLUX.2 [dev] diagramFull FLUX.2 [dev] diagram

For your convenience, here is a zoomed-in look at the two transformer blocks.

Single Transformer blockSingle Transformer block
Dual Transformer blockDual Transformer block

Profiling FLUX.2 [dev] inference

Within this architecture, the most common operation during inference is a matrix multiplication. In fact, a profiling run of the model showed that it spent 67% of its time running GEMM (General Matrix Multiplication) kernels.

Profiler output of a naive implementation shows matmul as the main bottleneckProfiler output of a naive implementation shows matmul as the main bottleneck

To be precise, the model spent 1.82 ms in each matmul kernel, and took, from prompt to image, 2.776s.

What if instead, we run these BF16 GEMM kernels in a lower precision, specifically FP4.

Why would that make inference faster?

On the memory side, memory transactions on the GPU don't care how many numbers they grab; they think in terms of bytes. Memory transactions move fixed-size chunks, so using smaller data types means we can fetch more values in the same transaction. Smaller data types increase effective bandwidth and cache residency (memory bound).

On the compute side, Tensor Cores can perform more operations with lower precision formats. By moving from FP16 to FP4, we access cores with 4x higher FLOPS (compute bound).

Whether our inference process is compute-bound or memory-bound (depends on batch size), quantization will make inference faster.

Matrix multiplication using the tiling algorithm in FP16 vs FP4Matrix multiplication using the tiling algorithm in FP16 vs FP4

Quantizing the model

The default value of the model is BFloat16, but we want to make it run as fast as possible, and so we use the smallest precision possible on a B200: NVFP4.

BFloat16 is a floating point format with 16 bits: 1 sign bit, 8 exponent bits, and 7 mantissa bits. FP4 has only 4 bits, which means it can represent just 16 distinct values. The challenge we're faced with is: how do we compress our weights into only 16 possible values without destroying the model?

  • BFloat16 Range: [minb,maxb][min_b, max_b]

  • FP4 Range: [minf,maxf][min_f, max_f]

The key insight is that we don't care about the theoretical range of BFloat16. We care about the actual distribution of our weights. Neural network weights tend to cluster around zero, and we can exploit this. If we can figure out which 16 values best represent our weight distribution, we can map each weight to its nearest representative with minimal loss.

The linear mapping

Let’s look at a simplified example to understand mapping in quantization.

If we were mapping from a domain of 4 bits to a domain of 3 bits, we’d expect our relationship to look a bit like this.

A simplified illustration showing a mapping from a 4-bit space to a 3-bit spaceA simplified illustration showing a mapping from a 4-bit space to a 3-bit space

In our case we’re mapping from a domain of 16 bits to a domain of 4 bits, which makes it look a lot more like this:

Full range mapping from BF16 to FP4Full range mapping from BF16 to FP4

Where each ‘bucket’ so to speak can expect to get mapped from 4096 different values, to 16 values. It is intuitive to say that the new distribution maintains just under 0.4% of the information it originally had, although this is not exactly true.

Let us examine, mathematically, how this relationship works:

We assume a linear relationship:

y=m(x+b)y = m(x + b)

*Not a typo, the bias is also scaled.

Then, we can establish the boundaries:

maxunquantized=m(maxq)+bmax_{unquantized} = m(max_q) + b
minunquantized=m(minq)+bmin_{unquantized} = m(min_q) + b

We call mm, the scale, ss, and we set zz (Zero Point) to be b-b, and let Alpha (α\alpha) and Beta (β\beta) represent the min and max of the tensors being mapped. Thus:

β=s(maxq)z\beta = s(max_q) -z
α=s(minq)z\alpha= s(min_q) -z

Using the definitions above, we derive the core formulas:

Quantization (True → Quantized)

xq=round(1sx+z)x_q = \text{round}(\frac{1}{s} \cdot x + z)

De-quantization (Quantized → True*)

x=s(xqz)\quad x = s(x_q - z)

*or the closest representation of truth we can get. This conversion is not lossless, but we will see ways to minimize the loss.

To find the specific values for ss and zz, we use the min and max bounds:

s=βαβqαqz=round(βαqαβqβα)s = \frac{\beta - \alpha}{\beta_q - \alpha_q} \qquad z = \text{round}\left( \frac{\beta\alpha_q - \alpha\beta_q}{\beta - \alpha} \right)

A working example

Let us create two matrices, A and B, of dimensions MxK and KxN, respectively. Their output, matrix C, has dimensions MxN. Using these formulas, we are able to find the scaling factor and zero point for each matrix. We then apply the quantization formulas to quantize and store matrices A and B.

How do we get matrix C?

Matrix Multiplication in MxKxNMatrix Multiplication in MxKxN

Our goal is to approximate the output C of the original un-quantized multiplication B@A. We want the result to be as close as possible to what we would have gotten with full precision, but all we have access to are the quantized versions. So what do we do?

To recover the final unquantized product, we expand the quantized representations.

A=s1(AqzA)B=s2(BqzB)A = s_1(A_q - z_A)\qquad B = s_2(B_q - z_B)
AB=(s1s2)[AqBqzBkAq,ikzAkBq,kj+KzAzB]AB = (s_1 s_2) [ A_q B_q - z_B \sum_{k} A_{q,ik} - z_A \sum_{k} B_{q,kj} + K z_A z_B ]

Visually, this means we have to do 3 additional operations to unbias our result back into what it should have been:

Quantizing and de-quantizing a matrixQuantizing and de-quantizing a matrix

This is the key to the unlocked performance: using fast integer arithmetic (A_quantized * B_quantized ), and then recover from the errors using three correction terms.

By applying these corrections and then multiplying the entire integer result by the product of the scales, we can cast the matrix back to its original floating-point range with surprisingly little error.

This type of 4-Bit quantization is known as GGUF (GPT-Generated Unified Format) quantization. GPTQ (Generative Pre-Trained Quantization) is too nuanced and may be covered in subsequent blogs.

The problem with linear mapping

This is too intensive. We're doing too many additions and too many corrections. Previously, we had to do 2*MNK operations per matmul (each element in the output matrix of a matmul requires K additions and K multiplications, and there are MxN elements in the output matrix).

Now have to perform 5*MNK (on top of the 2MNK, we have to apply 3 corrective additions per element), resulting in 2.5x more work.

We can solve this if we assume symmetry. Only one corrective matrix is applied for scaling, and no bias correction is needed: 3*MNK total work. To get rid of all of the intermediate terms, we set the Zero-Point to exactly 0 every time, regardless of where it ‘truly’ should be:

z=round(βαqαβqβα)=>0z = \text{round}\left( \frac{\beta\alpha_q - \alpha\beta_q}{\beta - \alpha} \right) => 0

This forces our dataset to be symmetric. Going back to our buckets example, we quantize from a 4 bit domain to a 3 bit domain. Let our original data points be -1, 0, 1, 2, 3, 4, 5.

If we choose not to bias, in order to reduce the amount of work done, we find the following:

Mapping asymmetric dataMapping asymmetric data

If we wanted to capture true precision, we would scale and apply a shift to our data such that the zero point is centered around the quantized zero point, and the data is spaced around the full range.

Biasing asymmetric data to use full rangeBiasing asymmetric data to use full range

In this case no data would be lost as each of our data points would be mapped to a distinctive quantized value, a bijective mapping that can be reversed with our scale and bias terms.

Without shifting, each bucket in the quantized range now holds two values while half of our buckets are unused; we wasted valuable quantized range and lost information. Our new system can no longer see the difference between the number 2 and the number 3.

The reason this works, is that the activations (almost) never have such skewed distributions, so this problem never arises. Even in the rare occasion it does arise, any loss in quality is offset by the extreme gains in performance.

*This is not always true. In the next blog, we explore in depth layers that have very skewed distributions, and the fix for those.

With the bias now always set to zero, we take as our running example a 2x3 matrix X, and a 3x4 matrix W , and quantize them, do the multiplication, de-quantize them, and compare the final output with the true result of W@X (if we had just done a direct BF16 matmul).

A running example of symmetric quantizationA running example of symmetric quantization

Noting that, for de-quantization of the output matrix, our equations were:

S1=αqαmax\tensor\a,S2=αqαmax_tensor\b,x=1S1S2xqS_{1} = \frac{\alpha_{q}}{\alpha_{max\tensor\a}}, \qquad S{2} = \frac{\alpha{q}}{\alpha_{max\_tensor\b}}, \quad x = \frac1{S_1*S_2}*x{q}

An attempt at doing so shows us that we get 31.2% error.

A better approach: Blockwise Quantization

What if instead of maintaining one single scale per matrix, we divide the matrix into submatrices and compute a scaling factor per submatrix? Specifically, for our use case of NVFP4, each submatrix is a block of 16 elements.

Here's how it works: instead of multiplying the full quantized matrices directly, we tile them. We tile across the K dimension, taking block-sized (16-element) chunks at a time. For each tile, we perform the matmul on these tiles and get a partial result.

We then take the scales of each tile used for X along its blocks and the scales used for W along its blocks, and multiply them together to get a per-element scale factor. We apply this scale to the tile output. We repeat this for all tiles across the full matrix multiplication, then sum all the scaled outputs together.

Using this approach, we reduce our error rate to 6.3%.

Matmul with one-level scalingMatmul with one-level scaling

The overhead problem

The problem with blockwise quantization is the overhead we incur from storing all these scales. Each scale factor is stored as a 32-bit FP number, so that's 4 bytes per scale. If we have M × N blocks, that's 4 × M × N bytes of overhead.

Blockwise Quantization: two-level scaling

A simple solution arises: what if we quantized our scaling matrices?

We take our regular scaling matrix and divide it by a global scale factor. This global scale factor is the ratio of FP8 max over the maximum value in the scaling tensor. Because we've divided by this global factor, we can now store the scales as FP8 instead of FP32.

Matmul with two-level scalingMatmul with two-level scaling

We quantize the matrices using these FP8 scales and tile them as before. When we need to dequantize, we multiply the FP8 scales by the global scale factor to reconstruct the original scale values. We then apply these reconstructed scales element-wise to each tile, and sum the results.

This approach performs with the exact same error percentage, but with one quarter of the storage overhead.

We could reduce overhead even further by increasing the block size. Instead of 16 elements per scale factor, we could use 32. In fact, this is exactly what the MXFP4 standard does, but NVIDIA's FP4 implementation chose to stick with 16-element blocks, trading storage efficiency for better numerical precision. Smaller blocks mean more scale factors to store, but each block can more accurately capture the local range of its values.

Performance after quantization

Armed with this understanding, we write our quantization kernels, and run a quantization config wire throughout our model, re-directing all linear layers to use our FP4 matmul kernels. The model’s weights are already pre-quantized, so really the only overhead incurred is quantizing the activations on the fly. The technical term for this process is Post-Training Quantization.

A profiler showing a forward pass through the quantized modelA profiler showing a forward pass through the quantized model

This new implementation runs end-to-end inference in 2.0s, a 1.38x speedup from the non-quantized version. But can we do better?

Profiling shows that most of our time is wasted computing the global maximum, having to:

  • Go through the entire tensor to get the maximum value

  • Return said value

  • Compute the scaling factor based on said value

  • Move the scale-factor tensor to the GPU

  • Launch the gemm kernel

This is too slow. How can we go even faster without impacting the quality of the model?

So far, we have been doing Dynamic Post Training Quantization, where we collect the distributions of the activations to calculate the zero point and scale factor needed, and do so for each activation. After all, we need the global value to perform high-quality block-wise quantization.

But what if there existed a magic number that would serve as a good approximation…

A fixed global max

The new process we’re about to follow is Static Post Training Quantization.

In contrast to dynamic quantization, static quantization does not calculate the 

The goal is simple: find one value that can serve as a good estimate for the global scale across all tensors.

The cleanest approach would be to do distribution mapping using Kullback-Leibler (KL) divergence. We run a calibration dataset, tune the scaling factor to minimize the divergence/distribution drift, and use that as the scaling factor set in stone for each layer moving forward: maintaining one scale per layer

In our case however, we found the model so well-behaved, it was sufficient to opt for a simple grid search auto-tuning algorithm: running a script that generated 10 different images from 10 different prompts with high-complexity, repeated across 30 different scale values. We then compared each set of images to find the value that best parameters.

We then used that parameter for all layers…and it worked.

Here is an example:

Image comparison with different hard-coded max valuesImage comparison with different hard-coded max values

Once we settled on an auto-tuned value (in this case a_max=1024), we were able to skip directly to the GEMM kernels:

A profiler showing an optimized forward pass through the quantized modelA profiler showing an optimized forward pass through the quantized model

Getting us a total latency of 1.81 seconds, or a 1.6x speedup versus the original implementation.

Conclusion

It was surprising to see a model trained without QAT (Quantization Aware Training) perform this well with static quantization, the weakest form of PTQ (Post-Training Quantization). In our findings, this strategy does not generalize, and FLUX.2 [dev] is unique in this regard. In the next blog, we explore other strategies when static quantization fails.

Custom kernel engineering work is an important piece of the Baseten Inference Stack, as evidenced by our recent performance gains on Wan 2.2 for video generation. If you’re interested in doing deep technical work at the frontier of inference optimization, join our engineering team.

Appendix

Appendix 1.0: Why use SGLang

SGLang OverviewSGLang Overview

We can't expect everyone to have the hardware required to serve the model themselves. So we need to establish an endpoint through which users can send requests. The model runs inference on our hardware, and we send back the result. But we also need a framework that has certain optimizations already implemented:

  • Paged KV caching

  • Radix attention

  • Batching

… to name a few. These won't be covered here, but suffice to say these are optimizations that make serving run faster, and we make use of libraries with these optimizations implemented.

SGLang architecture relevant to diffusion models

SGLang ArchitectureSGLang Architecture

In SGLang, the entry point is the launch server file, which spawns the GPU worker and then launches the HTTP server. The server initializes itself and connects to the backend scheduler, which sets up routers including the image API.

When a user sends a request to the HTTP endpoint, it gets converted to an internal request and forwarded to the scheduler. The scheduler acts as the bridge between HTTP and GPU. It runs an event loop that continuously receives requests, queues them, and dispatches them to workers. The GPU worker receives the request from the scheduler and has all the models loaded. It handles the forward pass through the actual model architecture via a composed pipeline that executes all the stages.

Pipeline stages

SGLang Pipeline Stages VisualizedSGLang Pipeline Stages Visualized

In our case, the pipeline stages are:

  1. Input validation: Sanity checks on the request

  2. Text encoding: A full forward pass through the Mistral model to get our embedding matrix

  3. Latent preparation: Create the random noisy latents that serve as our starting canvas

  4. Timestep preparation: Set up the denoising scheduler. The number of timesteps equals the number of denoising steps, and each timestep defines the strength of randomness at that step

  5. Denoising: The core stage where the transformer runs for each timestep. A forward pass produces a prediction, and the latents are updated accordingly

  6. Decoding: Take the final latent output and convert it to an image

On that last stage: the latent space is a compressed representation of the image. A 1024×1024×3 image (3.1M values) gets compressed to 64×64×128 (524K values). The VAE learned to encode images into this compressed space, the DiT operates entirely within it, and then the VAE decodes back to pixels.

Appendix 2.0: A proof for accumulating in higher precision

The value of an FP4 number is calculated using the following hardware representation:

Value = Sign × 2^Exponent × 1.M

Where 1.M is the mantissa with an implicit leading 1:

  • 1.1 in binary represents 1.5 (1 + 1/2)

  • 1.11 in binary represents 1.75 (1 + 1/2 + 1/4)

So:

  • Max Value (0111): 6.0

    • Calculation: 2^(3-1) × 1.5 = 2^2 × 1.5 = 4 × 1.5 = 6

  • Min Value (1111): -6.0

When performing a matrix multiplication, we iterate through K elements to calculate the dot product for a single output value, and we need to do this for M×N values.

Matrix MultiplicationMatrix Multiplication

When multiplying two FP4 numbers, we need 8 bits to store the result. Why? Let's multiply the max FP4 value by itself, expressing 6 as a product of exponent and mantissa: 0111 × 0111 = (2^2 × 1.5) × (2^2 × 1.5) = 2^4 × 2.25

The decimal 2.25 is 10.01 in binary. However, floating point format requires normalized notation (1.something). We cannot store 10.01 directly; we must shift to normalize: 10.01 = 1.001 × 2^1

So we need to add 1 to whatever the exponent value would have been.

Breaking down the storage requirements:

  • Mantissa: We need 4 bits to store .001 (plus the implicit leading 1)

  • Exponent: The exponent was 2^4, but after normalization it becomes 2^5. To store 5, we need 3 bits minimum

  • Sign: We need 1 sign bit

Total: 4 (mantissa) + 3 (exponent) + 1 (sign) = 8 bit

This is a case-specific result. For larger K (hidden) dimensions, even FP16 accumulators could be used.

Appendix 3.0: References

Subscribe to our newsletter

Stay up to date on model performance, inference infrastructure, and more.


Model performance
Ali Taha
Four Bits