Operations
In this page we list our PyTorch Autograd-compatible operations. These operations come with performance knobs (configurations), some of which are specific to certain backends.
Changing those knobs is completely optional, and NATTEN will continue to be functionally correct in all cases. However, to squeeze out the maximum performance achievable, we highly recommend looking at backends, or just using our profiler toolkit and its dry run feature to navigate through available backends and their valid configurations for your specific use case and GPU architecture. You can also use the profiler's optimize feature to search and find the best configuration.
Neighborhood Attention
natten.na1d
na1d(
query,
key,
value,
kernel_size,
stride=1,
dilation=1,
is_causal=False,
scale=None,
additional_keys=None,
additional_values=None,
attention_kwargs=None,
backend=None,
q_tile_shape=None,
kv_tile_shape=None,
backward_q_tile_shape=None,
backward_kv_tile_shape=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
try_fuse_additional_kv=False,
)
Computes 1-D neighborhood attention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query
|
Tensor
|
4-D query tensor, with the heads last layout
( |
required |
key
|
Tensor
|
4-D key tensor, with the heads last layout
( |
required |
value
|
Tensor
|
4-D value tensor, with the heads last layout
( |
required |
kernel_size
|
Tuple[int] | int
|
Neighborhood window (kernel) size. Note
|
required |
stride
|
Tuple[int] | int
|
Sliding window step size. Defaults to Note
|
1
|
dilation
|
Tuple[int] | int
|
Dilation step size. Defaults to Note The product of |
1
|
is_causal
|
Tuple[bool] | bool
|
Toggle causal masking. Defaults to |
False
|
scale
|
float
|
Attention scale. Defaults to |
None
|
additional_keys
|
Optional[Tensor]
|
|
None
|
additional_values
|
Optional[Tensor]
|
Note
|
None
|
Other Parameters:
Name | Type | Description |
---|---|---|
backend |
str
|
Backend implementation to run with. Choices are: |
q_tile_shape |
Tuple[int]
|
1-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination. |
kv_tile_shape |
Tuple[int]
|
1-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination. |
backward_q_tile_shape |
Tuple[int]
|
1-D Tile shape for the query token layout in the
backward pass kernel. This is only respected by the |
backward_kv_tile_shape |
Tuple[int]
|
1-D Tile shape for the key/value token layout in the
backward pass kernel. This is only respected by the |
backward_kv_splits |
Tuple[int]
|
Number of key/value tiles allowed to work in parallel in
the backward pass kernel. Like tile shapes, this is a tuple and not an integer for
neighborhood attention operations, and the size of the tuple corresponds to the number
of dimensions / rank of the layout of tokens. This is only respected by the
|
backward_use_pt_reduction |
bool
|
Whether to use PyTorch eager for computing the |
run_persistent_kernel |
bool
|
Whether to use persistent tile scheduling in the forward pass
kernel. This only applies to the |
kernel_schedule |
Optional[str]
|
Kernel type (Hopper architecture only). Choices are
|
torch_compile |
bool
|
Applies only to the |
attention_kwargs |
Optional[Dict]
|
arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention. If If for a given use case, the neighborhood attention problem is equivalent to self
attention (not causal, You can override arguments to attention by passing a dictionary here. |
try_fuse_additional_kv |
bool
|
Some backends may support fusing cross-attention (additional
KV) into the FNA kernel, instead of having to do a separate
attention and then merge. This can only
be supported in backends using Token Permutation for now, which means when there is
dilation, there could be additional memory operations and memory usage if this fusion
occurs. For now, only the |
Returns:
Name | Type | Description |
---|---|---|
output |
Tensor
|
4-D output tensor, with the heads last layout
( |
natten.na2d
na2d(
query,
key,
value,
kernel_size,
stride=1,
dilation=1,
is_causal=False,
scale=None,
additional_keys=None,
additional_values=None,
attention_kwargs=None,
backend=None,
q_tile_shape=None,
kv_tile_shape=None,
backward_q_tile_shape=None,
backward_kv_tile_shape=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
try_fuse_additional_kv=False,
)
Computes 2-D neighborhood attention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query
|
Tensor
|
2-D query tensor, with the heads last layout:
|
required |
key
|
Tensor
|
2-D key tensor, with the heads last layout:
|
required |
value
|
Tensor
|
2-D value tensor, with the heads last layout:
|
required |
kernel_size
|
Tuple[int, int] | int
|
Neighborhood window (kernel) size/shape. If an
integer, it will be repeated for all 2 dimensions. For example Note
|
required |
stride
|
Tuple[int, int] | int
|
Sliding window step size/shape. Defaults to Note
|
1
|
dilation
|
Tuple[int, int] | int
|
Dilation step size/shape. Defaults to Note The product of |
1
|
is_causal
|
Tuple[bool, bool] | bool
|
Toggle causal masking. Defaults to |
False
|
scale
|
float
|
Attention scale. Defaults to |
None
|
additional_keys
|
Optional[Tensor]
|
|
None
|
additional_values
|
Optional[Tensor]
|
Note
|
None
|
Other Parameters:
Name | Type | Description |
---|---|---|
backend |
str
|
Backend implementation to run with. Choices are: |
q_tile_shape |
Tuple[int, int]
|
2-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination. |
kv_tile_shape |
Tuple[int, int]
|
2-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination. |
backward_q_tile_shape |
Tuple[int, int]
|
2-D Tile shape for the query token layout in the
backward pass kernel. This is only respected by the |
backward_kv_tile_shape |
Tuple[int, int]
|
2-D Tile shape for the key/value token layout in
the backward pass kernel. This is only respected by the |
backward_kv_splits |
Tuple[int, int]
|
Number of key/value tiles allowed to work in parallel
in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for
neighborhood attention operations, and the size of the tuple corresponds to the number
of dimensions / rank of the layout of tokens. This is only respected by the
|
backward_use_pt_reduction |
bool
|
Whether to use PyTorch eager for computing the |
run_persistent_kernel |
bool
|
Whether to use persistent tile scheduling in the forward pass
kernel. This only applies to the |
kernel_schedule |
Optional[str]
|
Kernel type (Hopper architecture only). Choices are
|
torch_compile |
bool
|
Applies only to the |
attention_kwargs |
Optional[Dict]
|
arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention. If If for a given use case, the neighborhood attention problem is equivalent to self
attention (not causal along any dims, You can override arguments to attention by passing a dictionary here. |
try_fuse_additional_kv |
bool
|
Some backends may support fusing cross-attention (additional
KV) into the FNA kernel, instead of having to do a separate
attention and then merge. This can only
be supported in backends using Token Permutation for now, which means when there is
dilation, there could be additional memory operations and memory usage if this fusion
occurs. For now, only the |
Returns:
Name | Type | Description |
---|---|---|
output |
Tensor
|
5-D output tensor, with the heads last layout
( |
natten.na3d
na3d(
query,
key,
value,
kernel_size,
stride=1,
dilation=1,
is_causal=False,
scale=None,
additional_keys=None,
additional_values=None,
attention_kwargs=None,
backend=None,
q_tile_shape=None,
kv_tile_shape=None,
backward_q_tile_shape=None,
backward_kv_tile_shape=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
try_fuse_additional_kv=False,
)
Computes 3-D neighborhood attention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query
|
Tensor
|
3-D query tensor, with the heads last layout:
|
required |
key
|
Tensor
|
3-D key tensor, with the heads last layout:
|
required |
value
|
Tensor
|
3-D value tensor, with the heads last layout:
|
required |
kernel_size
|
Tuple[int, int, int] | int
|
Neighborhood window (kernel) size/shape. If an
integer, it will be repeated for all 3 dimensions. For example Note
|
required |
stride
|
Tuple[int, int, int] | int
|
Sliding window step size/shape. Defaults to Note
|
1
|
dilation
|
Tuple[int, int, int] | int
|
Dilation step size/shape. Defaults to Note The product of |
1
|
is_causal
|
Tuple[bool, bool, bool] | bool
|
Toggle causal masking. Defaults to |
False
|
scale
|
float
|
Attention scale. Defaults to |
None
|
additional_keys
|
Optional[Tensor]
|
|
None
|
additional_values
|
Optional[Tensor]
|
Note
|
None
|
Other Parameters:
Name | Type | Description |
---|---|---|
backend |
str
|
Backend implementation to run with. Choices are: |
q_tile_shape |
Tuple[int, int, int]
|
3-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination. |
kv_tile_shape |
Tuple[int, int, int]
|
3-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination. |
backward_q_tile_shape |
Tuple[int, int, int]
|
3-D Tile shape for the query token layout in
the backward pass kernel. This is only respected by the |
backward_kv_tile_shape |
Tuple[int, int, int]
|
3-D Tile shape for the key/value token
layout in the backward pass kernel. This is only respected by the |
backward_kv_splits |
Tuple[int, int, int]
|
Number of key/value tiles allowed to work in
parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an
integer for neighborhood attention operations, and the size of the tuple corresponds to
the number of dimensions / rank of the layout of tokens. This is only respected by the
|
backward_use_pt_reduction |
bool
|
Whether to use PyTorch eager for computing the |
run_persistent_kernel |
bool
|
Whether to use persistent tile scheduling in the forward pass
kernel. This only applies to the |
kernel_schedule |
Optional[str]
|
Kernel type (Hopper architecture only). Choices are
|
torch_compile |
bool
|
Applies only to the |
attention_kwargs |
Optional[Dict]
|
arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention. If If for a given use case, the neighborhood attention problem is equivalent to self
attention (not causal along any dims, You can override arguments to attention by passing a dictionary here. |
try_fuse_additional_kv |
bool
|
Some backends may support fusing cross-attention (additional
KV) into the FNA kernel, instead of having to do a separate
attention and then merge. This can only
be supported in backends using Token Permutation for now, which means when there is
dilation, there could be additional memory operations and memory usage if this fusion
occurs. For now, only the |
Returns:
Name | Type | Description |
---|---|---|
output |
Tensor
|
6-D output tensor, with the heads last layout
( |
Standard Attention
natten.attention
attention(
query,
key,
value,
scale=None,
backend=None,
q_tile_size=None,
kv_tile_size=None,
backward_q_tile_size=None,
backward_kv_tile_size=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
return_lse=False,
)
Runs standard dot product attention.
This operation is used to implement neighborhood cross attention, in which we allow every
token to interact with some additional context (additional_keys
and additional_values
tensors in na1d, na2d, and na3d).
This operator is also used as a fast path for cases where neighborhood attention is equivalent
to self attention (not causal along any dims, and kernel_size
is equal to the number of input
tokens).
This operation does not call into PyTorch's SDPA, and only runs one of the NATTEN backends
(cutlass-fmha
, hopper-fmha
, blackwell-fmha
, flex-fmha
). Reasons for that include being
able to control performance-related arguments, return logsumexp, and more.
For more information refer to backends.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query
|
Tensor
|
4-D query tensor, with the heads last layout
( |
required |
key
|
Tensor
|
4-D key tensor, with the heads last layout
( |
required |
value
|
Tensor
|
4-D value tensor, with the heads last layout
( |
required |
scale
|
float
|
Attention scale. Defaults to |
None
|
Other Parameters:
Name | Type | Description |
---|---|---|
backend |
str
|
Backend implementation to run with. Choices are: |
q_tile_size |
int
|
Tile size along query sequence length in the forward pass kernel. You can use profiler to find valid choices for your use case. |
kv_tile_size |
int
|
Tile size along key/value sequence length in the forward pass kernel. You can use profiler to find valid choices for your use case. |
backward_q_tile_size |
int
|
Tile size along query sequence length in the backward pass
kernel. This is only respected by the |
backward_kv_tile_size |
int
|
Tile size along key/value sequence length in the backward pass
kernel. This is only respected by the |
backward_kv_splits |
int
|
Number of key/value tiles allowed to work in parallel in the
backward pass kernel. This is only respected by the |
backward_use_pt_reduction |
bool
|
Whether to use PyTorch eager for computing the |
run_persistent_kernel |
bool
|
Whether to use persistent tile scheduling in the forward pass
kernel. This only applies to the |
kernel_schedule |
Optional[str]
|
Kernel type (Hopper architecture only). Choices are
|
torch_compile |
bool
|
Applies only to the |
return_lse |
bool
|
Whether or not to return the |
Returns:
Name | Type | Description |
---|---|---|
output |
Tensor
|
4-D output tensor, with the heads last layout
( |
logsumexp |
Tensor
|
only returned when |
natten.merge_attentions
Takes multiple attention outputs originating from the same query tensor, and their corresponding logsumexps, and merges them as if their context (key/value pair) had been concatenated.
This operation is used to implement cross-neighborhood attention, and can also be used for distributed setups, such as context-parallelism.
This operation also attempts to use torch.compile
to fuse the elementwise operations. This
can be disabled by passing torch_compile=False
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
outputs
|
List[Tensor]
|
List of 4-D attention output tensors, with the heads last layout
( |
required |
lse_tensors
|
List[Tensor]
|
List of 3-D logsumexp tensors, with the heads last layout
( |
required |
torch_compile
|
bool
|
Attempt to use |
True
|
Returns:
Name | Type | Description |
---|---|---|
output |
Tensor
|
merged attention output. |