Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions intermediate_source/transformer_building_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs):
#
# * Cross Attention
# * Fully masked rows no longer cause NaNs
# * Modifying attention score: ALiBi with FlexAttention and NJT
# * Packed Projection

###############################################################################
Expand Down Expand Up @@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs):
# appropriately makes it possible to properly express empty sequences.


################################################################################
# FlexAttention + NJT
# ---------------------------------------------------------------------
# NJT also composes with the ``FlexAttention`` module. This is a generalization
# of the ``MultiheadAttention`` layer that allows for arbitrary modifications
# to the attention score. The example below takes the ``alibi_mod``
# that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
# `attention gym <https://github.com/meta-pytorch/attention-gym>`_ and uses it
# with nested input tensors.

from torch.nn.attention.flex_attention import flex_attention


def generate_alibi_bias(H: int):
"""Returns an alibi bias score_mod given the number of heads H
Args:
H: number of heads
Returns:
alibi_bias: alibi bias score_mod
"""

def alibi_mod(score, b, h, q_idx, kv_idx):
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (q_idx - kv_idx) * scale
return score + bias

return alibi_mod


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

###############################################################################
# In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
# with NJTs via the ``create_nested_block_mask`` function. This is useful for
# taking advantage of the sparsity of the mask to speed up the attention computation.
# In particular, the function creates a sparse block mask for a "stacked sequence" of all
# the variable length sequences in the NJT combined into one, while properly masking out
# inter-sequence attention. In the following example, we show how to create a
# causal block mask using this utility.

from torch.nn.attention.flex_attention import create_nested_block_mask


def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex = flex_attention(query, key, value, block_mask=block_mask)

###############################################################################
# Packed Projection
# -----------------
Expand Down