Skip to content

Backends

In this page, we list our available implementations for standard attention (FMHA), and neighborhood attention (FNA).

CUTLASS FNA / FMHA

Supported modes

  • Inference (forward pass)
  • Training (backward pass)

FNA visualization

Visualization of FNA, as proposed in Faster Neighborhood Attention (2024).

Based on xFormers FMHA (a.k.a. memory-efficient attention), this kernel is based on the CUTLASS 2.X API, and targets multiple architectures: SM50 (Maxwell), SM70 (Volta), SM75 (Turing), and SM80 (Ampere). You can use these kernels on any NVIDIA GPU with compute capability >= 5.0, and both for training and inference.

Some newer architectures such as Hopper (SM90), and Blackwell (SM100) have much more performant dedicated kernels, but they are limited to inference for now.

This implementation fuses multi-dimensional tiling directly into the kernel, but at the same time may suffer from additional overhead of software predication. To read more about this, we refer you to our Generalized Neighborhood Attention paper, in which we also proposed solutions such as Token Permutation, which we use to build our Hopper and Blackwell kernels.

Finding configurations

You can use profiler dry runs to find configurations for any of our backends, and also find backends compatible with your device and use case. You can also use the following functions in your code.

Finding configurations for CUTLASS FMHA/FNA

natten.get_configs_for_cutlass_fmha

get_configs_for_cutlass_fmha(input_tensor)

Returns CUTLASS FMHA configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a device with compute capability >= 5.0, and if so, returns forward pass configurations compatible with the specific compute capability, tensor dtype and head dim.

Each configuration for this operation is a tuple of two integers: (q_tile_size,kv_tile_size). These are arguments to natten.attention.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[int, int]]

List of tuples of two integers corresponding to query and KV tile sizes.

natten.get_bwd_configs_for_cutlass_fmha

get_bwd_configs_for_cutlass_fmha(input_tensor)

Returns CUTLASS FMHA backward pass configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a device with compute capability >= 5.0, and if so, returns backward pass configurations compatible with the specific compute capability, tensor dtype and head dim.

Each configuration for this operation is a tuple of two integers: (backward_q_tile_size,backward_kv_tile_size). These are arguments to natten.attention.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[int, int]]

List of tuples of two integers corresponding to query and KV tile sizes in the backward pass.

natten.get_configs_for_cutlass_fna

get_configs_for_cutlass_fna(input_tensor)

Returns CUTLASS FNA configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a device with compute capability >= 5.0, and if so, returns forward pass configurations compatible with the specific compute capability, tensor dtype and head dim, and according to the rank of the token layout (1D/2D/3D).

Each configuration for this operation is a tuple of two integer tuples: (q_tile_shape,kv_tile_shape). These are arguments to natten.na1d, natten.na2d, and natten.na3d.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, and NA operations require matching token layouts, Q, K, and V are guaranteed to have the same shape, therefore it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[tuple, tuple]]

List of tuples of two integer tuples corresponding to query and KV tile shapes.

natten.get_bwd_configs_for_cutlass_fna

get_bwd_configs_for_cutlass_fna(input_tensor)

Returns CUTLASS FNA backward pass configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a device with compute capability >= 5.0, and if so, returns backward pass configurations compatible with the specific compute capability, tensor dtype and head dim.

Each configuration for this operation is a tuple of two integers: (backward_q_tile_shape,backward_kv_tile_shape). These are arguments to natten.na1d, natten.na2d, and natten.na3d.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, and NA operations require matching token layouts, Q, K, and V are guaranteed to have the same shape, therefore it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[tuple, tuple]]

List of tuples of two integer tuples corresponding to query and KV tile shapes in the backward pass.

Hopper FNA / FMHA

Supported modes

  • Inference (forward pass)
  • Training (backward pass)

Hopper FNA performance sample

Performance levels of Hopper FNA as of version 0.20.0.

Based on CUTLASS's Hopper FMHA kernel (3.X API), this backend offers non-persistent, warp-specialized cooperative, and warp-specialized ping-ponging kernels, similar to Flash Attention 3. This kernel exhibits similar forward-pass performance to Flash Attention 3.

This backend does not fuse multi-dimensional tiling into the kernel, and instead uses Token Permutation.

Finding configurations

You can use profiler dry runs to find configurations for any of our backends, and also find backends compatible with your device and use case. You can also use the following functions in your code.

Finding configurations for Hopper FMHA/FNA

natten.get_configs_for_cutlass_hopper_fmha

get_configs_for_cutlass_hopper_fmha(input_tensor)

Returns Hopper FMHA configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a Hopper datacenter GPU (SM90; compute capability 9.0), and if so, returns forward pass configurations compatible with the tensor dtype and head dim.

Each configuration for this operation is a tuple of one integer tuple, and another integer: ((q_tile_size, kv_tile_size), kernel_schedule). These are arguments to natten.attention. kernel_schedule is specific to Hopper FNA/FMHA only.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[Tuple[int, int], KernelSchedule]]

List of tuples of one tuple of two integers corresponding to query and KV tile sizes, and a kernel schedule enum type.

