Global context
Certain features in NATTEN are guarded by a global context. Those features are:
- KV parallelism in FNA backward: necessary for speeding up training, but can introduce non-deterministic behavior, and increased memory footprint. This is standard in almost all fused attention implementations. Read more.
- Flex Attention +
torch.compile
: Heavily experimental, and may lead to incorrect behavior, but users can choose to allow it at their own risk. Read more.
KV parallelism in FNA / FMHA
FNA (as well as most FMHA) backward pass implementations need to parallelize across the KV sequence, but this results in a race condition on the query gradient tiles. This is avoided with a mutex lock, which results in non-deterministic order of write, and therefore makes the computation non-deterministic. In addition, some additional scratch space may be required, which is a function of the parallelism degree.
In this document, we outline how you can specify your preference for using KV parallelism.
Controlling KV parallelism
KV parallelism is enabled by default, but you can choose to disable it explicitly using the following operations, or just enable PyTorch's deterministic mode to disable KV parallelism.
natten.use_kv_parallelism_in_fused_na
Sets guards for using KV Parallelism in backpropagation in "cutlass-fna"
/"cutlass-fmha"
backends.
Warning
Disabling KV parallelism can significantly slow down training, particularly in small-batch/head and large-token problems.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
bool
|
If |
True
|
You can additionally check the current setting:
Memory usage preference
In addition to a global context flag for whether or not KV parallelism is enabled, NATTEN also offers "memory usage preferences", which controls the upper bound for parallelism, as to control the memory footprint.
Presently there are 3 modes, but we plan to improve this interface in the future by giving more fine-grained control and improving the heuristic:
- Default
- Strict
- Unrestricted
Default and strict limit the upper bound for KV parallelism by factoring in how much parallelism is already gained across batch size and attention heads. Unrestricted does not limit the upper bound of KV parallelism and uses the maximum parallelism possible, and therefore gives the best performance, but also the highest memory footprint.
However, we note that in practice, we haven't seen any cases that run out of memory while using the unrestricted setting, and therefore recommend trying that setting first to see if it fits your use case, and downgrade only if not.
To change memory preferences, use the following function:
You can additionally check what the current setting is:
Flex Attention + torch.compile
We have been unable to verify the correctness of our Flex backend when compilation is enabled under all of our use cases. We believe this may be a PyTorch bug, because:
- everything works as expected without
torch.compile
, - some cases are intermittent and changing the order of the tests fixes it,
- in some cases, forward pass is correct, but backward pass fails, regardless of order.
We are working on raising this issue with PyTorch directly, but until it is resolved, we strongly recommend exercising caution when using this feature.
Due to this, Flex + compilation is guarded with global context variables, which you can control using the following functions.
natten.allow_flex_compile
Sets guards for Flex Attention + torch.compile
.
Allows using our Flex FNA / Flex FMHA backends with torch.compile
, meaning you can
pass torch_compile=True
to the na{1,2,3}d
or attention
operation, along with
backend="flex-fna"
/backend="flex-fmha"
, and NATTEN will compile the block-sparse mask, as
well as the attention operation using torch.compile
for you.
Warning
We have been unable to verify the correctness of this setting under all of our use cases. We are working on raising this issue with PyTorch directly, but until then we strongly recommend exercising caution when using this feature.
backprop=True is strongly discouraged!
Allowing torch.compile
for backpropagation (detected by checking
tensor.requires_grad
) is guarded separately. We strongly recommend NOT using this setting, as
it can impact your training results.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
bool
|
If |
True
|
backprop
|
bool
|
If |
False
|
natten.allow_flex_compile_backprop
Sets guards for Flex Attention + torch.compile
for backpropagation only.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
bool
|
If |
True
|
You can additionally check the current settings: