Skip to content

Commit 9943c89

Browse files
committed
autotune fp8 matmul reduce scatter in unit test
1 parent ee9724e commit 9943c89

1 file changed

Lines changed: 45 additions & 12 deletions

File tree

test/test_distributed.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
from datetime import timedelta
5+
import functools
56
import os
67
import unittest
78

@@ -23,6 +24,7 @@
2324
from helion._dist_utils import sync_seed
2425
from helion._testing import EXAMPLES_DIR
2526
from helion._testing import TestCase
27+
from helion._testing import assert_close_with_mismatch_tolerance
2628
from helion._testing import import_path
2729
from helion._testing import onlyBackends
2830
from helion._testing import skipIfRocm
@@ -429,14 +431,44 @@ def do_test_matmul_reduce_scatter(self, kernel, ref_kernel):
429431
@skipIfRocm("Distributed example requires CUDA/NCCL")
430432
@skipIfXPU("Distributed operations require CCL, not yet fully integrated")
431433
@skip_if_lt_x_gpu(4)
432-
def test_fp8_matmul_reduce_scatter(self):
434+
@parametrize("autotuner", ["fixed", "LFBOTreeSearch"])
435+
def test_fp8_matmul_reduce_scatter(self, autotuner):
433436
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
434437
self.skipTest("FP8 requires CUDA compute capability >= 9.0")
435438
self._init_process()
436439

437440
mod = import_path(EXAMPLES_DIR / "distributed" / "fp8_matmul_reduce_scatter.py")
438441

442+
kernel = mod.fp8_matmul_reduce_scatter_kernel.fn
439443
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16
444+
445+
accuracy_check_fn = functools.partial(
446+
assert_close_with_mismatch_tolerance, **mod.tolerance
447+
)
448+
449+
if autotuner == "fixed":
450+
kernel = helion.kernel(
451+
config=helion.Config(
452+
block_sizes=[64, 64, 32],
453+
num_warps=8,
454+
num_stages=3,
455+
),
456+
static_shapes=True,
457+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
458+
autotune_baseline_accuracy_check_fn=accuracy_check_fn,
459+
)(kernel)
460+
context = contextlib.nullcontext()
461+
else:
462+
kernel = helion.kernel(
463+
kernel,
464+
static_shapes=True,
465+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
466+
autotune_baseline_accuracy_check_fn=accuracy_check_fn,
467+
)
468+
context = unittest.mock.patch.dict(
469+
os.environ, {"HELION_AUTOTUNER": autotuner}
470+
)
471+
440472
M, N, K = 512, 768, 1024
441473

442474
torch.manual_seed(42 + self.rank)
@@ -457,17 +489,18 @@ def test_fp8_matmul_reduce_scatter(self):
457489
symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.bfloat16, device=self.device)
458490
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name)
459491

460-
result = mod.fp8_matmul_reduce_scatter_kernel(
461-
a,
462-
b,
463-
scale_a,
464-
scale_b,
465-
symm_mem_buffer,
466-
symm_mem_hdl.signal_pad_ptrs_dev,
467-
RANK=symm_mem_hdl.rank,
468-
WORLD_SIZE=symm_mem_hdl.world_size,
469-
GROUP_NAME=dist.group.WORLD.group_name,
470-
)
492+
with context:
493+
result = kernel(
494+
a,
495+
b,
496+
scale_a,
497+
scale_b,
498+
symm_mem_buffer,
499+
symm_mem_hdl.signal_pad_ptrs_dev,
500+
RANK=symm_mem_hdl.rank,
501+
WORLD_SIZE=symm_mem_hdl.world_size,
502+
GROUP_NAME=dist.group.WORLD.group_name,
503+
)
471504

472505
expected = mod.reference_fp8_matmul_reduce_scatter(a, b, scale_a, scale_b)
473506

0 commit comments

Comments
 (0)