natten.get_configs_for_cutlass_hopper_fna

get_configs_for_cutlass_hopper_fna(input_tensor)

Returns Hopper FNA configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a Hopper datacenter GPU (SM90; compute capability 9.0), and if so, returns forward pass configurations compatible with the tensor dtype and head dim.

Each configuration for this operation is a tuple of one tuple, and another integer: ((q_tile_shape, kv_tile_shape), kernel_schedule). These are arguments to natten.attention. kernel_schedule is specific to Hopper FNA/FMHA only.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, and NA operations require matching token layouts, Q, K, and V are guaranteed to have the same shape, therefore it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[Tuple[tuple, tuple], KernelSchedule]]

List of tuples of one tuple of two shape tuples, corresponding to query and KV tile shapes, and a kernel schedule enum type.

Blackwell FNA / FMHA

Supported modes

  • Inference (forward pass)
  • Training (backward pass)

Blackwell FNA performance sample

Performance levels of Blackwell FNA as of version 0.20.0 (also reported in Generalized Neighborhood Attention (2025)).

Based on CUTLASS's Blackwell FMHA kernel (3.X API), this backend offers incredible forward-pass performance, which is comparable with cuDNN's Blackwell Attention.

This backend does not fuse multi-dimensional tiling into the kernel, and instead uses Token Permutation.

Finding configurations

You can use profiler dry runs to find configurations for any of our backends, and also find backends compatible with your device and use case. You can also use the following functions in your code.

Finding configurations for Blackwell FMHA/FNA

natten.get_configs_for_cutlass_blackwell_fmha

get_configs_for_cutlass_blackwell_fmha(input_tensor)

Returns Blackwell FMHA configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a Blackwell datacenter GPU (SM100; compute capability 10.0), and if so, returns forward pass configurations compatible with the tensor dtype and head dim.

Each configuration for this operation is a tuple of two integers: (q_tile_size,kv_tile_size). These are arguments to natten.attention.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[int, int]]

List of tuples of two integers corresponding to query and KV tile sizes.

natten.get_configs_for_cutlass_blackwell_fna

get_configs_for_cutlass_blackwell_fna(input_tensor)

Returns Blackwell FNA configurations compatible with input tensor, if any.

Checks first if a CUDA tensor, and on a Blackwell datacenter GPU (SM100; compute capability 10.0), and if so, returns forward pass configurations compatible with the tensor dtype and head dim, and according to the rank of the token layout (1D/2D/3D).

Each configuration for this operation is a tuple of two integer tuples: (q_tile_shape,kv_tile_shape). These are arguments to natten.na1d, natten.na2d, and natten.na3d.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, and NA operations require matching token layouts, Q, K, and V are guaranteed to have the same shape, therefore it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[tuple, tuple]]

List of tuples of two integer tuples corresponding to query and KV tile shapes.

Flex FNA / FMHA

Warning

This feature is experimental.

Supported modes

  • Inference (forward pass)
  • Training (backward pass)

Info

This feature requires PyTorch >= 2.7.

This backend is PyTorch-native, and supports some non-NVIDIA devices as well (CPU and ROCm). It is based on Flex Attention.

Since this backend is implemented in PyTorch, fusion of multi-dimensional tiling is not possible. This backend however does support Token Permutation, similar to the Hopper and Blackwell backends.

By default, if tile shapes are not specified, token permutation will be disabled, and our legacy Flex mask will be used. If tile shapes are specified, it will use token permutation.

Finding configurations

You can use profiler dry runs to find configurations for any of our backends, and also find backends compatible with your device and use case. You can also use the following functions in your code.

Finding configurations for Flex FMHA/FNA

natten.get_configs_for_flex_fmha

get_configs_for_flex_fmha(input_tensor)

Returns Flex FMHA configurations compatible with input tensor, if any.

Each configuration for this operation is a tuple of two integers: (q_tile_size,kv_tile_size). These are arguments to natten.attention. Not specifying these arguments while backend is Flex will default to q_tile_size = 64 and kv_tile_size = 64.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[int, int]]

List of tuples of two integers corresponding to query and KV tile sizes.

natten.get_configs_for_flex_fna

get_configs_for_flex_fna(input_tensor)

Returns Flex FNA configurations compatible with input tensor, if any.

Each configuration for this operation is a tuple of two integer tuples: (q_tile_shape,kv_tile_shape). These are arguments to natten.na1d, natten.na2d, and natten.na3d. Not specifying these arguments while backend is Flex will default to single-dimensional tiling, and will not use our Token Permutation approach. By explicitly specifying tile shapes, you will automatically use our Token Permutation approach, which saves you the most compute.

Parameters:

Name Type Description Default
input_tensor Tensor

Input torch tensor. Either Q, K, or V. Since NATTEN does not support GQA/MQA, or V with a different head dim, and NA operations require matching token layouts, Q, K, and V are guaranteed to have the same shape, therefore it doesn't make a difference which one is passed in.

required

Returns:

Type Description
List[Tuple[tuple, tuple]]

List of tuples of two integer tuples corresponding to query and KV tile shapes.