Skip to content

[autotune] Feat: Overlapping GPU benchmarking and CPU compilation#1416

Draft
hinriksnaer wants to merge 25 commits into
pytorch:mainfrom
hinriksnaer:overlap-optimization
Draft

[autotune] Feat: Overlapping GPU benchmarking and CPU compilation#1416
hinriksnaer wants to merge 25 commits into
pytorch:mainfrom
hinriksnaer:overlap-optimization

Conversation

@hinriksnaer
Copy link
Copy Markdown
Collaborator

related issue #1400

The problem

Autotuning currently wastes a lot of time with idle resources. Here's what happens when we benchmark 200 configs:

Phase 1: Compile all 200 configs (parallel)
├─ CPU: [████████] 100% busy
└─ GPU: [        ] idle

Phase 2: Benchmark all 200 configs (sequential)
├─ CPU: [        ] idle
└─ GPU: [████████] 100% busy

The GPU sits idle for while we compile, then the CPU sits idle for while we benchmark. That's a lot of wasted compute.

The solution

Increase hardware utilization by parallel compiling configurations while using the main thread for sequential benchmarking of compiled kernels. This approach does the following iteratively:

  while queue:
      # Start up to N compilations in parallel (N = CPU cap)
      start_compilations_up_to_cap()
      # Wait for next one to finish
      compiled_kernel = wait_for_next_compilation()
      # Benchmark on main thread (GPU requires this)
      benchmark_kernel(compiled_kernel)

This is a simple approach that allows for overlapping GPU and CPU utilization which reduces unused resources during autotuning. More "clever" scheduling can be added later on such as a dedicated cpu scheduling thread in order to unblock CPU scheduling from GPU benchmarking.

Usage

Enable with environment variable:

export HELION_AUTOTUNE_OVERLAP_COMPILATION=1

Early Benchmark

Baseline

HELION_AUTOTUNE_RANDOM_SEED=42 HELION_AUTOTUNE_OVERLAP_COMPILATION=0 python benchmarks/run.py --kernel jsd
        (B, T, V)    torch_jsd-latency    liger_jsd-latency    torch_compile_jsd-latency    helion_jsd_tritonbench-latency
-----------------  -------------------  -------------------  ---------------------------  --------------------------------
  (4, 2048, 4096)    0.970816 (±0.34%)    0.258048 (±1.51%)            0.083168 (±2.35%)                 0.082112 (±2.46%)
  (4, 2048, 8192)    1.891136 (±0.26%)    0.496576 (±2.20%)            0.151744 (±1.22%)                 0.152800 (±1.13%)
 (4, 2048, 16384)    3.672192 (±0.13%)    6.335680 (±0.12%)            0.282400 (±0.67%)                 0.276256 (±1.09%)
 (4, 2048, 32768)    7.237536 (±0.08%)   13.000832 (±0.19%)            0.543328 (±1.31%)                 0.540032 (±1.78%)
 (4, 2048, 65536)   14.380352 (±0.04%)   26.603489 (±0.02%)            1.065024 (±2.32%)                 1.072160 (±4.51%)
(4, 2048, 131072)   28.714176 (±0.02%)   53.575359 (±0.00%)            2.154912 (±4.45%)                 2.178496 (±1.83%)
          average    9.477701385815939   16.711663981278736           0.7134293342630068                0.7169759844740232

Autotuning times:
  (4, 2048, 4096): 366.2s
  (4, 2048, 8192): 620.8s
  (4, 2048, 16384): 562.0s
  (4, 2048, 32768): 333.7s
  (4, 2048, 65536): 373.6s
  (4, 2048, 131072): 478.9s
  Total: 2735.2s

With Changes

HELION_AUTOTUNE_RANDOM_SEED=42 HELION_AUTOTUNE_OVERLAP_COMPILATION=1 python benchmarks/run.py --kernel jsd
        (B, T, V)    torch_jsd-latency    liger_jsd-latency    torch_compile_jsd-latency    helion_jsd_tritonbench-latency
-----------------  -------------------  -------------------  ---------------------------  --------------------------------
  (4, 2048, 4096)    0.968864 (±0.37%)    0.256608 (±1.50%)            0.082912 (±2.20%)                 0.082240 (±2.18%)
  (4, 2048, 8192)    1.888704 (±0.28%)    0.496800 (±2.24%)            0.151808 (±1.35%)                 0.147456 (±1.28%)
 (4, 2048, 16384)    3.669632 (±0.10%)    6.336160 (±0.11%)            0.282176 (±0.68%)                 0.276704 (±0.65%)
 (4, 2048, 32768)    7.235584 (±0.10%)   13.000544 (±0.14%)            0.542880 (±1.54%)                 0.539776 (±2.38%)
 (4, 2048, 65536)   14.377536 (±0.04%)   26.617311 (±0.13%)            1.077696 (±3.54%)                 1.029952 (±7.33%)
