Profiler
We offer a profiling toolkit in NATTEN, designed to allow easy measurement of NATTEN's performance on your device and given your desired use case. You can also use it to compare against baselines available in PyTorch's SDPA (cuDNN Attention and Flash Attention v2), as well as self attention operators in NATTEN.
It directly uses the PyTorch profiler API to extract the trace of operations, maps known symbol names to human readable operation names, and filters out the less relevant and mutual operations (i.e. device synchronize).
It also provides an interface for exploring backends, their configurations, and automatically searching for the best configuration.
Getting Started
You can access the profiler with python -m natten.profiler
.
Dependencies
We recommend installing rich
and tqdm
(1) for the best visual experience, but they are not
required.
For example, let's say we want to profile a 3D use case, where our token layout (feature map) shape is 8 x 16 x 24, and we have head dim 128:
- Token layout (feature map) shape:
(8, 16, 24)
- Head dim 128
This will report the default self attention time.
Now let's say we want to profile this with window size 2 x 4 x 3.
We can use -w
to specify kernel_size
:
There's so many more options available.
Not only can you specify all the neighborhood attention parameters (kernel_size
, dilation
,
stride
, is_causal
), you can also toggle backward pass, play around with different
backends, data types, backend configurations, and also find out what backends and
configurations are available on your GPU and for your specific use case.
Refer to Arguments for a detailed list of the options, and our examples below highlighting dry run mode, which lists out all available backends and their configurations, and optimize mode, which picks a backend and configuration for you by running reasonable choices and finding the fastest one.
There's also some examples highlighting the performance of our new Hopper FNA and Blackwell FNA kernels.
Arguments
Option |
Description |
---|---|
|
Show help message and exit. |
|
Required QKV token layout shape (i.e. sequence length in 1-D, height and width in 2-D, depth, height, and width in 3-D). Info
|
|
QKV batch size. Default: |
|
QKV number of heads (GQA/MQA are not supported in NATTEN at this
time). Default: |
|
QKV head dim. Default: |
|
Neighborhood attention window size (shape), also referred to as
|
|
Neighborhood attention stride values. This must be a tuple with
the same number of elements as in |
|
Neighborhood attention dilation values. This must be a tuple with
the same number of elements as in |
|
Causal masking values. This must be a boolean tuple with the same
number of elements as in |
|
Element (data) type. Choices: |
|
Profile backward pass as well as forward pass. |
|
Number of additional KV tokens, if desired. Defaults to 0. |
|
Backend / kernel to run. Choices: NATTEN backends: PyTorch SDPA backends (can only perform self attention):
|
|
Backend / kernel for cross-attention (additional KV) and fast-path self attention in NATTEN. Choices:
|
|
Q tile shape in the kernel (varies between different backends).
Run with |
|
KV tile shape in the kernel (varies between different backends).
Run with |
|
Q tile shape in the backward pass kernel (only respected by
the |
|
KV tile shape in the backward pass kernel (only respected by
the |
|
Kernel schedule ( |
|
Use persistent tile scheduler in |
|
Enables compiling Flex Attention block sparse mask and kernel in
|
|
Try to fuse additional KV (cross attention) into the FNA kernel,
if any ( Experimental featureThis feature may be removed in the future, as it rarely provides any benefit to separate attention branches and attention merging. |
|
Number of profiling warmup steps. Default: |
|
Display valid forward and backward pass configurations for this use case and exit. Noteyour default CUDA device (or CPU if CUDA is not available) will be a determining factor in the tile shapes / configurations shown. For instance, some tile shapes may only be available for specific GPU architectures, or if your GPU architecture does not support a specific backend, you will see an empty list. |
|
Maximum number of tile configurations to display. Default: |
|
Find the best configuration (and backend if unspecified) for your use case, by profiling all available choices, and selecting the fastest one. Experimental featureThis feature is experimental, and may change in future releases. We hope to eventually integrate / replace this with NATTEN Simulator. |
|
Number of warmup steps for optimizer. Defaults to |
Dry run
Figuring out tile sizes for your use case.
We offer many different backends, and each backend can offer various configurations, mainly tile sizes / shapes, for each unique use case. Factors that determine those include, but are not limited to:
- GPU architecture
- Element type (
float16
/bfloat16
vsfloat32
) - Attention head dim
Due to this, we highly recommend first trying to understand what options you have, by using the profiler's dry run mode:
python -m natten.profiler \
--dry-run \
--dtype bf16 \
-i 16 16 16 \ #(1)!
-d 128 #(2)!
- Sample feature map shape
(16, 16, 16)
- 128 head dim
Since a backend was not specified, profiler will first detect all compatible backends with the specified options for your default GPU, and print out compatible tile shapes for each available backend in a separate table.
Sample output from running on H100
Use case is compatible with backend hopper-fna.
Backend: hopper-fna
Forward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃ kernel_schedule ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ (4, 4, 8) │ (4, 4, 8) │ KernelSchedule.WarpSpecializedCooperative │
│ (4, 4, 8) │ (4, 4, 8) │ KernelSchedule.WarpSpecializedPingpong │
│ (2, 8, 8) │ (2, 8, 8) │ KernelSchedule.WarpSpecializedCooperative │
│ (2, 8, 8) │ (2, 8, 8) │ KernelSchedule.WarpSpecializedPingpong │
└──────────────┴───────────────┴───────────────────────────────────────────┘
Use case is compatible with backend cutlass-fna.
Backend: cutlass-fna
Forward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (8, 2, 2) │ (32, 2, 2) │
│ (8, 2, 2) │ (16, 4, 2) │
│ (8, 2, 2) │ (16, 2, 4) │
│ (8, 2, 2) │ (8, 8, 2) │
│ (8, 2, 2) │ (8, 4, 4) │
│ (8, 2, 2) │ (8, 2, 8) │
│ (4, 4, 2) │ (16, 4, 2) │
│ (4, 4, 2) │ (8, 8, 2) │
│ (4, 4, 2) │ (8, 4, 4) │
│ (4, 4, 2) │ (4, 16, 2) │
│ ... │ ... │
└──────────────┴───────────────┘
Backend: cutlass-fna
Backward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (16, 2, 2) │ (16, 2, 2) │
│ (16, 2, 2) │ (8, 4, 2) │
│ (16, 2, 2) │ (8, 2, 4) │
│ (16, 2, 2) │ (4, 8, 2) │
│ (16, 2, 2) │ (4, 4, 4) │
│ (16, 2, 2) │ (4, 2, 8) │
│ (16, 2, 2) │ (2, 16, 2) │
│ (16, 2, 2) │ (2, 8, 4) │
│ (16, 2, 2) │ (2, 4, 8) │
│ (16, 2, 2) │ (2, 2, 16) │
│ ... │ ... │
└──────────────┴───────────────┘
Use case is compatible with backend flex-fna.
Backend: flex-fna
Forward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (4, 4, 4) │ (4, 4, 4) │
│ (2, 4, 8) │ (2, 4, 8) │
│ (2, 4, 8) │ (4, 4, 4) │
└──────────────┴───────────────┘
Sample output from running on B200
Use case is compatible with backend blackwell-fna.
Backend: blackwell-fna
Forward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (8, 4, 8) │ (4, 4, 8) │
│ (8, 4, 8) │ (2, 8, 8) │
│ (2, 8, 16) │ (4, 4, 8) │
│ (2, 8, 16) │ (2, 8, 8) │
│ (4, 4, 16) │ (2, 4, 16) │
│ (2, 16, 8) │ (2, 8, 8) │
│ (4, 8, 8) │ (2, 8, 8) │
└──────────────┴───────────────┘
Use case is compatible with backend cutlass-fna.
Backend: cutlass-fna
Forward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (8, 2, 2) │ (32, 2, 2) │
│ (8, 2, 2) │ (16, 4, 2) │
│ (8, 2, 2) │ (16, 2, 4) │
│ (8, 2, 2) │ (8, 8, 2) │
│ (8, 2, 2) │ (8, 4, 4) │
│ (8, 2, 2) │ (8, 2, 8) │
│ (4, 4, 2) │ (16, 4, 2) │
│ (4, 4, 2) │ (8, 8, 2) │
│ (4, 4, 2) │ (8, 4, 4) │
│ (4, 4, 2) │ (4, 16, 2) │
│ ... │ ... │
└──────────────┴───────────────┘
Backend: cutlass-fna
Backward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (16, 2, 2) │ (16, 2, 2) │
│ (16, 2, 2) │ (8, 4, 2) │
│ (16, 2, 2) │ (8, 2, 4) │
│ (16, 2, 2) │ (4, 8, 2) │
│ (16, 2, 2) │ (4, 4, 4) │
│ (16, 2, 2) │ (4, 2, 8) │
│ (16, 2, 2) │ (2, 16, 2) │
│ (16, 2, 2) │ (2, 8, 4) │
│ (16, 2, 2) │ (2, 4, 8) │
│ (16, 2, 2) │ (2, 2, 16) │
│ ... │ ... │
└──────────────┴───────────────┘
Use case is compatible with backend flex-fna.
Backend: flex-fna
Forward pass configurations
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ q_tile_shape ┃ kv_tile_shape ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ (4, 4, 4) │ (4, 4, 4) │
│ (2, 4, 8) │ (2, 4, 8) │
│ (2, 4, 8) │ (4, 4, 4) │
└──────────────┴───────────────┘
Note that some backends offer many combinations, and you may want to use --max-configs
to
change the default limit from 10
, or set it to 0
to display everything.
Tip
You can also do this directly in your code by using our get[_bwd]configs_for_{BACKEND}
APIs.
All you need to do is pass one of your query, key, or value tensors, and they'll return a list
of valid options. Read more about them in backends.
Optimize
You can use the profiler to also search and find the fastest backend configurations for your use case. This can sometimes bring you very significant speedups. However, since running this on some backends may be time consuming, it is a good idea to run it ahead of your actual program, and then hard-code your chosen configuration (this is why we removed our autotuner feature).
As an example, here we demonstrate a 3-D use case on Ampere, where the cutlass-fna
backend
specifically is the best choice, but also comes with over 70 forward pass and 300 backward pass
configurations just for this use case.
python -m natten.profiler \
--backprop \
-b 1 \
-n 24 \
-d 128 \
-i 30 48 80 \
-w 18 24 24 \
-s 16 8 8
Sample output from running on A100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm80 │ FnaForward │ 1 │ 557.289ms │
│ CUTLASS │ attention │ Sm80 │ FnaBackward │ 1 │ 1.445s │
│ CUTLASS │ reduction │ - │ Reduction │ 1 │ 1.461ms │
│ PyTorch │ elementwise │ - │ vectorized_elementwise_kernel │ 5 │ 2.404ms │
│ │ │ │ Total │ │ 2.006s │
└───────────┴─────────────────┴──────┴───────────────────────────────┴─────────┴───────────┘
python -m natten.profiler \
--backprop \
-b 1 \
-n 24 \
-d 128 \
-i 30 48 80 \
-w 18 24 24 \
-s 16 8 8 \
--optimize
Sample output from running on A100
Best configuration
┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ Parameter ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ backend │ cutlass-fna │
│ fmha_backend │ cutlass-fmha │
│ q_tile_shape │ (2, 8, 4) │
│ kv_tile_shape │ (2, 8, 8) │
│ backward_q_tile_shape │ (2, 8, 8) │
│ backward_kv_tile_shape │ (2, 8, 8) │
└────────────────────────┴──────────────┘
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm80 │ FnaForward │ 1 │ 221.512ms │
│ CUTLASS │ attention │ Sm80 │ FnaBackward │ 1 │ 1.158s │
│ CUTLASS │ reduction │ - │ Reduction │ 1 │ 1.464ms │
│ PyTorch │ elementwise │ - │ vectorized_elementwise_kernel │ 2 │ 737.376us │
│ │ │ │ Total │ │ 1.382s │
└───────────┴─────────────────┴──────┴───────────────────────────────┴─────────┴───────────┘
Warning
Running this took over an hour, since there are so many unique configurations for backward pass, and that, well, it's a pretty big use case. But remember, this only needs to be done once.
Alternatively, you can always try to set batch and heads to 1 (when there's at least a few
thousand tokens), run more quickly, and then use the same config from the batch=1
heads=1
case. It can get you most of the way, especially on newer architectures that have persistent
scheduling.
In this case, the default case's runtime was approximately 2 seconds, while the optimized case was at approximately 1.4 seconds, which is 1.45X speedup. If we look at just forward pass, it's ~ 557 ms vs ~ 222 ms, which is 2.5X speedup!
However, we plan to eventually integrate our new
NATTEN Simulator with this feature, so that we can optimize
faster by ruling out obviously terrible configurations, and only search a fraction of the total per
use case. In addition, the Simulator can also make other recommendations, like what stride
and
dilation
values give you the best speedups!
Hopper and Blackwell Examples
1-D use case
In this example, we try to measure the runtime for self attention, two forms of neighborhood attention, and blocked attention (also implemented by NATTEN), on a 32K 1-D sequence.
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ cuDNN │ attention │ Sm90 │ sm90_flash_fprop_wgmma_f16_knob_7_64x128x128_... │ 1 │ 787.392us │
│ │ │ │ Total │ │ 787.392us │
└───────────┴─────────────────┴──────┴──────────────────────────────────────────────────┴─────────┴───────────┘
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ cuDNN │ attention │ Sm100 │ sm100_flash_fprop_f16_knob_7_128x128x128_4x1x... │ 1 │ 424.766us │
│ │ │ │ Total │ │ 424.766us │
└───────────┴─────────────────┴───────┴──────────────────────────────────────────────────┴─────────┴───────────┘
Now let's try standard neighborhood attention, with a 2K window size, which gives us 93.75% sparsity.
python -m natten.profiler \
--backend hopper-fna \
-d 128 \
-i 32768 \
-w 2048
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 94.848us │
│ │ │ │ Total │ │ 94.848us │
└───────────┴─────────────────┴──────┴────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with 93.75% sparsity: 16X
Speedup: 787.392 us / 94.848 us = 8.3X
Note
1-D non-dilated neighborhood attention does not require token permute, hence the only runtime is the kernel runtime.
python -m natten.profiler \
--backend blackwell-fna \
-d 128 \
-i 32768 \
-w 2048
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 58.111us │
│ │ │ │ Total │ │ 58.111us │
└───────────┴─────────────────┴───────┴────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with 93.75% sparsity: 16X
Speedup: 424.766 us / 58.111 us = 7.3X
Note
1-D non-dilated neighborhood attention does not require token permute, hence the only runtime is the kernel runtime.
Let's try a strided variant of neighborhood attention.
We ran the NATTEN Simulator (coming soon), and found that stride=256
is the minimum stride that
results in a fully block-sparse mask, for both Hopper and Blackwell tile sizes.
python -m natten.profiler \
--backend hopper-fna \
-d 128 \
-i 32768 \
-w 2048 \
-s 256
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 62.687us │
│ │ │ │ Total │ │ 62.687us │
└───────────┴─────────────────┴──────┴────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with 93.75% sparsity: 16X
Speedup: 787.392 us / 62.687 us = 12.6X
Note
1-D non-dilated neighborhood attention does not require token permute, hence the only runtime is the kernel runtime.
python -m natten.profiler \
--backend blackwell-fna \
-d 128 \
-i 32768 \
-w 2048 \
-s 256
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 37.888us │
│ │ │ │ Total │ │ 37.888us │
└───────────┴─────────────────┴───────┴────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with 93.75% sparsity: 16X
Speedup: 424.766 us / 37.888 us = 11.4X
Note
1-D non-dilated neighborhood attention does not require token permute, hence the only runtime is the kernel runtime.
Finally, we know that when stride == kernel_size
, neighborhood attention numerically matches
blocked attention. In this use case, blocked attention is also fully block-sparse, and given that it
is using the same window size as the neighborhood attention case, we should expect identical
performance (excluding runtime variance).
python -m natten.profiler \
--backend hopper-fna \
-d 128 \
-i 32768 \
-w 2048 \
-s 2048
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 59.264us │
│ │ │ │ Total │ │ 59.264us │
└───────────┴─────────────────┴──────┴────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with 93.75% sparsity: 16X
Speedup: 787.392 us / 59.264 us = 13.3X
Note
1-D non-dilated neighborhood attention does not require token permute, hence the only runtime is the kernel runtime.
python -m natten.profiler \
--backend blackwell-fna \
-d 128 \
-i 32768 \
-w 2048 \
-s 2048
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 38.240us │
│ │ │ │ Total │ │ 38.240us │
└───────────┴─────────────────┴───────┴────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with 93.75% sparsity: 16X
Speedup: 424.766 us / 38.240 us = 11.1X
Note
1-D non-dilated neighborhood attention does not require token permute, hence the only runtime is the kernel runtime.
2-D use case (FLUX)
In this example we take the problem size from Flux.1-dev (4K), with 24 attention heads, and a
(256, 256)
token layout (feature map) shape.
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ cuDNN │ attention │ Sm90 │ sm90_flash_fprop_wgmma_f16_knob_7_64x128x128_... │ 1 │ 97.435ms │
│ │ │ │ Total │ │ 97.435ms │
└───────────┴─────────────────┴──────┴──────────────────────────────────────────────────┴─────────┴──────────┘
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ cuDNN │ attention │ Sm100 │ sm100_flash_fprop_f16_knob_7_128x128x128_4x1x... │ 1 │ 43.313ms │
│ │ │ │ Total │ │ 43.313ms │
└───────────┴─────────────────┴───────┴──────────────────────────────────────────────────┴─────────┴──────────┘
Now let's try standard neighborhood attention, with a window size of (80, 80)
, which is
approximately 90% sparsity (the same configuration as in the
GNA paper).
We explicitly set the query and KV tile shapes to (16, 8)
, which is a combination that
requires not additional input padding, and given the window size can become fully block-sparse
when using stride (next up). Q and KV tile shapes (16, 8)
is a combination supported by our
Hopper FNA kernel.
python -m natten.profiler \
--backend hopper-fna \
--q-tile 16 8 \
--kv-tile 16 8 \
-n 24 \
-d 128 \
-i 256 256 \
-w 80 80
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 16.201ms │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 2.402ms │
│ │ │ │ Total │ │ 18.603ms │
└───────────┴─────────────────┴──────┴────────────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with ~90% sparsity: 10.2X
Speedup w/o token permute time: 97.435 ms / 16.201 ms = 6.0X
Speedup w/ token permute time: 97.435 ms / 18.603 ms = 5.2X
We explicitly set the query and KV tile shapes to (16, 16)
and (16, 8)
respectively.
This is because the Blackwell forward pass kernel presently only supports a Q tile size of 256
and KV tile size of 128. Query padding is not avoidable in this case, but KV padding is avoided,
which allows the mask to become fully block-sparse when using stride (next up).
python -m natten.profiler \
--backend blackwell-fna \
--q-tile 16 16 \
--kv-tile 16 8 \
-n 24 \
-d 128 \
-i 256 256 \
-w 80 80
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 8.692ms │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 2.239ms │
│ │ │ │ Total │ │ 10.931ms │
└───────────┴─────────────────┴───────┴────────────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with ~90% sparsity: 10.2X
Speedup w/o token permute time: 43.313 ms / 8.692 ms = 5.0X
Speedup w/ token permute time: 43.313 ms / 10.931 ms = 4.0X
Finally let's try strided neighborhood attention, with stride (16, 16)
(the same configuration
as in the GNA paper).
Given the use case, and Q and KV tile shapes, this stride results in a fully block-sparse mask.
python -m natten.profiler \
--backend hopper-fna \
--q-tile 16 8 \
--kv-tile 16 8 \
-n 24 \
-d 128 \
-i 256 256 \
-w 80 80 \
-s 16 16
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 7.914ms │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 2.687ms │
│ │ │ │ Total │ │ 10.601ms │
└───────────┴─────────────────┴──────┴────────────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with ~90% sparsity: 10.2X
Speedup w/o token permute time: 97.435 ms / 7.914 ms = 12.3X
Speedup w/ token permute time: 97.435 ms / 10.601 ms = 9.2X
Note
The baseline is cuDNN Attention, which in this particlar case is slightly slower than CUTLASS FMHA, which is why the speedup without token permute exceeds the 10.2X FLOP-wise limit. Additionally, runtime variance and other factors may affect observable speedups.
python -m natten.profiler \
--backend blackwell-fna \
--q-tile 16 16 \
--kv-tile 16 8 \
-n 24 \
-d 128 \
-i 256 256 \
-w 80 80 \
-s 16 16
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 4.114ms │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 2.275ms │
│ │ │ │ Total │ │ 6.389ms │
└───────────┴─────────────────┴───────┴────────────────────┴─────────┴─────────┘
FLOP-wise speedup (upper bound) with ~90% sparsity: 10.2X
Speedup w/o token permute time: 43.313 ms / 4.114 ms = 10.5X
Speedup w/ token permute time: 43.313 ms / 6.389 ms = 6.8X
Note
Speedup without token permute only appears to exceed the 10.2X FLOP-wise due to runtime variance and other factors that may affect observable speedups.
3-D use case (Hunyuan Video)
In this example we take the problem size from Hunyuan Video, with 24 attention heads, and a
(30, 48, 80)
token layout (feature map) shape.
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ cuDNN │ attention │ Sm90 │ sm90_flash_fprop_wgmma_f16_knob_7_64x128x128_... │ 1 │ 283.327ms │
│ │ │ │ Total │ │ 283.327ms │
└───────────┴─────────────────┴──────┴──────────────────────────────────────────────────┴─────────┴───────────┘
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ cuDNN │ attention │ Sm100 │ sm100_flash_fprop_f16_knob_7_128x128x128_4x1x... │ 1 │ 135.265ms │
│ │ │ │ Total │ │ 135.265ms │
└───────────┴─────────────────┴───────┴──────────────────────────────────────────────────┴─────────┴───────────┘
Now let's try standard neighborhood attention, with a window size of (18, 24, 24)
, which is
approximately 90% sparsity (the same configuration as in the
GNA paper).
We explicitly set the query and KV tile shapes to (2, 8, 8)
, which is a combination that
requires not additional input padding, and given the window size can become fully block-sparse
when using stride (next up). Q and KV tile shapes (2, 8, 8)
is a combination supported by our
Hopper FNA kernel.
python -m natten.profiler \
--backend hopper-fna \
--q-tile 2 8 8 \
--kv-tile 2 8 8 \
-n 24 \
-d 128 \
-i 30 48 80 \
-w 18 24 24
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 79.043ms │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 5.374ms │
│ │ │ │ Total │ │ 84.417ms │
└───────────┴─────────────────┴──────┴────────────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with ~91% sparsity: 11.1X
Speedup w/o token permute time: 283.327 ms / 79.043 ms = 3.6X
Speedup w/ token permute time: 283.327 ms / 84.417 ms = 3.4X
We explicitly set the query and KV tile shapes to (4, 8, 8)
and (2, 8, 8)
respectively.
This is because the Blackwell forward pass kernel presently only supports a Q tile size of 256
and KV tile size of 128. Query padding is not avoidable in this case, but KV padding is avoided,
which allows the mask to become fully block-sparse when using stride (next up).
python -m natten.profiler \
--backend blackwell-fna \
--q-tile 4 8 8 \
--kv-tile 2 8 8 \
-n 24 \
-d 128 \
-i 30 48 80 \
-w 18 24 24
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 42.243ms │
│ PyTorch │ elementwise │ - │ vectorized_elementwise_kernel │ 1 │ 191.360us │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 4.919ms │
│ │ │ │ Total │ │ 47.353ms │
└───────────┴─────────────────┴───────┴───────────────────────────────┴─────────┴───────────┘
FLOP-wise speedup (upper bound) with ~91% sparsity: 11.1X
Speedup w/o token permute time: 135.265 ms / 42.243 ms = 3.2X
Speedup w/ token permute time: 135.265 ms / 47.353 ms = 2.9X
Finally let's try strided neighborhood attention, with stride (16, 8, 8)
(the same configuration
as in the GNA paper).
Given the use case, and Q and KV tile shapes, this stride results in a fully block-sparse mask.
python -m natten.profiler \
--backend hopper-fna \
--q-tile 2 8 8 \
--kv-tile 2 8 8 \
-n 24 \
-d 128 \
-i 30 48 80 \
-w 18 24 24 \
-s 16 8 8
Sample output from running on H100
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm90 │ FnaForward │ 1 │ 23.359ms │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 5.750ms │
│ │ │ │ Total │ │ 29.109ms │
└───────────┴─────────────────┴──────┴────────────────────┴─────────┴──────────┘
FLOP-wise speedup (upper bound) with ~91% sparsity: 11.1X
Speedup w/o token permute time: 283.327 ms / 23.359 ms = 12.1X
Speedup w/ token permute time: 283.327 ms / 29.109 ms = 9.7X
Note
The baseline is cuDNN Attention, which in this particlar case is slightly slower than CUTLASS FMHA, which is why the speedup without token permute exceeds the 11.1X FLOP-wise limit. Additionally, runtime variance and other factors may affect observable speedups.
python -m natten.profiler \
--backend blackwell-fna \
--q-tile 4 8 8 \
--kv-tile 2 8 8 \
-n 24 \
-d 128 \
-i 30 48 80 \
-w 18 24 24 \
-s 16 8 8
Sample output from running on B200
Profiler results
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓
┃ Framework ┃ Kernel category ┃ Arch ┃ Operation ┃ # calls ┃ Runtime ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩
│ CUTLASS │ attention │ Sm100 │ FnaForward │ 1 │ 13.010ms │
│ PyTorch │ elementwise │ - │ vectorized_elementwise_kernel │ 1 │ 206.623us │
│ PyTorch │ elementwise │ - │ elementwise_kernel │ 4 │ 5.321ms │
│ │ │ │ Total │ │ 18.538ms │
└───────────┴─────────────────┴───────┴───────────────────────────────┴─────────┴───────────┘
FLOP-wise speedup (upper bound) with ~91% sparsity: 11.1X
Speedup w/o token permute time: 135.265 ms / 13.010 ms = 10.4X
Speedup w/ token permute time: 135.265 ms / 18.538 ms = 7.3X
Limitations
We will be expanding the current profiling toolkit to support third-party backends (i.e. Flash Attention 3), more complete argument bindings, automatic comparison to reference with speedup measurements, and more.