> ## Documentation Index
> Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt
> Use this file to discover all available pages before exploring further.

# Prefill Benchmark

> Compare PyTorch Standard, Naive Triton, and Flash Attention implementations during the prefill phase, including crossover analysis and kernel launch overhead.

The prefill benchmark measures how fast each attention implementation processes the full input prompt before any tokens are generated. Run it with:

```bash theme={null}
uv run python benchmark_prefilling.py
```

## What prefill does

During prefill, every token in the input sequence attends to every previous token. For a sequence of length `N`, this produces an `N × N` attention matrix. The cost of this step dominates short-context latency, and the memory required to hold that matrix determines which implementations can scale.

## Implementations compared

<CardGroup cols={3}>
  <Card title="PyTorch Standard" icon="square-1">
    O(N²) memory. Materializes the full attention matrix in GPU global memory using standard `torch.matmul` and `torch.softmax`. Works at any sequence length but memory usage grows quadratically.
  </Card>

  <Card title="Naive Triton" icon="square-2">
    O(N²) memory. A custom Triton kernel that loads the entire Q, K, V sequence into shared memory and computes the full attention matrix there. Limited to ≤128 tokens due to the GPU shared memory budget.
  </Card>

  <Card title="Flash Attention" icon="square-3">
    O(N) memory. Tiled computation with online softmax: accumulates the output block by block without ever materializing the full N×N matrix. Memory footprint is proportional to sequence length, not sequence length squared.
  </Card>
</CardGroup>

## Function signatures

<CodeGroup>
  ```python pytorch_standard_attention theme={null}
  def pytorch_standard_attention(
      q: torch.Tensor,
      k: torch.Tensor,
      v: torch.Tensor,
      cu_seqlens: torch.Tensor,
      scale: float,
      num_heads: int,
      num_kv_heads: int,
      head_dim: int,
  ) -> torch.Tensor:
      """Standard PyTorch attention - O(N²) memory"""
  ```

  ```python flash_attention theme={null}
  def flash_attention(
      q: torch.Tensor,
      k: torch.Tensor,
      v: torch.Tensor,
      cu_seqlens: torch.Tensor,
      scale: float,
      num_heads: int,
      num_kv_heads: int,
      head_dim: int,
  ) -> torch.Tensor:
      """Flash Attention - online softmax optimization"""
  ```
</CodeGroup>

Both functions receive packed token tensors of shape `(total_tokens, num_heads, head_dim)` and a `cu_seqlens` offsets array that marks where each sequence starts and ends in the packed layout.

## Benchmark configurations

The script runs four configurations to expose different regimes:

| `num_seqs` | `seq_len` | Total tokens | What it tests                                                      |
| ---------- | --------- | ------------ | ------------------------------------------------------------------ |
| 2          | 60        | 120          | Short sequences — kernel launch overhead dominates                 |
| 4          | 64        | 256          | Small batch at Naive Triton's shared memory limit                  |
| 2          | 1024      | 2048         | Medium sequences — Naive Triton is skipped                         |
| 1          | 4096      | 4096         | Long sequences — Flash Attention's efficiency advantage is largest |

Each configuration runs a warmup pass before timing, and the timing loop uses `torch.cuda.synchronize()` to get accurate wall-clock measurements.

## Shared memory constraint for Naive Triton

The Naive Triton kernel stores the entire `BLOCK_SIZE × BLOCK_SIZE` attention matrix in GPU shared memory. The memory cost in bytes is:

```
attention_matrix_bytes = BLOCK_SIZE² × 4  (float32)
```

With a GPU shared memory budget of roughly 48 KB per block:

| `BLOCK_SIZE` | Attention matrix          | Status                                   |
| ------------ | ------------------------- | ---------------------------------------- |
| 64           | 64 × 64 × 4 = **16 KB**   | Safe                                     |
| 128          | 128 × 128 × 4 = **64 KB** | Exceeds limit — results may be incorrect |

