Skip to content

Operations

In this page we list our PyTorch Autograd-compatible operations. These operations come with performance knobs (configurations), some of which are specific to certain backends.

Changing those knobs is completely optional, and NATTEN will continue to be functionally correct in all cases. However, to squeeze out the maximum performance achievable, we highly recommend looking at backends, or just using our profiler toolkit and its dry run feature to navigate through available backends and their valid configurations for your specific use case and GPU architecture. You can also use the profiler's optimize feature to search and find the best configuration.

Neighborhood Attention

natten.na1d

na1d(
    query,
    key,
    value,
    kernel_size,
    stride=1,
    dilation=1,
    is_causal=False,
    scale=None,
    additional_keys=None,
    additional_values=None,
    attention_kwargs=None,
    backend=None,
    q_tile_shape=None,
    kv_tile_shape=None,
    backward_q_tile_shape=None,
    backward_kv_tile_shape=None,
    backward_kv_splits=None,
    backward_use_pt_reduction=False,
    run_persistent_kernel=True,
    kernel_schedule=None,
    torch_compile=False,
    try_fuse_additional_kv=False,
)

Computes 1-D neighborhood attention.

Parameters:

Name Type Description Default
query Tensor

4-D query tensor, with the heads last layout ([batch, seqlen, heads, head_dim])

required
key Tensor

4-D key tensor, with the heads last layout ([batch, seqlen, heads, head_dim])

required
value Tensor

4-D value tensor, with the heads last layout ([batch, seqlen, heads, head_dim])

required
kernel_size Tuple[int] | int

Neighborhood window (kernel) size.

Note

kernel_size must be smaller than or equal to seqlen.

required
stride Tuple[int] | int

Sliding window step size. Defaults to 1 (standard sliding window).

Note

stride must be smaller than or equal to kernel_size. When stride == kernel_size, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).

1
dilation Tuple[int] | int

Dilation step size. Defaults to 1 (standard sliding window).

Note

The product of dilation and kernel_size must be smaller than or equal to seqlen.

1
is_causal Tuple[bool] | bool

Toggle causal masking. Defaults to False (bi-directional).

False
scale float

Attention scale. Defaults to head_dim ** -0.5.

None
additional_keys Optional[Tensor]

None or 4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]), corresponding to key tokens from some additional context. Used when performing neighborhood cross-attention, where query tokens attend to their neighborhood, as well as some fixed additional set of tokens.

None
additional_values Optional[Tensor]

None or 4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]), corresponding to value tokens from some additional context. Used when performing neighborhood cross-attention, where query tokens attend to their neighborhood, as well as some fixed additional set of tokens.

Note

additional_keys and additional_values must both either be Tensors, or both Nones, and must match in shape.

None

Other Parameters:

Name Type Description
backend str

Backend implementation to run with. Choices are: None (pick the best available one), "cutlass-fna", "hopper-fna", "blackwell-fna", "flex-fna". Refer to backends for more information.

q_tile_shape Tuple[int]

1-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.

kv_tile_shape Tuple[int]

1-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_q_tile_shape Tuple[int]

1-D Tile shape for the query token layout in the backward pass kernel. This is only respected by the "cutlass-fna" backend. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_kv_tile_shape Tuple[int]

1-D Tile shape for the key/value token layout in the backward pass kernel. This is only respected by the "cutlass-fna" backend. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_kv_splits Tuple[int]

Number of key/value tiles allowed to work in parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for neighborhood attention operations, and the size of the tuple corresponds to the number of dimensions / rank of the layout of tokens. This is only respected by the "cutlass-fna" backend, and only when KV parallelism is enabled.

backward_use_pt_reduction bool

Whether to use PyTorch eager for computing the dO * O product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fna" backend.

run_persistent_kernel bool

Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fna" backend.

kernel_schedule Optional[str]

