-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathsplash_attention.py
361 lines (327 loc) · 11.8 KB
/
splash_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import dataclasses
import functools
import json
from dataclasses import asdict
from typing import Any
import torch
import torch_xla.debug.profiler as xp
from torch.library import custom_op
from torch.utils._pytree import tree_flatten
from torch_xla.core.xla_builder import call_jax
from torch_xla.distributed.spmd import Mesh
from torch_xla.experimental.custom_kernel import requires_jax
@dataclasses.dataclass(eq=True, frozen=True)
class SplashAttentionConfig:
### Splash attention block sizes
# These can be tuned for specific hardware generations, and can be set up to
# the model's sequence length.
sa_block_q: int = 2048
sa_block_kv: int = 2048
sa_block_kv_compute: int = 2048
sa_block_q_dkv: int = 2048
sa_block_kv_dkv: int = 2048
sa_block_kv_dkv_compute: int = 2048
sa_block_q_dq: int = 2048
sa_block_kv_dq: int = 2048
sa_use_fused_bwd_kernel: bool = True
sa_q_layout: str = "HEAD_DIM_MINOR"
sa_k_layout: str = "HEAD_DIM_MINOR"
sa_v_layout: str = "HEAD_DIM_MINOR"
mesh: str | None = None
qkv_partition_spec: tuple[tuple[str] | str | None] = (
("data", "fsdp"),
None,
None,
None,
)
segment_ids_partition_spec: tuple[tuple[str] | str | None] = (
("data", "fsdp"),
None,
)
attentiontype_local_sliding: bool = False
slide_window_size: int | None = None
def to_json(self) -> str:
"""Serialize to JSON string"""
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str) -> "SplashAttentionConfig":
"""Deserialize from JSON string"""
json_data = json.loads(data)
# Define a function to convert lists to tuples
def list_to_tuple(x):
if isinstance(x, list):
return tuple(list_to_tuple(item) for item in x)
return x
# Apply the conversion to all fields
converted_data = {k: list_to_tuple(v) for k, v in json_data.items()}
return SplashAttentionConfig(**converted_data)
@xp.trace_me("splash_attention_kernel_wrapper")
def splash_attention_jax_wrapper(
query,
key,
value,
decoder_segment_ids,
causal: bool,
config: SplashAttentionConfig,
attn_logits_soft_cap,
):
"""Splash attention kernel wrapper for JAX
Inside the function, everything is JAX specific. We convert the torch_xla mesh
and partition spec into jax specific format, and reuse the MaxText attention
call function from
https://github.com/AI-Hypercomputer/maxtext/blob/d8ffb5c4fc65e6832976226a8053236c2fe3164e/MaxText/layers/attentions.py#L336-L430.
"""
import jax
from jax.experimental import shard_map
from jax.experimental.pallas.ops.tpu.splash_attention import (
splash_attention_kernel,
splash_attention_mask,
)
mesh = Mesh.from_str(config.mesh).get_jax_mesh()
# input q,k,v shape: [batch, #head, seq_len, head_dim]
if decoder_segment_ids is not None and not decoder_segment_ids.shape:
decoder_segment_ids = None
if decoder_segment_ids is not None:
decoder_segment_ids = splash_attention_kernel.SegmentIds(
decoder_segment_ids, decoder_segment_ids)
axis_names = jax.sharding.PartitionSpec(*config.qkv_partition_spec)
segment_axis_names = jax.sharding.PartitionSpec(
*config.segment_ids_partition_spec)
global_block_q = config.sa_block_q
global_block_kv = config.sa_block_kv
global_block_kv_compute = config.sa_block_kv_compute
global_block_q_dkv = config.sa_block_q_dkv
global_block_kv_dkv = config.sa_block_kv_dkv
global_block_kv_dkv_compute = config.sa_block_kv_dkv_compute
global_block_q_dq = config.sa_block_q_dq
global_block_kv_dq = config.sa_block_kv_dq
global_use_fused_bwd_kernel = config.sa_use_fused_bwd_kernel
global_q_layout = config.sa_q_layout
global_k_layout = config.sa_k_layout
global_v_layout = config.sa_v_layout
shard_map = shard_map.shard_map
@functools.partial(
shard_map,
mesh=mesh,
in_specs=(
axis_names,
axis_names,
axis_names,
segment_axis_names,
),
out_specs=axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value, decoder_segment_ids):
seq_len = query.shape[2]
if decoder_segment_ids is not None:
assert (
seq_len == decoder_segment_ids.q.shape[1]
), "Sharding along sequence dimension not allowed in tpu kernel attention"
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(global_block_q, seq_len),
block_kv=min(global_block_kv, key.shape[2]),
block_kv_compute=min(global_block_kv_compute, key.shape[2]),
block_q_dkv=min(global_block_q_dkv, seq_len),
block_kv_dkv=min(global_block_kv_dkv, key.shape[2]),
block_kv_dkv_compute=min(global_block_kv_dkv_compute, seq_len),
block_q_dq=None if global_use_fused_bwd_kernel else min(
global_block_q_dq, seq_len),
block_kv_dq=None if global_use_fused_bwd_kernel else min(
global_block_kv_dq, seq_len),
use_fused_bwd_kernel=global_use_fused_bwd_kernel,
q_layout=splash_attention_kernel.QKVLayout[global_q_layout],
k_layout=splash_attention_kernel.QKVLayout[global_k_layout],
v_layout=splash_attention_kernel.QKVLayout[global_v_layout],
)
if causal:
mask = splash_attention_mask.CausalMask(shape=(seq_len, seq_len))
else:
mask = splash_attention_mask.FullMask(_shape=(seq_len, seq_len))
# Apply local masking if local sliding attention is enabled.
if config.attentiontype_local_sliding:
if config.slide_window_size is None:
raise ValueError(
"Sliding_window_size must be set if Local Sliding attention type")
mask &= splash_attention_mask.LocalMask(
shape=(seq_len, seq_len),
window_size=(config.slide_window_size, config.slide_window_size),
offset=0,
)
# Create multi-head mask
multi_head_mask = splash_attention_mask.MultiHeadMask(
masks=(mask,) * query.shape[1])
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1,
q_seq_shards=1,
block_sizes=block_sizes,
attn_logits_soft_cap=attn_logits_soft_cap,
)
return jax.vmap(splash_kernel)(
query, key, value, segment_ids=decoder_segment_ids)
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
assert (query.shape[0] / devices_in_data_fsdp).is_integer(), (
"Batch dimension should be shardable among the devices in data and fsdp axis"
)
x = wrap_flash_attention(query, key, value, decoder_segment_ids)
# x.shape = [batch, heads, seq_length, head_dim]
return x
@requires_jax
def _jax_grad_f(query, key, value, decoder_segment_ids, causal, config,
attn_logits_soft_cap, grad_output):
import jax
differentiated_fun = functools.partial(
splash_attention_jax_wrapper,
decoder_segment_ids=decoder_segment_ids,
causal=causal,
config=config,
attn_logits_soft_cap=attn_logits_soft_cap,
)
primals, f_vjp = jax.vjp(differentiated_fun, query, key, value)
return f_vjp(grad_output)
@xp.trace_me("tpu_splash_attention_jax_call_wrapper")
def tpu_splash_attention_jax_call_wrapper(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
config: SplashAttentionConfig,
decoder_segment_ids: torch.Tensor | None,
causal: bool,
attn_logits_soft_cap: float | None = None,
is_forward: bool = True,
grad_output: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
# return tuple to fit for the output num for both fwd and bwd
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
input_args = [
query, key, value, decoder_segment_ids, causal, config,
attn_logits_soft_cap
]
if is_forward:
output = call_jax(splash_attention_jax_wrapper, input_args, {},
"splash_attention_jax_wrapper_fw")
return (output, None, None)
else:
# TODO: find out a way to skip grad computation for decoder_segment_ids
q_grad, k_grad, v_grad, *_rest = call_jax(
_jax_grad_f,
input_args + [grad_output],
{},
"splash_attention_jax_wrapper_bw",
)
return (q_grad, k_grad, v_grad)
@custom_op("xla::sa_custom_forward", mutates_args=())
def sa_custom_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
config: str,
decoder_segment_ids: torch.Tensor | None,
causal: bool | None,
attn_logits_soft_cap: float | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
config = SplashAttentionConfig.from_json(config)
return tpu_splash_attention_jax_call_wrapper(
q,
k,
v,
config,
decoder_segment_ids,
causal,
attn_logits_soft_cap,
is_forward=True,
grad_output=None,
)
@sa_custom_forward.register_fake
def sa_custom_forward_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
config: str,
decoder_segment_ids: torch.Tensor | None,
causal: bool | None,
attn_logits_soft_cap: float | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# q.shape: batch_size, seq_length, num_heads, head_dim
return (torch.empty_like(q), None, None)
@custom_op("xla::sa_custom_backward", mutates_args=())
def sa_custom_backward(
grad_output: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
config: str,
decoder_segment_ids: torch.Tensor | None,
causal: bool | None,
attn_logits_soft_cap: float | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
config = SplashAttentionConfig.from_json(config)
o = tpu_splash_attention_jax_call_wrapper(
q,
k,
v,
config,
decoder_segment_ids,
causal,
attn_logits_soft_cap,
is_forward=False,
grad_output=grad_output,
)
return o
@sa_custom_backward.register_fake
def sa_custom_backward_fake(
grad_output: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
config: str,
decoder_segment_ids: torch.Tensor | None,
causal: bool | None,
attn_logits_soft_cap: float | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return (torch.empty_like(q), torch.empty_like(k), torch.empty_like(v))
class SplashAttention(torch.autograd.Function):
@staticmethod
@requires_jax
def forward(ctx, q, k, v, config, decoder_segment_ids, causal,
attn_logits_soft_cap):
output = sa_custom_forward(q, k, v, config, decoder_segment_ids, causal,
attn_logits_soft_cap)[0]
ctx.save_for_backward(q, k, v, decoder_segment_ids, attn_logits_soft_cap)
ctx.config = config
ctx.causal = causal
return output
@staticmethod
@requires_jax
def backward(ctx, grad_output):
q, k, v, decoder_segment_ids, attn_logits_soft_cap = ctx.saved_tensors
config = ctx.config
causal = ctx.causal
grad_q, grad_k, grad_v = sa_custom_backward(grad_output, q, k, v, config,
decoder_segment_ids, causal,
attn_logits_soft_cap)
return grad_q, grad_k, grad_v, None, None, None, None
def splash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
config: str,
decoder_segment_ids: torch.Tensor | None = None,
causal: bool = True,
attn_logits_soft_cap: float | None = None,
) -> torch.Tensor:
"""Splash attention function.
Args:
decoder_segment_ids: Segment ids are a pair of 1D jax.Arrays, one for Q (of
size q_seq_len) and one for KV (of size kv_seq_len). A segment id mask is
computed such that only tokens that have the same segment id can attend to
each other. This creates a block-sparse pattern along the main diagonal.
attn_logits_soft_cap: The soft clipping value for logits pre softmax.
Returns:
The attention output tensor.
"""
return SplashAttention.apply(q, k, v, config, decoder_segment_ids, causal,
attn_logits_soft_cap)