Modules
We offer torch modules as well for easy integration into your neural network:
natten.NeighborhoodAttention1D
1-D Neighborhood Attention torch module.
Performs QKV and output linear projections in addition to the na1d operation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Embedding dimension size (a.k.a. number of channels, latent size). Note This is not |
required |
num_heads
|
int
|
Number of attention heads. |
required |
kernel_size
|
Tuple[int] | int
|
Neighborhood window (kernel) size. Note
|
required |
stride
|
Tuple[int] | int
|
Sliding window step size. Defaults to Note
|
1
|
dilation
|
Tuple[int] | int
|
Dilation step size. Defaults to Note The product of |
1
|
is_causal
|
Tuple[bool] | bool
|
Toggle causal masking. Defaults to |
False
|
qkv_bias
|
bool
|
Enable bias in the QKV linear projection. |
True
|
qk_scale
|
Optional[float]
|
Attention scale. Defaults to |
None
|
proj_drop
|
float
|
Dropout score for projection layer. Defaults is |
0.0
|
Example
import torch
from natten import NeighborhoodAttention1D
num_heads = 4
head_dim = 128
embed_dim = num_heads * head_dim
model = NeighborhoodAttention1D(
embed_dim=embed_dim,
num_heads=num_heads,
kernel_size=2048,
stride=2,
dilation=4,
is_causal=True
)
batch = 1
seqlen = 4096 # (1)!
x = torch.randn(batch, seqlen, embed_dim) # (2)!
y = model(x) # (3)!
-
Tokens are arranged in a sequential layout of size 4096, to which we apply a kernel size of 2048, stride 2, dilation 4, and apply causal masking.
-
x.shape == [1, 4096, 512]
y.shape == [1, 4096, 512]
natten.NeighborhoodAttention2D
2-D Neighborhood Attention torch module.
Performs QKV and output linear projections in addition to the na2d operation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Embedding dimension size (a.k.a. number of channels, latent size). Note This is not |
required |
num_heads
|
int
|
Number of attention heads. |
required |
kernel_size
|
Tuple[int, int] | int
|
Neighborhood window (kernel) size/shape. If an
integer, it will be repeated for all 2 dimensions. For example Note
|
required |
stride
|
Tuple[int, int] | int
|
Sliding window step size/shape. Defaults to Note
|
1
|
dilation
|
Tuple[int, int] | int
|
Dilation step size/shape. Defaults to Note The product of |
1
|
is_causal
|
Tuple[bool, bool] | bool
|
Toggle causal masking. Defaults to |
False
|
qkv_bias
|
bool
|
Enable bias in the QKV linear projection. |
True
|
qk_scale
|
Optional[float]
|
Attention scale. Defaults to |
None
|
proj_drop
|
float
|
Dropout score for projection layer. Defaults is |
0.0
|
Example
import torch
from natten import NeighborhoodAttention2D
num_heads = 4
head_dim = 128
embed_dim = num_heads * head_dim
model = NeighborhoodAttention2D(
embed_dim=embed_dim,
num_heads=num_heads,
kernel_size=(8, 16),
stride=(1, 2),
dilation=(2, 1),
is_causal=False
)
batch = 1
token_layout_shape = (16, 32) # (1)!
x = torch.randn(batch, *token_layout_shape, embed_dim) # (2)!
y = model(x) # (3)!
-
Tokens are arranged in a 16 x 32 layout, to which we apply a kernel size of 8 x 16, stride 1 x 2, and dilation 2 x 1.
-
x.shape == [1, 16, 32, 512]
y.shape == [1, 16, 32, 512]
natten.NeighborhoodAttention3D
3-D Neighborhood Attention torch module.
Performs QKV and output linear projections in addition to the na3d operation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Embedding dimension size (a.k.a. number of channels, latent size). Note This is not |
required |
num_heads
|
int
|
Number of attention heads. |
required |
kernel_size
|
Tuple[int, int, int] | int
|
Neighborhood window (kernel) size/shape. If an
integer, it will be repeated for all 3 dimensions. For example Note
|
required |
stride
|
Tuple[int, int, int] | int
|
Sliding window step size/shape. Defaults to Note
|
1
|
dilation
|
Tuple[int, int, int] | int
|
Dilation step size/shape. Defaults to Note The product of |
1
|
is_causal
|
Tuple[bool, bool, bool] | bool
|
Toggle causal masking. Defaults to |
False
|
qkv_bias
|
bool
|
Enable bias in the QKV linear projection. |
True
|
qk_scale
|
Optional[float]
|
Attention scale. Defaults to |
None
|
proj_drop
|
float
|
Dropout score for projection layer. Defaults is |
0.0
|
Example
import torch
from natten import NeighborhoodAttention3D
num_heads = 4
head_dim = 128
embed_dim = num_heads * head_dim
model = NeighborhoodAttention3D(
embed_dim=embed_dim,
num_heads=num_heads,
kernel_size=(4, 8, 12),
stride=(1, 1, 4),
dilation=(1, 2, 1),
is_causal=(True, False, False)
)
batch = 1
token_layout_shape = (12, 16, 20) # (1)!
x = torch.randn(batch, *token_layout_shape, embed_dim) # (2)!
y = model(x) # (3)!
-
Tokens are arranged in a 12 x 16 x 20 layout, to which we apply a kernel size of 4 x 8 x 12, stride 1 x 1 x 4, dilation 1 x 2 x 1, and apply causal masking to the left-most dimension (12).
-
x.shape == [1, 12, 16, 20, 512]
y.shape == [1, 12, 16, 20, 512]