From f3316d8763ab368f9cee34af3f5455f5df13b870 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 12 Sep 2025 14:13:45 -0700 Subject: [PATCH] Delete FlexAttention + NJT composition from tutorial --- .../transformer_building_blocks.py | 61 ------------------- 1 file changed, 61 deletions(-) diff --git a/intermediate_source/transformer_building_blocks.py b/intermediate_source/transformer_building_blocks.py index df2fb90f96..decaf0602f 100644 --- a/intermediate_source/transformer_building_blocks.py +++ b/intermediate_source/transformer_building_blocks.py @@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs): # # * Cross Attention # * Fully masked rows no longer cause NaNs -# * Modifying attention score: ALiBi with FlexAttention and NJT # * Packed Projection ############################################################################### @@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs): # appropriately makes it possible to properly express empty sequences. -################################################################################ -# FlexAttention + NJT -# --------------------------------------------------------------------- -# NJT also composes with the ``FlexAttention`` module. This is a generalization -# of the ``MultiheadAttention`` layer that allows for arbitrary modifications -# to the attention score. The example below takes the ``alibi_mod`` -# that implements `ALiBi `_ from -# `attention gym `_ and uses it -# with nested input tensors. - -from torch.nn.attention.flex_attention import flex_attention - - -def generate_alibi_bias(H: int): - """Returns an alibi bias score_mod given the number of heads H - Args: - H: number of heads - Returns: - alibi_bias: alibi bias score_mod - """ - - def alibi_mod(score, b, h, q_idx, kv_idx): - scale = torch.exp2(-((h + 1) * 8.0 / H)) - bias = (q_idx - kv_idx) * scale - return score + bias - - return alibi_mod - - -query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) -n_heads, D = 8, E_q // 8 -alibi_score_mod = generate_alibi_bias(n_heads) -query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod) - -############################################################################### -# In addition, one can also use the ``block_mask`` utility of ``FlexAttention`` -# with NJTs via the ``create_nested_block_mask`` function. This is useful for -# taking advantage of the sparsity of the mask to speed up the attention computation. -# In particular, the function creates a sparse block mask for a "stacked sequence" of all -# the variable length sequences in the NJT combined into one, while properly masking out -# inter-sequence attention. In the following example, we show how to create a -# causal block mask using this utility. - -from torch.nn.attention.flex_attention import create_nested_block_mask - - -def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - -query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) -block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) -query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -out_flex = flex_attention(query, key, value, block_mask=block_mask) - ############################################################################### # Packed Projection # -----------------