Skip to content

Global context

Certain features in NATTEN are guarded by a global context. Those features are:

  • KV parallelism in FNA backward: necessary for speeding up training, but can introduce non-deterministic behavior, and increased memory footprint. This is standard in almost all fused attention implementations. Read more.
  • Flex Attention + torch.compile: Heavily experimental, and may lead to incorrect behavior, but users can choose to allow it at their own risk. Read more.

KV parallelism in FNA / FMHA

FNA (as well as most FMHA) backward pass implementations need to parallelize across the KV sequence, but this results in a race condition on the query gradient tiles. This is avoided with a mutex lock, which results in non-deterministic order of write, and therefore makes the computation non-deterministic. In addition, some additional scratch space may be required, which is a function of the parallelism degree.

In this document, we outline how you can specify your preference for using KV parallelism.

Controlling KV parallelism

KV parallelism is enabled by default, but you can choose to disable it explicitly using the following operations, or just enable PyTorch's deterministic mode to disable KV parallelism.

natten.use_kv_parallelism_in_fused_na

use_kv_parallelism_in_fused_na(mode=True)

Sets guards for using KV Parallelism in backpropagation in "cutlass-fna"/"cutlass-fmha" backends.

Warning

Disabling KV parallelism can significantly slow down training, particularly in small-batch/head and large-token problems.

Parameters:

Name Type Description Default
mode bool

If True, allows KV parallelism (default setting), and otherwise disables it.

True

You can additionally check the current setting:

natten.is_kv_parallelism_in_fused_na_enabled

is_kv_parallelism_in_fused_na_enabled()

Returns whether KV parallelism in "cutlass-fna" and "cutlass-fmha" backends is enabled.

Memory usage preference

In addition to a global context flag for whether or not KV parallelism is enabled, NATTEN also offers "memory usage preferences", which controls the upper bound for parallelism, as to control the memory footprint.

Presently there are 3 modes, but we plan to improve this interface in the future by giving more fine-grained control and improving the heuristic:

  1. Default
  2. Strict
  3. Unrestricted

Default and strict limit the upper bound for KV parallelism by factoring in how much parallelism is already gained across batch size and attention heads. Unrestricted does not limit the upper bound of KV parallelism and uses the maximum parallelism possible, and therefore gives the best performance, but also the highest memory footprint.

However, we note that in practice, we haven't seen any cases that run out of memory while using the unrestricted setting, and therefore recommend trying that setting first to see if it fits your use case, and downgrade only if not.

To change memory preferences, use the following function:

natten.set_memory_usage_preference

set_memory_usage_preference(pref='default')

Sets memory usage preference for KV parallelism in "cutlass-fna" and "cutlass-fmha" backends.

Parameters:

Name Type Description Default
pref str

Choices are "default", "strict", and "unrestricted".

'default'

You can additionally check what the current setting is:

natten.is_memory_usage_default

is_memory_usage_default()

Returns whether memory usage preference for KV parallelism in "cutlass-fna" and "cutlass-fmha" backends is the default setting.

natten.is_memory_usage_strict

is_memory_usage_strict()

Returns whether memory usage preference for KV parallelism in "cutlass-fna" and "cutlass-fmha" backends is the restricted setting.

natten.is_memory_usage_unrestricted

is_memory_usage_unrestricted()

Returns whether memory usage preference for KV parallelism in "cutlass-fna" and "cutlass-fmha" backends is the unrestricted setting.

Flex Attention + torch.compile

We have been unable to verify the correctness of our Flex backend when compilation is enabled under all of our use cases. We believe this may be a PyTorch bug, because:

  1. everything works as expected without torch.compile,
  2. some cases are intermittent and changing the order of the tests fixes it,
  3. in some cases, forward pass is correct, but backward pass fails, regardless of order.

We are working on raising this issue with PyTorch directly, but until it is resolved, we strongly recommend exercising caution when using this feature.

Due to this, Flex + compilation is guarded with global context variables, which you can control using the following functions.

natten.allow_flex_compile

allow_flex_compile(mode=True, backprop=False)

Sets guards for Flex Attention + torch.compile.

Allows using our Flex FNA / Flex FMHA backends with torch.compile, meaning you can pass torch_compile=True to the na{1,2,3}d or attention operation, along with backend="flex-fna"/backend="flex-fmha", and NATTEN will compile the block-sparse mask, as well as the attention operation using torch.compile for you.

Warning

We have been unable to verify the correctness of this setting under all of our use cases. We are working on raising this issue with PyTorch directly, but until then we strongly recommend exercising caution when using this feature.

backprop=True is strongly discouraged!

Allowing torch.compile for backpropagation (detected by checking tensor.requires_grad) is guarded separately. We strongly recommend NOT using this setting, as it can impact your training results.

Parameters:

Name Type Description Default
mode bool

If True, enable compilation for forward pass, otherwise disable.

True
backprop bool

If True, assuming compilation for forward pass is allowed, enable compilation for backward pass, otherwise disable.

False

natten.allow_flex_compile_backprop

allow_flex_compile_backprop(mode=True)

Sets guards for Flex Attention + torch.compile for backpropagation only.

Parameters:

Name Type Description Default
mode bool

If True, enable compilation for backprop (assuming forward compilation is already enabled), otherwise disable.

True

natten.disable_flex_compile

disable_flex_compile()

Disallow Flex Attention + torch.compile entirely.

natten.disable_flex_compile_backprop

disable_flex_compile_backprop()

Disallow Flex Attention + torch.compile for backpropagation entirely.

You can additionally check the current settings:

natten.is_flex_compile_allowed

is_flex_compile_allowed()

Returns whether compilation is allowed in "flex-fna" and "flex-fmha" backends.

natten.is_flex_compile_backprop_allowed

is_flex_compile_backprop_allowed()

Returns whether compilation for backpropagation is allowed in "flex-fna" and "flex-fmha" backends.