Kernel type (Hopper architecture only). Choices are None: pick the default, "non" (non-persistent), "coop" (warp-specialized cooperative), or "pp" (warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.

torch_compile bool

Applies only to the "flex-fna" backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile.

attention_kwargs Optional[Dict]

arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention.

If additional_{keys,values} are specified, NATTEN usually performs a separate cross-attention using our attention operator, and merges the results.

If for a given use case, the neighborhood attention problem is equivalent to self attention (not causal, kernel_size == seqlen), NATTEN will also attempt to directly use attention.

You can override arguments to attention by passing a dictionary here.

Example

out = na1d(
    q, k, v, kernel_size=kernel_size,
    ...,
    attention_kwargs={
        "backend": "blackwell-fmha",
        "run_persistent_kernel": True,
    }
)
try_fuse_additional_kv bool

Some backends may support fusing cross-attention (additional KV) into the FNA kernel, instead of having to do a separate attention and then merge. This can only be supported in backends using Token Permutation for now, which means when there is dilation, there could be additional memory operations and memory usage if this fusion occurs. For now, only the "blackwell-fna" backend supports this. We recommend using the profiler to see if this option is suitable for your use case before trying it.

Returns:

Name Type Description
output Tensor

4-D output tensor, with the heads last layout ([batch, seqlen, heads, head_dim]).

natten.na2d

na2d(
    query,
    key,
    value,
    kernel_size,
    stride=1,
    dilation=1,
    is_causal=False,
    scale=None,
    additional_keys=None,
    additional_values=None,
    attention_kwargs=None,
    backend=None,
    q_tile_shape=None,
    kv_tile_shape=None,
    backward_q_tile_shape=None,
    backward_kv_tile_shape=None,
    backward_kv_splits=None,
    backward_use_pt_reduction=False,
    run_persistent_kernel=True,
    kernel_schedule=None,
    torch_compile=False,
    try_fuse_additional_kv=False,
)

Computes 2-D neighborhood attention.

Parameters:

Name Type Description Default
query Tensor

2-D query tensor, with the heads last layout: [batch, X, Y, heads, head_dim], where token layout shape (feature map shape) is (X, Y).

required
key Tensor

2-D key tensor, with the heads last layout: [batch, X, Y, heads, head_dim], where token layout shape (feature map shape) is (X, Y).

required
value Tensor

2-D value tensor, with the heads last layout: [batch, X, Y, heads, head_dim], where token layout shape (feature map shape) is (X, Y).

required
kernel_size Tuple[int, int] | int

Neighborhood window (kernel) size/shape. If an integer, it will be repeated for all 2 dimensions. For example kernel_size=3 is reinterpreted as kernel_size=(3, 3).

Note

kernel_size must be smaller than or equal to token layout shape ((X, Y)) along every dimension.

required
stride Tuple[int, int] | int

Sliding window step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 2 dimensions. For example stride=2 is reinterpreted as stride=(2, 2).

Note

stride must be smaller than or equal to kernel_size along every dimension. When stride == kernel_size, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).

1
dilation Tuple[int, int] | int

Dilation step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 2 dimensions. For example dilation=4 is reinterpreted as dilation=(4, 4).

Note

The product of dilation and kernel_size must be smaller than or equal to token layout shape ((X, Y)) along every dimension.

1
is_causal Tuple[bool, bool] | bool

Toggle causal masking. Defaults to False (bi-directional). If a boolean, it will be repeated for all 2 dimensions. For example is_causal=True is reinterpreted as is_causal=(True, True).

False
scale float

Attention scale. Defaults to head_dim ** -0.5.

None
additional_keys Optional[Tensor]

None or 4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]), corresponding to key tokens from some additional context. Used when performing neighborhood cross-attention, where query tokens attend to their neighborhood, as well as some fixed additional set of tokens.

None
additional_values Optional[Tensor]

None or 4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]), corresponding to value tokens from some additional context. Used when performing neighborhood cross-attention, where query tokens attend to their neighborhood, as well as some fixed additional set of tokens.

Note

additional_keys and additional_values must both either be Tensors, or both Nones, and must match in shape.

None

Other Parameters:

Name Type Description
backend str

Backend implementation to run with. Choices are: None (pick the best available one), "cutlass-fna", "hopper-fna", "blackwell-fna", "flex-fna". Refer to backends for more information.

q_tile_shape Tuple[int, int]

2-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.

kv_tile_shape Tuple[int, int]

2-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_q_tile_shape Tuple[int, int]

2-D Tile shape for the query token layout in the backward pass kernel. This is only respected by the "cutlass-fna" backend. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_kv_tile_shape Tuple[int, int]

2-D Tile shape for the key/value token layout in the backward pass kernel. This is only respected by the "cutlass-fna" backend. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_kv_splits Tuple[int, int]

Number of key/value tiles allowed to work in parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for neighborhood attention operations, and the size of the tuple corresponds to the number of dimensions / rank of the layout of tokens. This is only respected by the "cutlass-fna" backend, and only when KV parallelism is enabled.

backward_use_pt_reduction bool

Whether to use PyTorch eager for computing the dO * O product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fna" backend.

run_persistent_kernel bool

Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fna" backend.

kernel_schedule Optional[str]

Kernel type (Hopper architecture only). Choices are None: pick the default, "non" (non-persistent), "coop" (warp-specialized cooperative), or "pp" (warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.

torch_compile bool

Applies only to the "flex-fna" backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile.

attention_kwargs Optional[Dict]

arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention.

If additional_{keys,values} are specified, NATTEN usually performs a separate cross-attention using our attention operator, and merges the results.

If for a given use case, the neighborhood attention problem is equivalent to self attention (not causal along any dims, kernel_size == (X, Y)), NATTEN will also attempt to directly use attention.

You can override arguments to attention by passing a dictionary here.

Example

out = na2d(
    q, k, v, kernel_size=kernel_size,
    ...,
    attention_kwargs={
        "backend": "blackwell-fmha",
        "run_persistent_kernel": True,
    }
)
try_fuse_additional_kv bool

Some backends may support fusing cross-attention (additional KV) into the FNA kernel, instead of having to do a separate attention and then merge. This can only be supported in backends using Token Permutation for now, which means when there is dilation, there could be additional memory operations and memory usage if this fusion occurs. For now, only the "blackwell-fna" backend supports this. We recommend using the profiler to see if this option is suitable for your use case before trying it.

Returns:

Name Type Description
output Tensor

5-D output tensor, with the heads last layout ([batch, X, Y, heads, head_dim]).

natten.na3d

na3d(
    query,
    key,
    value,
    kernel_size,
    stride=1,
    dilation=1,
    is_causal=False,
    scale=None,
    additional_keys=None,
    additional_values=None,
    attention_kwargs=None,
    backend=None,
    q_tile_shape=None,
    kv_tile_shape=None,
    backward_q_tile_shape=None,
    backward_kv_tile_shape=None,
    backward_kv_splits=None,
    backward_use_pt_reduction=False,
    run_persistent_kernel=True,
    kernel_schedule=None,
    torch_compile=False,
    try_fuse_additional_kv=False,
)

Computes 3-D neighborhood attention.

Parameters:

Name Type Description Default
query Tensor

3-D query tensor, with the heads last layout: [batch, X, Y, Z, heads, head_dim], where token layout shape (feature map shape) is (X, Y, Z).

required
key Tensor

3-D key tensor, with the heads last layout: [batch, X, Y, Z, heads, head_dim], where token layout shape (feature map shape) is (X, Y, Z).

required
value Tensor

3-D value tensor, with the heads last layout: [batch, X, Y, Z, heads, head_dim], where token layout shape (feature map shape) is (X, Y, Z).

required
kernel_size Tuple[int, int, int] | int

Neighborhood window (kernel) size/shape. If an integer, it will be repeated for all 3 dimensions. For example kernel_size=3 is reinterpreted as kernel_size=(3, 3, 3).

Note

kernel_size must be smaller than or equal to token layout shape ((X, Y, Z)) along every dimension.

required
stride Tuple[int, int, int] | int

Sliding window step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 3 dimensions. For example stride=2 is reinterpreted as stride=(2, 2, 2).

Note

stride must be smaller than or equal to kernel_size along every dimension. When stride == kernel_size, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).

