Skip to content

Modules

We offer torch modules as well for easy integration into your neural network:

natten.NeighborhoodAttention1D

1-D Neighborhood Attention torch module.

Performs QKV and output linear projections in addition to the na1d operation.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension size (a.k.a. number of channels, latent size).

Note

This is not head_dim. It's head_dim * num_heads.

required
num_heads int

Number of attention heads.

required
kernel_size Tuple[int] | int

Neighborhood window (kernel) size.

Note

kernel_size must be smaller than or equal to seqlen.

required
stride Tuple[int] | int

Sliding window step size. Defaults to 1 (standard sliding window).

Note

stride must be smaller than or equal to kernel_size. When stride == kernel_size, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).

1
dilation Tuple[int] | int

Dilation step size. Defaults to 1 (standard sliding window).

Note

The product of dilation and kernel_size must be smaller than or equal to seqlen.

1
is_causal Tuple[bool] | bool

Toggle causal masking. Defaults to False (bi-directional).

False
qkv_bias bool

Enable bias in the QKV linear projection.

True
qk_scale Optional[float]

Attention scale. Defaults to head_dim ** -0.5.

None
proj_drop float

Dropout score for projection layer. Defaults is 0.0 (no dropout).

0.0
Example
import torch
from natten import NeighborhoodAttention1D

num_heads = 4
head_dim = 128
embed_dim = num_heads * head_dim

model = NeighborhoodAttention1D(
    embed_dim=embed_dim,
    num_heads=num_heads,
    kernel_size=2048,
    stride=2,
    dilation=4,
    is_causal=True
)

batch = 1
seqlen = 4096 # (1)!

x = torch.randn(batch, seqlen, embed_dim) # (2)!
y = model(x) # (3)!
  1. Tokens are arranged in a sequential layout of size 4096, to which we apply a kernel size of 2048, stride 2, dilation 4, and apply causal masking.

  2. x.shape == [1, 4096, 512]

  3. y.shape == [1, 4096, 512]

natten.NeighborhoodAttention2D

2-D Neighborhood Attention torch module.

Performs QKV and output linear projections in addition to the na2d operation.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension size (a.k.a. number of channels, latent size).

Note

This is not head_dim. It's head_dim * num_heads.

required
num_heads int

Number of attention heads.

required
kernel_size Tuple[int, int] | int

Neighborhood window (kernel) size/shape. If an integer, it will be repeated for all 2 dimensions. For example kernel_size=3 is reinterpreted as kernel_size=(3, 3).

Note

kernel_size must be smaller than or equal to token layout shape ((X, Y)) along every dimension.

required
stride Tuple[int, int] | int

Sliding window step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 2 dimensions. For example stride=2 is reinterpreted as stride=(2, 2).

Note

stride must be smaller than or equal to kernel_size along every dimension. When stride == kernel_size, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).

1
dilation Tuple[int, int] | int

Dilation step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 2 dimensions. For example dilation=4 is reinterpreted as dilation=(4, 4).

Note

The product of dilation and kernel_size must be smaller than or equal to token layout shape ((X, Y)) along every dimension.

1
is_causal Tuple[bool, bool] | bool

Toggle causal masking. Defaults to False (bi-directional). If a boolean, it will be repeated for all 2 dimensions. For example is_causal=True is reinterpreted as is_causal=(True, True).

False
qkv_bias bool

Enable bias in the QKV linear projection.

True
qk_scale Optional[float]

Attention scale. Defaults to head_dim ** -0.5.

None
proj_drop float

Dropout score for projection layer. Defaults is 0.0 (no dropout).

0.0
Example
import torch
from natten import NeighborhoodAttention2D

num_heads = 4
head_dim = 128
embed_dim = num_heads * head_dim

model = NeighborhoodAttention2D(
    embed_dim=embed_dim,
    num_heads=num_heads,
    kernel_size=(8, 16),
    stride=(1, 2),
    dilation=(2, 1),
    is_causal=False
)

batch = 1
token_layout_shape = (16, 32) # (1)!

x = torch.randn(batch, *token_layout_shape, embed_dim) # (2)!
y = model(x) # (3)!
  1. Tokens are arranged in a 16 x 32 layout, to which we apply a kernel size of 8 x 16, stride 1 x 2, and dilation 2 x 1.

  2. x.shape == [1, 16, 32, 512]

  3. y.shape == [1, 16, 32, 512]

natten.NeighborhoodAttention3D

3-D Neighborhood Attention torch module.

Performs QKV and output linear projections in addition to the na3d operation.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension size (a.k.a. number of channels, latent size).

Note

This is not head_dim. It's head_dim * num_heads.

required
num_heads int

Number of attention heads.

required
kernel_size Tuple[int, int, int] | int

Neighborhood window (kernel) size/shape. If an integer, it will be repeated for all 3 dimensions. For example kernel_size=3 is reinterpreted as kernel_size=(3, 3, 3).

Note

kernel_size must be smaller than or equal to token layout shape ((X, Y, Z)) along every dimension.

required
stride Tuple[int, int, int] | int

Sliding window step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 3 dimensions. For example stride=2 is reinterpreted as stride=(2, 2, 2).

Note

stride must be smaller than or equal to kernel_size along every dimension. When stride == kernel_size, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).

1
dilation Tuple[int, int, int] | int

Dilation step size/shape. Defaults to 1 (standard sliding window). If an integer, it will be repeated for all 3 dimensions. For example dilation=4 is reinterpreted as dilation=(4, 4, 4).

Note

The product of dilation and kernel_size must be smaller than or equal to token layout shape ((X, Y, Z)) along every dimension.

1
is_causal Tuple[bool, bool, bool] | bool

Toggle causal masking. Defaults to False (bi-directional). If a boolean, it will be repeated for all 3 dimensions. For example is_causal=True is reinterpreted as is_causal=(True, True, True).

False
qkv_bias bool

Enable bias in the QKV linear projection.

True
qk_scale Optional[float]

Attention scale. Defaults to head_dim ** -0.5.

None
proj_drop float

Dropout score for projection layer. Defaults is 0.0 (no dropout).

0.0
Example
import torch
from natten import NeighborhoodAttention3D

num_heads = 4
head_dim = 128
embed_dim = num_heads * head_dim

model = NeighborhoodAttention3D(
    embed_dim=embed_dim,
    num_heads=num_heads,
    kernel_size=(4, 8, 12),
    stride=(1, 1, 4),
    dilation=(1, 2, 1),
    is_causal=(True, False, False)
)

batch = 1
token_layout_shape = (12, 16, 20) # (1)!

x = torch.randn(batch, *token_layout_shape, embed_dim) # (2)!
y = model(x) # (3)!
  1. Tokens are arranged in a 12 x 16 x 20 layout, to which we apply a kernel size of 4 x 8 x 12, stride 1 x 1 x 4, dilation 1 x 2 x 1, and apply causal masking to the left-most dimension (12).

  2. x.shape == [1, 12, 16, 20, 512]

  3. y.shape == [1, 12, 16, 20, 512]