When `head_dim > 64` the kernel selects `BLOCK_SIZE = 64`. Sequences longer than `BLOCK_SIZE` are silently skipped at runtime.

<Warning>
  The Naive Triton kernel is automatically skipped for any configuration where `seq_len > BLOCK_SIZE`. You will see a "SKIPPED" message for the `seq_len=1024` and `seq_len=4096` runs.
</Warning>

## Crossover analysis

The script includes a `find_crossover_point()` function that sweeps `seq_len` from 16 to 1024 (with `num_seqs=2`, `num_heads=32`, `num_kv_heads=8`, `head_dim=128`) to find exactly where Flash Attention overtakes Naive Triton:

```
Seq Len  |  Naive (ms)  |  Flash (ms)  |  Winner  |  Speedup
-----------------------------------------------------------------
     16  |       0.xxx  |       0.xxx  |   Naive  |  x.xxX
     32  |       0.xxx  |       0.xxx  |   Naive  |  x.xxX
     64  |       0.xxx  |       0.xxx  |   Flash  |  x.xxX   <-- crossover
    128  |         OOM  |       0.xxx  |   Flash  |  N/A
    ...
   1024  |         OOM  |       0.xxx  |   Flash  |  N/A
```

At short sequences Naive Triton wins because it launches fewer kernels and does less work per kernel. Once the sequence is long enough that the O(N²) matrix cost compounds, Flash Attention's tiled approach wins — and beyond the shared memory limit, Flash Attention is the only option.

## Kernel launch analysis

Flash Attention's grid has three dimensions; Naive Triton's grid has two:

```python theme={null}
# Naive Triton grid: one thread block per (sequence, head)
naive_grid = (num_seqs, num_heads)
naive_kernels = num_seqs * num_heads

# Flash Attention grid: one thread block per (query tile, head, sequence)
num_blocks_m = ceil(seq_len / BLOCK_M)
flash_grid = (num_blocks_m, num_heads, num_seqs)
flash_kernels = num_blocks_m * num_heads * num_seqs
```

For 2 sequences of 60 tokens each, with `BLOCK_M=32` and 32 heads:

| Implementation  | Grid         | Total kernels |
| --------------- | ------------ | ------------- |
| Naive Triton    | `(2, 32)`    | 64            |
| Flash Attention | `(2, 32, 2)` | 128           |

Each kernel launch carries \~5–20 µs of fixed overhead. For 64 extra launches that is roughly 0.64–1.28 ms of extra latency — which is why Naive Triton can be faster at short sequences despite doing the same mathematical work.

<Note>
  This overhead becomes negligible at longer sequences where the compute time per kernel dwarfs the launch cost. At `seq_len=1024`, Flash Attention processes each tile independently and with O(N) memory access patterns, which is why it wins decisively.
</Note>

## Why Flash Attention wins at long sequences

<Steps>
  <Step title="O(N) memory access">
    Flash Attention never writes a full N×N matrix. Each tile of Q reads a tile of K and V from HBM once, accumulates into a running output, and discards the intermediate scores. Total HBM traffic is O(N) rather than O(N²).
  </Step>

  <Step title="Online softmax keeps precision">
    The kernel tracks a running maximum `m_i` and a running normalizer `l_i` per row. When a new K/V tile arrives, it rescales the previous accumulator with `alpha = exp(m_i_old - m_i_new)` before adding the new contribution. The final output is divided by `l_i` once. No second pass over the data is needed.
  </Step>

  <Step title="Shared memory is reused, not exhausted">
    Each tile is small enough to fit entirely in shared memory. The kernel loops over tiles, reusing the same shared memory budget for each one. Naive Triton allocates a slot for the entire N×N matrix upfront, which caps the maximum sequence length.
  </Step>

  <Step title="No sequence length ceiling">
    Because memory is O(N), Flash Attention can handle arbitrarily long sequences (subject only to total HBM capacity). Naive Triton silently skips sequences that exceed `BLOCK_SIZE`.
  </Step>
</Steps>
