-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy pathbenchmark_uintx.py
116 lines (95 loc) · 3.43 KB
/
benchmark_uintx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
import torch
from torchao.prototype.uintx import (
uintx_affine_weight_only,
unpack_cpu,
)
from torchao.quantization.quant_api import quantize_
class Linear16(torch.nn.Module):
def __init__(self, scale):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(scale * 2, scale, bias=True, dtype=torch.float16).cuda(),
torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(),
torch.nn.Linear(scale, scale // 2, bias=True, dtype=torch.float16).cuda(),
)
def forward(self, x):
return self.net(x)
def benchmark(function, args, num_runs):
# warmup
torch._dynamo.reset()
for i in range(100):
function(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_runs):
function(*args)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs
def profile_bitpack():
from torch.profiler import ProfilerActivity, profile
fake_tensor = [torch.randint(2**8, (512, 512), dtype=torch.uint8).cuda()]
func = torch.compile(unpack_cpu, fullgraph=True)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
) as prof:
for _ in range(1000):
func(fake_tensor, 4)
# Print a summary
with open("profile-bitpack.txt", "a") as f:
print(f"{func}", file=f)
print(
prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f
)
prof.export_chrome_trace("trace.json")
"""
CPU perf:
unpack_gpu
Self CPU time total: 602.501ms
unpack_cpu
Self CPU time total: 415.469ms
GPU perf:
unpack_gpu on gpu:
Self CPU time total: 58.512ms
Self CUDA time total: 5.083ms
unpack_cpu:
Self CPU time total: 96.947ms
Self CUDA time total: 5.253ms
"""
def uintx_vs_fp16(nbits=[1, 2, 3, 4, 5, 6, 7], scales=[256, 512, 1024], repeats=30):
results = []
nbits.sort()
scales.sort()
for scale in scales:
test_input = torch.randn(scale * 2, dtype=torch.float16).cuda()
forward_args = [test_input]
times = [scale]
fp16 = Linear16(scale)
fp16c = torch.compile(fp16, fullgraph=True)
fp16_time = benchmark(fp16c.forward, forward_args, repeats)
times.append(fp16_time)
for bit_size in nbits:
m = deepcopy(fp16)
quantize_(m, uintx_affine_weight_only(bit_size))
m = torch.compile(m, fullgraph=True)
uintx_time = benchmark(m.forward, forward_args, repeats)
times.append(uintx_time)
print(f"scale={scale} done")
results.append(times)
print("----------- benchmark results -----------")
for result in results:
print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:")
for i in range(2, len(result)):
print(f"int{nbits[i - 2]}: {result[1] / result[i]: .2f}x")
if __name__ == "__main__":
uintx_vs_fp16(nbits=[4, 7])