We pointed our automatic agent builder at TPU docs. It came back with 4x faster kernels.
April 10, 2026
Charles Hong (UC Berkeley)
Autocomp now supports Google TPU! We built the optimization agent automatically from public documentation, and used it to speed up production Pallas kernels, including FlashAttention, by up to 1.41x and vanilla JAX workloads by up to 4.37x.
📋 Table of Contents
Building the TPU Agent
A quick recap: Autocomp is our LLM-driven code optimization framework for tensor accelerators. It integrates domain knowledge, hardware correctness/performance feedback, and novel strategies for response diversity to automatically search for performant code.
Adding a new hardware target to Autocomp requires two things: a hardware-aware optimization agent and an evaluation backend. For previous targets (Gemmini, AWS Trainium, NVIDIA GPUs, RISC-V Vector), building the agent involved significant manual effort: copy-and-pasting documentation, writing optimization strategies by hand, and encoding hardware-specific constraints.
For TPU, we used Autocomp’s Agent Builder to generate the entire agent automatically. We pointed it at four documentation sources:
- Pallas overview —
pallas_call, grid/BlockSpec API - TPU Pallas guides — matmul, pipelining, DMA
- TPU Pallas API reference — full API surface
- TPU hardware docs — architecture, memory hierarchy
From these, the Agent Builder synthesized 22 TPU-specific optimization strategies (on top of 15 default strategies), a 3300-line ISA reference, an architecture summary, and correctness rules. This is the first Autocomp agent built from scratch with the Agent Builder; we never wrote a hand-built agent for TPU. Strategies like “mark grid dimensions as parallel” and “fuse RHS transpose into dot_general” turned out to be directly useful in the optimizations we describe below.
Benchmarks
We evaluate on two categories of workloads, all running on a TPU v6e-1 (Trillium) with JAX 0.6.2:
Category 1 — Optimizing hand-tuned Pallas kernels. Production kernels from upstream JAX, already hand-optimized by Google engineers. Specifically we optimized the Flash Attention and Ragged Paged Attention kernels. Model shapes are drawn from Llama-3.1-8B. These are hard baselines since the starting code is already well-optimized.
Category 2 — Translating and optimizing vanilla JAX. Four workloads from JAXBench starting as unoptimized JAX code, which Autocomp first translates into Pallas and then optimizes. These include MLA Attention, RetNet Retention, Sparse MoE, and Mamba-2 SSD. Here the baseline is the original JAX implementation running through XLA, and there is significantly more headroom for optimization.
Flash Attention: Eliminating 37.5% of wasted compute
We optimized Google’s highly optimized Flash Attention implementation directly pulled from JAX’s codebase. Autocomp found a 3-step optimization chain that speeds it up by 1.41x (0.371 ms → 0.264 ms):
Step 1: Unnormalized online softmax (0.371 → 0.332 ms). The baseline normalizes running softmax statistics on every K/V block iteration, dividing by the running sum of exponentials and rescaling the accumulator. Autocomp defers all normalization to a single pass after the loop, eliminating per-iteration reciprocal computations and matrix-vector multiplies.
Step 2: Causal wavefront microtiling (0.332 → 0.271 ms). For causal attention with Q and KV sequence lengths of 2048, the Q×K matmul computes a 4×4 grid of subtiles. The causal mask zeroes out the 6 upper-triangular subtiles entirely. The baseline computes all 16 subtiles and masks afterward. Autocomp rewrites the inner loop to skip the 6 structurally zero subtiles, computing only the 10 that contribute to the output. This is an algorithmic insight, not a micro-optimization, and eliminates 37.5% of the MXU work.
Step 3: Head-axis coarsening (0.271 → 0.264 ms). The v6e-1 has a single TensorCore, so per-head kernel launch overhead is nontrivial. Autocomp batches 2 heads per program, reducing launch count by half.
Different LLMs contributed different steps: Gemini 3 Flash planned the softmax rewrite, GPT-5.4 planned the wavefront tiling and head coarsening, and Claude Opus 4.5 wrote all three implementations. We’ve uploaded the full optimization trace and final generated kernel for you to explore.
Ragged Paged Attention: The long tail of optimization
Ragged Paged Attention (RPA) is vLLM’s decode-phase attention kernel for batched inference with a paged KV cache. Unlike Flash Attention, RPA is memory-bound, so there is no single algorithmic win to be had. Instead, Autocomp found 11 incremental optimizations over 15 search iterations, each shaving off fractions of a millisecond:
Hoisting loop-invariant computations, pre-folding query scaling into the Q tensor, removing redundant VMEM-to-VMEM transfers, restructuring data layouts for contiguous access, enabling parallel grid dimensions. Each change is small on its own, but they compound to a 1.10x speedup (0.644 ms → 0.587 ms).
This kind of improvement matters at serving scale: RPA runs on every decode step for every request, so even a 10% latency reduction translates directly to higher throughput and lower tail latency. See the full optimization trace and final generated kernel.
Results
Category 1 — Optimizing hand-tuned Pallas
| Kernel | Baseline | Autocomp | Speedup |
|---|---|---|---|
| flash_attention | 0.371 ms | 0.264 ms | 1.41x |
| ragged_paged_attention | 0.644 ms | 0.587 ms | 1.10x |
Category 2 — Translating and optimizing vanilla JAX
| Kernel | JAX Baseline | Autocomp | Speedup |
|---|---|---|---|
| mamba2_ssd | 1.587 ms | 0.363 ms | 4.37x |
| retnet_retention | 0.520 ms | 0.199 ms | 2.61x |
| mla_attention | 4.543 ms | 2.458 ms | 1.85x |
| sparse_moe | 8.268 ms | 6.357 ms | 1.30x |
For Category 2 workloads, Autocomp first translates vanilla JAX code into Pallas kernels and then iteratively optimizes them. The largest win is on Mamba-2 SSD (4.37x), where the translation to Pallas with explicit memory management provides a large baseline improvement, and subsequent optimizations further close the gap to hardware limits.
Conclusion
TPU is Autocomp’s 5th hardware target (after Gemmini, AWS Trainium, NVIDIA GPUs, and RISC-V Vector processors) and the first where the optimization agent was built fully autonomously by the Agent Builder from public documentation. The results show that this auto-generated agent is effective. It can speed up already-optimized production kernels and produce large gains on workloads translated from vanilla JAX.
Check out the Autocomp repo, our paper, the TPU agent configuration, and the generated kernels and traces. Feel free to reach out at charleshong@berkeley.edu if you have any questions or want help getting started.