Skip to content

Latest commit

 

History

History

ODML Transformer Layers

Common Pytorch building blocks to re-author transformer models.

Attention layers

attention.py and attention_utils.py contain common building blocks for the attention calculation, which is the key part of transformer models. You can use the abstractions provided here to compose your transformer model.

These two files provide the following common Python helper functions:

  • scaled_dot_product_attention: helper function to compute scaled dot product attention on query, key and value tensors.
  • scaled_dot_product_attention_with_hlfb: same as scaled_dot_product_attention with the addition of HLFB (high-level function boundary) for improved performance.
  • build_rope_cache: pre-compute sin and cos values for Rotary Positional Embedding.
  • build_causal_mask_cache: build a cache for causal self attention mask.

And also the following nn.Module classes:

  • TransformerBlock
  • CausalSelfAttention
  • SelfAttention
  • CrossAttention

Builder class for common layers

In builder.py, it provides following helper functions:

  • build_norm: constructs different kinds of normalizers based on a config.
  • build_ff: constructs different kinds of feed forward layers, which includes Sequential or Gated.

Feed forward layer

The library provides the following nn.Modules to represent feed forward layer.

  • SequentialFeedForward
  • GatedFeedForward

KV cache layer

We provide a nn.Module KVCache to express the logic to update the cache. It also has internal logic to apply HLFB to ensure high-performance at runtime.

Model Configuration class

Currently, the library provides the following configuration class for you to customize the transformer model:

  • AttentionConfig
  • FeedForwardConfig
  • NormalizationConfig
  • ModelConfig

Normalization layer

normalization.py provides normalization modules currently not supported by Pytorch such as RMSNorm:

  • RMSNorm

RoPE Embedding

rotary_position_embedding.py contains helper functions for applying RoPE to tensors.

High-Level function boundary for performance

We introduce High-Level Function Boundary (HLFB) as a way of annotating performance-critical pieces of the model (e.g. scaled_dot_product_attention, or KVCache). HLFB allows the converter to lower the annotated blocks to performant TFLite custom ops. Following is an example of applying HLFB to SDPA:

def scaled_dot_product_attention_with_hlfb(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
head_size: int,
mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
"""Scaled dot product attention with high-level function boundary enabled.
Args:
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
head_size (int): head dimension.
mask (torch.Tensor): the optional mask tensor.
Returns:
The output tensor of scaled_dot_product_attention.
"""
if scale is None:
scale = 1.0 / math.sqrt(head_size)
builder = StableHLOCompositeBuilder(
name="odml.scaled_dot_product_attention", attr={"scale": scale}
)
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if q.size() != k.size():
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=mask is None,
scale=scale,
)
result = y.transpose(1, 2)
result = builder.mark_outputs(result)
return result