1
dilation Tuple[int, int, int] | int

Dilation step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 3 dimensions. For example dilation=4 is reinterpreted as dilation=(4, 4, 4).

Note

The product of dilation and kernel_size must be smaller than or equal to token layout shape ((X, Y, Z)) along every dimension.

1
is_causal Tuple[bool, bool, bool] | bool

Toggle causal masking. Defaults to False (bi-directional). If a boolean, it will be repeated for all 3 dimensions. For example is_causal=True is reinterpreted as is_causal=(True, True, True).

False
scale float

Attention scale. Defaults to head_dim ** -0.5.

None
additional_keys Optional[Tensor]

None or 4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]), corresponding to key tokens from some additional context. Used when performing neighborhood cross-attention, where query tokens attend to their neighborhood, as well as some fixed additional set of tokens.

None
additional_values Optional[Tensor]

None or 4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]), corresponding to value tokens from some additional context. Used when performing neighborhood cross-attention, where query tokens attend to their neighborhood, as well as some fixed additional set of tokens.

Note

additional_keys and additional_values must both either be Tensors, or both Nones, and must match in shape.

None

Other Parameters:

Name Type Description
backend str

Backend implementation to run with. Choices are: None (pick the best available one), "cutlass-fna", "hopper-fna", "blackwell-fna", "flex-fna". Refer to backends for more information.

q_tile_shape Tuple[int, int, int]

3-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.

kv_tile_shape Tuple[int, int, int]

3-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_q_tile_shape Tuple[int, int, int]

3-D Tile shape for the query token layout in the backward pass kernel. This is only respected by the "cutlass-fna" backend. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_kv_tile_shape Tuple[int, int, int]

3-D Tile shape for the key/value token layout in the backward pass kernel. This is only respected by the "cutlass-fna" backend. You can use profiler to find valid choices for your use case, and search for the best combination.

backward_kv_splits Tuple[int, int, int]

Number of key/value tiles allowed to work in parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for neighborhood attention operations, and the size of the tuple corresponds to the number of dimensions / rank of the layout of tokens. This is only respected by the "cutlass-fna" backend, and only when KV parallelism is enabled.

backward_use_pt_reduction bool

Whether to use PyTorch eager for computing the dO * O product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fna" backend.

run_persistent_kernel bool

Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fna" backend.

kernel_schedule Optional[str]

Kernel type (Hopper architecture only). Choices are None: pick the default, "non" (non-persistent), "coop" (warp-specialized cooperative), or "pp" (warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.

torch_compile bool

Applies only to the "flex-fna" backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile.

attention_kwargs Optional[Dict]

arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention.

If additional_{keys,values} are specified, NATTEN usually performs a separate cross-attention using our attention operator, and merges the results.

If for a given use case, the neighborhood attention problem is equivalent to self attention (not causal along any dims, kernel_size == (X, Y, Z)), NATTEN will also attempt to directly use attention.

You can override arguments to attention by passing a dictionary here.

Example

out = na3d(
    q, k, v, kernel_size=kernel_size,
    ...,
    attention_kwargs={
        "backend": "blackwell-fmha",
        "run_persistent_kernel": True,
    }
)
try_fuse_additional_kv bool

Some backends may support fusing cross-attention (additional KV) into the FNA kernel, instead of having to do a separate attention and then merge. This can only be supported in backends using Token Permutation for now, which means when there is dilation, there could be additional memory operations and memory usage if this fusion occurs. For now, only the "blackwell-fna" backend supports this. We recommend using the profiler to see if this option is suitable for your use case before trying it.

Returns:

Name Type Description
output Tensor

6-D output tensor, with the heads last layout ([batch, X, Y, Z, heads, head_dim]).

Standard Attention

natten.attention

attention(
    query,
    key,
    value,
    scale=None,
    backend=None,
    q_tile_size=None,
    kv_tile_size=None,
    backward_q_tile_size=None,
    backward_kv_tile_size=None,
    backward_kv_splits=None,
    backward_use_pt_reduction=False,
    run_persistent_kernel=True,
    kernel_schedule=None,
    torch_compile=False,
    return_lse=False,
)

Runs standard dot product attention.

This operation is used to implement neighborhood cross attention, in which we allow every token to interact with some additional context (additional_keys and additional_values tensors in na1d, na2d, and na3d). This operator is also used as a fast path for cases where neighborhood attention is equivalent to self attention (not causal along any dims, and kernel_size is equal to the number of input tokens).

This operation does not call into PyTorch's SDPA, and only runs one of the NATTEN backends (cutlass-fmha, hopper-fmha, blackwell-fmha, flex-fmha). Reasons for that include being able to control performance-related arguments, return logsumexp, and more. For more information refer to backends.

Parameters:

Name Type Description Default
query Tensor

4-D query tensor, with the heads last layout ([batch, seqlen, heads, head_dim])

required
key Tensor

4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim])