(4, 2048, 131072)   28.712128 (±0.02%)   53.658783 (±0.00%)            2.151968 (±1.93%)                 2.177408 (±1.62%)
          average    9.475407868623734   16.727701038122177           0.7149066577355067                0.7089226792256037

Autotuning times:
  (4, 2048, 4096): 299.9s
  (4, 2048, 8192): 492.0s
  (4, 2048, 16384): 317.6s
  (4, 2048, 32768): 251.9s
  (4, 2048, 65536): 322.2s
  (4, 2048, 131072): 485.3s
  Total: 2168.8s

Next steps

I think we should benchmark this change and from there we can add tests and determine if we would like to ship this or add additional layers of complexity for some additional performance gain.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 10, 2026
@hinriksnaer hinriksnaer marked this pull request as draft February 10, 2026 20:58
@oulgen
Copy link
Copy Markdown
Contributor

oulgen commented Feb 10, 2026

I love the idea. Would be curious to see whether this overlapping causes any consistency/measurement problems.

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Feb 13, 2026

My main concern with this would be how stable the perf measurements are. We have had a lot of issues with unstable benchmark results leading to the autotuner making incorrect decisions.

Could we benchmark the stddev/min/max/mean/median of the raw perf measurements of a single kernel with and without this change. Right now the autotuner takes the median. We could also possibly offset more noise by running additional trials.

@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

@jansel Currently working on some more standardized benchmarking for autotune to make sure we can better quantify meaningful metrics that address these kinds of concerns. I can tag you in a draft PR tomorrow once it's ready. That way we can get a discussion going to scope out some valuable metrics that could help make these contributions a clearer yes or no.

@ethche
Copy link
Copy Markdown
Contributor

ethche commented Feb 19, 2026

@hinriksnaer @jansel I think right now a lot of the CI autotuning jobs are not working well, so I'll do my own benchmarking. My take is that if we see the autotuner get identical results faster, we should enable this -- although we prob want to check for smaller shapes and on alternative hardware.

I'll also do the data analysis around whether this introduces bias into perf measurements. I have some code set up to extract all perf quantiles. I'll take a look and share the results.

@ethche
Copy link
Copy Markdown
Contributor

ethche commented Feb 24, 2026

Quick update on this: I ran the B200 benchmarks and found that the overlapping reduces gives an improvement in wall clock time from 0.43x geomean wall-clock reduction compared to pattern search to 0.48x geomean wall-clock time reduction. @hinriksnaer you can run the CI job with your new logging to confirm.

The good news is that I don't see any discernible loss in performance due to overlapping. I'll work on the data analysis to see at the config level whether this introduces variance into measurement.

b200_multi_method_comparison_relative_comparison

@ethche
Copy link
Copy Markdown
Contributor

ethche commented Feb 24, 2026

@hinriksnaer @jansel Just as in the adaptive compile time pr #1384, I wonder if there is a simple check we can do in the initial population benchmarking to see whether compile time overlapping will introduce bias (as this could also vary depending on the user's CPU).

Could we introduce a check the compares the benchmark vs the rebenchmark results and if we see that the difference is within some tolerance we keep overlapping (and disable it if not)?

@hinriksnaer hinriksnaer marked this pull request as ready for review February 25, 2026 18:58
@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

hinriksnaer commented Feb 25, 2026

Made some new changes that hopefully address your concerns @jansel

There is now a HELION_OVERLAP_STABILITY_THRESHOLD that is 2.0 by default. This parameter is used during initial result verification but instead of re-benchmarking a subset of configs based on deviation from current best performance, we re-benchmark all of them. We then determine the average performance deviation and if we want to disable the overlap based on the threshold.

re-benchmarking everything might introduce too much of a bottleneck, so we might want to introduce some heuristic that samples a subset of the configurations instead.

@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

@jansel last test failure was autograd test timeout, don't think it is related to these changes.

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Mar 3, 2026

I'm looking for more data here on how this effects measurement noise. Not end-to-end autotuning perf, but individual measurement noise.

@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

Converting this to draft for now. Hoping to land some updates to the search abstraction that streamlines the process of swapping out and testing various different compilation/benchmarking strategies. I'll revisit this once apples to apples comparisons and noise measurement is more natively supported.

cc: @jansel @ethche

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants