22
33import contextlib
44from datetime import timedelta
5+ import functools
56import os
67import unittest
78
2324from helion ._dist_utils import sync_seed
2425from helion ._testing import EXAMPLES_DIR
2526from helion ._testing import TestCase
27+ from helion ._testing import assert_close_with_mismatch_tolerance
2628from helion ._testing import import_path
2729from helion ._testing import onlyBackends
2830from helion ._testing import skipIfRocm
@@ -429,14 +431,44 @@ def do_test_matmul_reduce_scatter(self, kernel, ref_kernel):
429431 @skipIfRocm ("Distributed example requires CUDA/NCCL" )
430432 @skipIfXPU ("Distributed operations require CCL, not yet fully integrated" )
431433 @skip_if_lt_x_gpu (4 )
432- def test_fp8_matmul_reduce_scatter (self ):
434+ @parametrize ("autotuner" , ["fixed" , "LFBOTreeSearch" ])
435+ def test_fp8_matmul_reduce_scatter (self , autotuner ):
433436 if not torch .cuda .is_available () or torch .cuda .get_device_capability ()[0 ] < 9 :
434437 self .skipTest ("FP8 requires CUDA compute capability >= 9.0" )
435438 self ._init_process ()
436439
437440 mod = import_path (EXAMPLES_DIR / "distributed" / "fp8_matmul_reduce_scatter.py" )
438441
442+ kernel = mod .fp8_matmul_reduce_scatter_kernel .fn
439443 _SymmetricMemory .signal_pad_size = 1024 * 1024 * 16
444+
445+ accuracy_check_fn = functools .partial (
446+ assert_close_with_mismatch_tolerance , ** mod .tolerance
447+ )
448+
449+ if autotuner == "fixed" :
450+ kernel = helion .kernel (
451+ config = helion .Config (
452+ block_sizes = [64 , 64 , 32 ],
453+ num_warps = 8 ,
454+ num_stages = 3 ,
455+ ),
456+ static_shapes = True ,
457+ ignore_warnings = [helion .exc .TensorOperationInWrapper ],
458+ autotune_baseline_accuracy_check_fn = accuracy_check_fn ,
459+ )(kernel )
460+ context = contextlib .nullcontext ()
461+ else :
462+ kernel = helion .kernel (
463+ kernel ,
464+ static_shapes = True ,
465+ ignore_warnings = [helion .exc .TensorOperationInWrapper ],
466+ autotune_baseline_accuracy_check_fn = accuracy_check_fn ,
467+ )
468+ context = unittest .mock .patch .dict (
469+ os .environ , {"HELION_AUTOTUNER" : autotuner }
470+ )
471+
440472 M , N , K = 512 , 768 , 1024
441473
442474 torch .manual_seed (42 + self .rank )
@@ -457,17 +489,18 @@ def test_fp8_matmul_reduce_scatter(self):
457489 symm_mem_buffer = symm_mem .empty (M , N , dtype = torch .bfloat16 , device = self .device )
458490 symm_mem_hdl = symm_mem .rendezvous (symm_mem_buffer , dist .group .WORLD .group_name )
459491
460- result = mod .fp8_matmul_reduce_scatter_kernel (
461- a ,
462- b ,
463- scale_a ,
464- scale_b ,
465- symm_mem_buffer ,
466- symm_mem_hdl .signal_pad_ptrs_dev ,
467- RANK = symm_mem_hdl .rank ,
468- WORLD_SIZE = symm_mem_hdl .world_size ,
469- GROUP_NAME = dist .group .WORLD .group_name ,
470- )
492+ with context :
493+ result = kernel (
494+ a ,
495+ b ,
496+ scale_a ,
497+ scale_b ,
498+ symm_mem_buffer ,
499+ symm_mem_hdl .signal_pad_ptrs_dev ,
500+ RANK = symm_mem_hdl .rank ,
501+ WORLD_SIZE = symm_mem_hdl .world_size ,
502+ GROUP_NAME = dist .group .WORLD .group_name ,
503+ )
471504
472505 expected = mod .reference_fp8_matmul_reduce_scatter (a , b , scale_a , scale_b )
473506
0 commit comments