Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 469.697136 688.319248
1 384.0 628.812511 826.644799
2 512.0 804.595591 933.205860
3 640.0 809.771939 963.139548
4 768.0 889.579084 1031.323887
5 896.0 944.537557 1074.134385
6 1024.0 1004.870579 1122.966470
7 1152.0 1101.538865 1034.879969
8 1280.0 1148.411183 1077.517402
9 1408.0 1161.396480 1100.572410
10 1536.0 1189.490197 1135.772700
11 1664.0 1216.264218 1160.863384
12 1792.0 1228.121168 1197.499376
13 1920.0 1245.861816 1196.804317
14 2048.0 1279.838097 1226.110890
15 2176.0 1242.860693 956.098392
16 2304.0 1242.118111 1000.915071
17 2432.0 1275.613185 1034.135528
18 2560.0 1287.040811 1066.653997
19 2688.0 1289.996590 1099.383944
20 2816.0 1303.440625 1121.102949
21 2944.0 1312.063477 1143.861414
22 3072.0 1329.883776 1170.285724
23 3200.0 1328.549284 1171.700235
24 3328.0 1338.101650 1201.544294
25 3456.0 1349.523849 1218.866268
26 3584.0 1349.112664 1245.453025
27 3712.0 1359.664612 1262.972284
28 3840.0 1369.848872 1288.029302
29 3968.0 1369.903427 1294.958319
30 4096.0 1375.766213 1316.342264
31 4224.0 1334.696732 1289.585144
32 4352.0 1336.197202 1317.966872
33 4480.0 1356.622263 1334.300695
34 4608.0 1363.732193 1353.878101
35 4736.0 1361.486021 1366.682155
36 4864.0 1372.661050 1381.218673
37 4992.0 1373.830479 1394.761234
38 5120.0 1373.513819 1404.539783
39 5248.0 1377.568032 1362.510288
40 5376.0 1377.975309 1384.241684
41 5504.0 1378.305178 1392.962110
42 5632.0 1384.524974 1410.763241
43 5760.0 1394.514086 1422.415574
44 5888.0 1389.353069 1430.205356
45 6016.0 1402.396835 1434.333874
46 6144.0 1406.177543 1441.079219
47 6272.0 1415.074474 1405.632928
48 6400.0 1418.687998 1421.571952
49 6528.0 1411.890275 1435.847123
50 6656.0 1424.944828 1442.628937
51 6784.0 1410.322918 1441.879852
52 6912.0 1424.585726 1450.853843
53 7040.0 1420.959905 1465.340695
54 7168.0 1426.473500 1468.992544
55 7296.0 1430.081395 1084.499392
56 7424.0 1428.341106 1099.225235
57 7552.0 1423.703201 1112.998444
58 7680.0 1432.499378 1126.336971
59 7808.0 1432.974643 1134.395507
60 7936.0 1439.432487 1144.513782
61 8064.0 1433.683263 1151.117364
62 8192.0 1435.340739 1154.389079
63 8320.0 1392.287298 1112.348747
64 8448.0 1383.807745 1124.001226
65 8576.0 1401.545571 1125.143858
66 8704.0 1391.062780 1129.837110
67 8832.0 1385.717057 1130.808579
68 8960.0 1404.307036 1136.186474
69 9088.0 1415.958051 1132.773515
70 9216.0 1409.812062 1129.336089
71 9344.0 1405.168869 1423.485398
72 9472.0 1404.010263 1429.460112
73 9600.0 1398.362765 1429.409929
74 9728.0 1405.574437 1441.049044
75 9856.0 1419.891808 1440.501682
76 9984.0 1403.158652 1448.278508
77 10112.0 1417.004838 1452.231094
78 10240.0 1424.144140 1461.975164
79 10368.0 1418.775178 1463.827823
80 10496.0 1424.094031 1464.213243
81 10624.0 1417.327356 1466.698451
82 10752.0 1411.369908 1467.447355
83 10880.0 1406.965475 1483.324515
84 11008.0 1425.727083 1474.598888
85 11136.0 1428.216019 1485.369408
86 11264.0 1431.153553 1486.886098
87 11392.0 1419.146323 1490.577739
88 11520.0 1428.036088 1493.336872
89 11648.0 1426.728588 1498.774329
90 11776.0 1436.659873 1503.285958
91 11904.0 1445.867536 1508.460496
92 12032.0 1429.844757 1513.721323
93 12160.0 1424.588961 1513.084016
94 12288.0 1442.863905 1425.624490
95 12416.0 1456.343198 1395.479358
96 12544.0 1444.649256 1397.384016
97 12672.0 1456.040882 1394.542284
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.285 seconds)