# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. from copy import deepcopy import torch from torchao.prototype.uintx import ( uintx_affine_weight_only, unpack_cpu, ) from torchao.quantization.quant_api import quantize_ class Linear16(torch.nn.Module): def __init__(self, scale): super().__init__() self.net = torch.nn.Sequential( torch.nn.Linear(scale * 2, scale, bias=True, dtype=torch.float16).cuda(), torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(), torch.nn.Linear(scale, scale // 2, bias=True, dtype=torch.float16).cuda(), ) def forward(self, x): return self.net(x) def benchmark(function, args, num_runs): # warmup torch._dynamo.reset() for i in range(100): function(*args) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_runs): function(*args) end_event.record() torch.cuda.synchronize() return start_event.elapsed_time(end_event) / num_runs def profile_bitpack(): from torch.profiler import ProfilerActivity, profile fake_tensor = [torch.randint(2**8, (512, 512), dtype=torch.uint8).cuda()] func = torch.compile(unpack_cpu, fullgraph=True) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True, ) as prof: for _ in range(1000): func(fake_tensor, 4) # Print a summary with open("profile-bitpack.txt", "a") as f: print(f"{func}", file=f) print( prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f ) prof.export_chrome_trace("trace.json") """ CPU perf: unpack_gpu Self CPU time total: 602.501ms unpack_cpu Self CPU time total: 415.469ms GPU perf: unpack_gpu on gpu: Self CPU time total: 58.512ms Self CUDA time total: 5.083ms unpack_cpu: Self CPU time total: 96.947ms Self CUDA time total: 5.253ms """ def uintx_vs_fp16(nbits=[1, 2, 3, 4, 5, 6, 7], scales=[256, 512, 1024], repeats=30): results = [] nbits.sort() scales.sort() for scale in scales: test_input = torch.randn(scale * 2, dtype=torch.float16).cuda() forward_args = [test_input] times = [scale] fp16 = Linear16(scale) fp16c = torch.compile(fp16, fullgraph=True) fp16_time = benchmark(fp16c.forward, forward_args, repeats) times.append(fp16_time) for bit_size in nbits: m = deepcopy(fp16) quantize_(m, uintx_affine_weight_only(bit_size)) m = torch.compile(m, fullgraph=True) uintx_time = benchmark(m.forward, forward_args, repeats) times.append(uintx_time) print(f"scale={scale} done") results.append(times) print("----------- benchmark results -----------") for result in results: print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:") for i in range(2, len(result)): print(f"int{nbits[i - 2]}: {result[1] / result[i]: .2f}x") if __name__ == "__main__": uintx_vs_fp16(nbits=[4, 7])