A 70 GB model on a 48 GB MacBook

Running Qwen3.5-35B on Apple Silicon with 3-bit weight compression. What fit, what broke, and why 1 tok/s still mattered.

Last night I generated coherent text from a 35B mixture-of-experts model on my MacBook Pro.

The checkpoint is about 70 GB. The machine has 48 GB of unified memory.

It ran. The output made sense. The speed was about 1 token per second.

The machine is an M4 Pro MacBook Pro with 48 GB of unified memory. Nothing exotic.

That is slow enough to disappoint you immediately. It is also enough to prove something useful.

This post is about what that result actually means: the compression math that made the model fit, the memory wall that showed up next, the design choices that mattered, and why a system can still be worth building even when the first number looks bad.

The real problem

New open models now arrive faster than their quantized variants.

A model drops. Then comes the familiar delay. Someone needs to run GPTQ or AWQ, pick calibration data, validate the result, upload the checkpoint, and make sure it did not get mangled on the way. Usually that means a lag of a couple of days before the model becomes easy to serve on constrained hardware.

That delay matters.

If the goal is to evaluate or serve a model the same day it appears, waiting for a separate quantization pipeline is friction you do not need. The more interesting path is to load the raw BF16 checkpoint and compress weights at startup.

That is what TurboQuant does in vLLM. It applies HIGGS style compression at load time, using a Walsh-Hadamard rotation and a small codebook, with no calibration data and no pre-quantized checkpoint.

On CUDA this gives about 4x memory reduction with limited quality loss.

The question I wanted to answer was simpler: can the same idea make a model fit on a Mac that otherwise has no chance of loading?

Porting the path to Apple Silicon

The GPU implementation uses Triton. That is not available on Apple Silicon, so I ported the core path to MLX.

That part went better than expected.

MLX already has mx.hadamard_transform, which covers the key rotation. The rest is familiar array work: codebook lookup, 3-bit packing and unpacking, shape-gain normalization.

The useful trick was this: do not inverse-rotate every weight row separately. Rotate the input once instead.

Because the Walsh-Hadamard transform is self-inverse, rotating the activation gives the same result as inverse-rotating each weight row. One transform replaces a large number of smaller ones.

That change matters.

On a dense 0.5B model, Qwen2.5-0.5B, this path gives about 26 tok/s on an M4 Pro.
On a small MoE model, IBM Granite 1B with 40 experts, it reaches about 84 tok/s.

So the basic machinery works. The trouble starts when the model gets large and sparse at the same time.

Where the memory went

Qwen3.5-35B-A3B has 256 experts per layer, with top-8 routing. Each expert carries gate, up, and down projections. That means 768 weight matrices per MoE layer, across 40 layers.

My first implementation unpacked all 3-bit indices to int32 at model load.

That was a mistake.

The arithmetic was brutal:

That was the wrong shape of optimization for a 48 GB machine.

The model did not need all 256 experts unpacked. It only needed the 8 active ones for the current token.

The change that made it fit

The fix was to unpack only the active experts.

Instead of materializing full int32 index tensors for every expert in advance, the runtime gathers the packed bytes for the selected experts, unpacks those on demand, looks up the codebook values, applies the norms, and runs the matmul.

ids_flat = indices.reshape(-1)
active_packed = mx.take(self._packed_per_expert, ids_flat, axis=0)
active_unpacked = unpack_indices_3bit_mlx(active_packed, ...)
w = self.quant_state.centroids[active_unpacked] * active_norms
out = mx.einsum("ki,koi->ko", x_rotated, w)

That one change moved the model from impossible to runnable.

After it, the total resident model state was about 19 GB:

That fits comfortably enough on a 48 GB MacBook Pro to avoid swap and produce coherent output.

The speed number

This is the part where many writeups become selective.

The generation speed was about 1 token per second.

That number was worse than my early projection.

From a micro-benchmark on a single MoE projection, I estimated much higher throughput. The active expert gather, unpack, and matmul path looked cheap in isolation. But the full forward pass has many other costs around it, and those costs add up quickly.

The rough decode-time picture looked like this:

The early benchmark had focused on the MoE path. It did not include the hundreds of dense linear operations around it. Those turned out to matter at least as much as the sparse part.

I also tried mx.compile, hoping MLX would fuse enough of the small Metal kernels to change the picture.

It helped, but only slightly. About 1.1x.

That tells you something important. The bottleneck is not mainly launch overhead. The kernels are doing substantial work, and too much of that work is separated into memory-heavy stages.

In the MoE path alone, the runtime still walks through multiple discrete operations:

take -> unpack -> centroid lookup -> norm scale -> reshape -> einsum

Each stage moves data through memory again. On a model this size, that is where the time goes.

Why I kept the code

1. The model fits at all

That sounds trivial until you remember the starting point.

The original checkpoint cannot load on this machine. The compressed path produces about 19 GB of working state and generates valid output.

That is a meaningful boundary crossing. A 3.7x effective reduction turns "cannot run" into "can run."

2. The small-model path is already useful

The same implementation runs Granite-1B at 84 tok/s and Qwen2.5-0.5B at 26 tok/s.

So this is not a dead-end prototype. It is a path that scales poorly at 30B+, but works well below that range.

That distinction matters. It tells you where the architecture is sound and where the backend still needs real systems work.

3. It validates the algorithm outside CUDA

Every successful forward pass on MLX exercises the same core compression logic as the CUDA path: packing, lookup, scaling, routing, reconstruction.

When you are building Triton kernels and a model suddenly starts producing garbage, having a second implementation that produces coherent text is extremely useful. Even at 1 tok/s.

That kind of reference system saves time.

What comes next

The path to interactive speed on large MoE models is not mysterious.

The MLX implementation needs a fused Metal kernel that reads packed uint8 weights, performs codebook lookup, applies norms, and emits the matmul result without bouncing through intermediate tensors.

That is the same architectural shape the CUDA path wants as well. Fewer memory round-trips. More fusion. Less staging between small kernels.

There is a nice side note here too. I originally came at this from a broader "TurboQuant" direction. During implementation I kept simplifying the machinery, and each simplification pushed the result closer to the scalar form of HIGGS. Dan Alistarh pointed that out during the vLLM PR #39970. The CLI name stayed --quantization turboquant for compatibility, but the underlying path is HIGGS.

So the split is fairly clear now.

On Mac, this remains a research and validation path, plus a genuinely useful inference engine for smaller models.

On CUDA, this is where the production value compounds.

The pace of open-weight releases is accelerating. Every week brings another checkpoint that somebody needs to evaluate, serve, or reject. The tools that matter are the ones that compress it the moment it arrives: no calibration data, no separate pipeline, no delay. Whether that compression runs at 1 tok/s or 100, the ability to load and test the model immediately is what removes the bottleneck. Speed is engineering. Access is architecture.


The code is here: TurboQuant
The upstream PR is here: vLLM PR #39970

Weight compression KV cache compression turboquant-vllm MLX
More writing →