required
value Tensor

4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim])

required
scale float

Attention scale. Defaults to head_dim ** -0.5.

None

Other Parameters:

Name Type Description
backend str

Backend implementation to run with. Choices are: None (pick the best available one), "cutlass-fmha", "hopper-fmha", "blackwell-fmha", "flex-fmha". Refer to backends for more information.

q_tile_size int

Tile size along query sequence length in the forward pass kernel. You can use profiler to find valid choices for your use case.

kv_tile_size int

Tile size along key/value sequence length in the forward pass kernel. You can use profiler to find valid choices for your use case.

backward_q_tile_size int

Tile size along query sequence length in the backward pass kernel. This is only respected by the "cutlass-fmha" backend. You can use profiler to find valid choices for your use case.

backward_kv_tile_size int

Tile size along key/value sequence length in the backward pass kernel. This is only respected by the "cutlass-fmha" backend. You can use profiler to find valid choices for your use case.

backward_kv_splits int

Number of key/value tiles allowed to work in parallel in the backward pass kernel. This is only respected by the "cutlass-fmha" backend, only when KV parallelism is enabled.

backward_use_pt_reduction bool

Whether to use PyTorch eager for computing the dO * O product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fmha" backend.

run_persistent_kernel bool

Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fmha" backend.

kernel_schedule Optional[str]

Kernel type (Hopper architecture only). Choices are None: pick the default, "non" (non-persistent), "coop" (warp-specialized cooperative), or "pp" (warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.

torch_compile bool

Applies only to the "flex-fmha" backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile.

return_lse bool

Whether or not to return the logsumexp tensor. logsumexp can be used in the backward pass, and for attention merging.

Returns:

Name Type Description
output Tensor

4-D output tensor, with the heads last layout ([batch, seqlen, heads, head_dim]).

logsumexp Tensor

only returned when return_lse=True. 3-D logsumexp tensor, with the heads last layout ([batch, seqlen, heads]).

natten.merge_attentions

merge_attentions(outputs, lse_tensors, torch_compile=True)

Takes multiple attention outputs originating from the same query tensor, and their corresponding logsumexps, and merges them as if their context (key/value pair) had been concatenated.

This operation is used to implement cross-neighborhood attention, and can also be used for distributed setups, such as context-parallelism.

This operation also attempts to use torch.compile to fuse the elementwise operations. This can be disabled by passing torch_compile=False.

Parameters:

Name Type Description Default
outputs List[Tensor]

List of 4-D attention output tensors, with the heads last layout ([batch, seqlen, heads, head_dim])

required
lse_tensors List[Tensor]

List of 3-D logsumexp tensors, with the heads last layout ([batch, seqlen, heads])

required
torch_compile bool

Attempt to use torch.compile to fuse the underlying elementwise operations.

True

Returns:

Name Type Description
output Tensor

merged attention output.