@@ -351,6 +351,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
351351 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
352352 self ._precompile_args_path : str | None = None
353353 self ._precompile_result_counter = count ()
354+ self ._crashed_config_strs : set [str ] = set ()
354355
355356 def _prepare (self ) -> None :
356357 """Some initialization deferred until autotuning actually runs.
@@ -495,6 +496,32 @@ def _try_load_checkpoint(self) -> bool:
495496
496497 def _recompile_after_checkpoint (self ) -> None :
497498 """Recompile after loading a checkpoint. Override in subclasses."""
499+
500+ def _load_crashed_configs (self ) -> None :
501+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
502+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
503+ if checkpoint_dir_str is None :
504+ return
505+ crashed_configs_path = (
506+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
507+ )
508+ if crashed_configs_path .exists ():
509+ self ._crashed_config_strs |= {
510+ line .strip ()
511+ for line in crashed_configs_path .read_text ().splitlines ()
512+ if line .strip ()
513+ }
514+ if self ._crashed_config_strs :
515+ self .log (
516+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
517+ )
518+
519+ def _get_pending_config_path (self ) -> Path | None :
520+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
521+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
522+ if checkpoint_dir_str is None :
523+ return None
524+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
498525 def _compute_baseline (
499526 self ,
500527 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -717,6 +744,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
717744 Returns:
718745 The performance of the configuration in ms.
719746 """
747+ # Skip configs that previously crashed the subprocess
748+ config_str = str (config )
749+ if config_str in self ._crashed_config_strs :
750+ self .log .warning (f"Skipping known-crashed config: { config } " )
751+ return inf
752+
720753 self ._autotune_metrics .num_configs_tested += 1
721754 self .counters ["benchmark" ] += 1
722755 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -997,13 +1030,36 @@ def _benchmark(
9971030 A list of BenchmarkResult entries containing the configuration, compiled
9981031 callable, measured performance, status, and compilation time.
9991032 """
1033+ # Filter out known-crashed configs before compilation
1034+ if self ._crashed_config_strs :
1035+ original_len = len (configs )
1036+ configs = [c for c in configs if str (c ) not in self ._crashed_config_strs ]
1037+ skipped = original_len - len (configs )
1038+ if skipped :
1039+ self .log .warning (
1040+ f"Skipped { skipped } known-crashed config(s) before compilation"
1041+ )
1042+ if not configs :
1043+ return []
1044+
10001045 fns : list [Callable [..., object ]] = []
10011046 valid_configs : list [Config ] = []
10021047 futures : list [PrecompileFuture ] | None = None
1048+ pending_path = self ._get_pending_config_path ()
10031049 for i , config in enumerate (configs ):
1050+ # Write sentinel before compile so a hard crash (SIGKILL /
1051+ # CUDA IMA) leaves a trace the crash recovery script can find.
1052+ if pending_path is not None :
1053+ pending_path .write_text (str (config ))
10041054 try :
10051055 fn = self .kernel .compile_config (config , allow_print = False )
1006- except Exception :
1056+ except Exception as e :
1057+ if match_unrecoverable_runtime_error (e ):
1058+ # Leave sentinel for crash recovery — CUDA context is
1059+ # corrupted and the process cannot continue.
1060+ raise
1061+ if pending_path is not None :
1062+ pending_path .unlink (missing_ok = True )
10071063 # If all configs failed, raise error
10081064 if not valid_configs and i == len (configs ) - 1 :
10091065 raise
@@ -1013,9 +1069,14 @@ def _benchmark(
10131069 exc_info = True ,
10141070 )
10151071 continue
1072+ if pending_path is not None :
1073+ pending_path .unlink (missing_ok = True )
10161074 fns .append (fn )
10171075 valid_configs .append (config )
10181076 configs = valid_configs
1077+ # NOTE: precompile runs in separate subprocesses with isolated CUDA
1078+ # contexts; crashes there are caught via is_working checks, not
1079+ # sentinels.
10191080 if self .settings .autotune_precompile :
10201081 futures = list (
10211082 starmap (
@@ -1077,7 +1138,14 @@ def _benchmark(
10771138 )
10781139 )
10791140 # benchmark one-by-one to avoid noisy results
1141+ # Write pending-config sentinel; cleared after benchmark.
1142+ # On crash the file stays so the crash recovery script can
1143+ # detect which config caused the failure.
1144+ if pending_path is not None :
1145+ pending_path .write_text (str (config ))
10801146 perf = self .benchmark_function (config , fn )
1147+ if pending_path is not None :
1148+ pending_path .unlink (missing_ok = True )
10811149 status = "ok" if math .isfinite (perf ) else "error"
10821150 # Log completion after benchmarking
10831151 self .log .record_autotune_entry (
@@ -1182,6 +1250,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11821250
11831251 if not self ._try_load_checkpoint ():
11841252 self ._init_search ()
1253+ self ._load_crashed_configs ()
11851254 try :
11861255 best = self ._autotune ()
11871256 self ._cleanup_checkpoint ()
0 commit comments