@@ -905,6 +905,10 @@ class _PerKiterTmaArgs:
905905 # (cute_plan.md §6.12.7). Default 1 preserves byte-identity for the
906906 # validated cluster_m=2 cluster_n=1 path.
907907 cluster_n : int = 1
908+ # Static-full one-CTA pipelined TMA loops can drop the per-K runtime
909+ # full-tile branch and scalar fallback. Non-pipelined/asymmetric or two-CTA
910+ # TMA paths must keep the guarded fallback path.
911+ static_full_tiles : bool = False
908912
909913
910914def _kloop_tma_copy_a_src (args : _PerKiterTmaArgs , * , k_offset : str ) -> str :
@@ -993,8 +997,13 @@ def _build_kloop_pipeline_producer_if(
993997 assert args .use_tma_a and args .use_tma_b , (
994998 "pipelined branch requires both A and B to be TMA-loaded"
995999 )
1000+ assert not (args .static_full_tiles and args .is_two_cta ), (
1001+ "static-full fast path is only valid for one-CTA pipelined TMA loops"
1002+ )
9961003 k_offset = f"{ args .tma_k_tile } + cutlass.Int32({ args .ab_stage_count } )"
997- predicate_terms = [args .tma_full_tile ]
1004+ predicate_terms = []
1005+ if not args .static_full_tiles :
1006+ predicate_terms .append (args .tma_full_tile )
9981007 if gate_tma_warp :
9991008 predicate_terms .append (args .tma_warp )
10001009 predicate_terms .append (args .tma_next_full_tile )
@@ -1035,6 +1044,17 @@ def _build_kloop_pipeline_consumer_if(
10351044 sync_before_scalar_fallback : bool = False ,
10361045) -> ast .stmt :
10371046 """Per-K-iter TMA consumer / scalar-fallback ``if`` for the pipelined branch."""
1047+ if args .static_full_tiles :
1048+ assert not args .is_two_cta , (
1049+ "static-full fast path is only valid for one-CTA pipelined TMA loops"
1050+ )
1051+ assert gate_exec_warp , "static-full fast path requires an exec-warp gate"
1052+ assert not include_scalar_fallback , (
1053+ "static-full fast path has no scalar fallback branch"
1054+ )
1055+ assert not sync_before_scalar_fallback , (
1056+ "static-full fast path has no scalar fallback presync"
1057+ )
10381058 if args .skip_consumer_wait :
10391059 consumer_src = "pass"
10401060 else :
@@ -1056,8 +1076,10 @@ def _build_kloop_pipeline_consumer_if(
10561076 gate_exec_warp = gate_exec_warp ,
10571077 cluster_n = args .cluster_n ,
10581078 ),
1059- indent = " " ,
1079+ indent = "" if args . static_full_tiles else " " ,
10601080 )
1081+ if args .static_full_tiles :
1082+ return statement_from_string (full_tile_src )
10611083 fallback_src = ""
10621084 if include_scalar_fallback :
10631085 scalar_load_a_src = textwrap .indent (ast .unparse (args .scalar_load_a ), " " )
@@ -1117,6 +1139,14 @@ def _build_kloop_pipeline_release_if(
11171139 advance their local consumer state. Peer CTAs participate via the
11181140 multicast mask; separate peer arrivals over-count the empty barrier.
11191141 """
1142+ if args .static_full_tiles :
1143+ assert not args .is_two_cta , (
1144+ "static-full fast path is only valid for one-CTA pipelined TMA loops"
1145+ )
1146+ assert gate_exec_warp , "static-full fast path requires an exec-warp gate"
1147+ assert not include_scalar_fallback , (
1148+ "static-full fast path has no scalar fallback branch"
1149+ )
11201150 release_src = f"{ args .tma_pipeline } .consumer_release({ args .tma_consumer_state } )"
11211151 release_gate = _tcgen05_two_cta_owner_predicate (
11221152 args .exec_active ,
@@ -1125,19 +1155,22 @@ def _build_kloop_pipeline_release_if(
11251155 cluster_n = args .cluster_n ,
11261156 )
11271157 advance_src = emit_pipeline_advance (args .tma_consumer_state )
1158+ indent = "" if args .static_full_tiles else " "
11281159 if args .is_two_cta :
11291160 # With gate_exec_warp=False the caller is already inside the
11301161 # role-local exec loop, so every iteration can advance local state.
11311162 advance_gate = args .exec_active if gate_exec_warp else None
11321163 full_tile_src = (
1133- _tcgen05_emit_optional_gate (release_src , release_gate , indent = " " )
1164+ _tcgen05_emit_optional_gate (release_src , release_gate , indent = indent )
11341165 + "\n "
1135- + _tcgen05_emit_optional_gate (advance_src , advance_gate , indent = " " )
1166+ + _tcgen05_emit_optional_gate (advance_src , advance_gate , indent = indent )
11361167 )
11371168 else :
11381169 full_tile_src = _tcgen05_emit_optional_gate (
1139- release_src + "\n " + advance_src , release_gate , indent = " "
1170+ release_src + "\n " + advance_src , release_gate , indent = indent
11401171 )
1172+ if args .static_full_tiles :
1173+ return statement_from_string (full_tile_src )
11411174 fallback_src = (
11421175 "\n else:\n cute.arch.sync_threads()" if include_scalar_fallback else ""
11431176 )
@@ -1208,6 +1241,9 @@ def _build_kloop_non_pipeline_producer_if(
12081241 offset on the cute.copy, and no ``advance`` here (the release block
12091242 advances both producer and consumer state).
12101243 """
1244+ assert not args .static_full_tiles , (
1245+ "static-full fast path is only valid for pipelined all-TMA K loops"
1246+ )
12111247 predicate_terms = [args .tma_full_tile ]
12121248 if gate_tma_warp :
12131249 predicate_terms .append (args .tma_warp )
@@ -1233,37 +1269,36 @@ def _build_kloop_non_pipeline_consumer_if(args: _PerKiterTmaArgs) -> ast.stmt:
12331269 into the full-tile branch (e.g. A-TMA + B-scalar still loads B
12341270 here on full tiles).
12351271 """
1236- scalar_load_a_src = textwrap .indent (ast .unparse (args .scalar_load_a ), " " )
1237- scalar_load_b_src = textwrap .indent (ast .unparse (args .scalar_load_b ), " " )
1238- scalar_load_a_tma_src = (
1239- textwrap .indent (ast .unparse (args .scalar_load_a ), " " ) + "\n "
1240- if not args .use_tma_a
1241- else ""
1242- )
1243- scalar_load_b_tma_src = (
1244- textwrap .indent (ast .unparse (args .scalar_load_b ), " " ) + "\n "
1245- if not args .use_tma_b
1246- else ""
1272+ assert not args .static_full_tiles , (
1273+ "static-full fast path is only valid for pipelined all-TMA K loops"
12471274 )
1248- src = (
1249- f"if { args .tma_full_tile } :\n "
1275+ scalar_load_a_src = ast .unparse (args .scalar_load_a )
1276+ scalar_load_b_src = ast .unparse (args .scalar_load_b )
1277+ scalar_load_a_tma_src = scalar_load_a_src + "\n " if not args .use_tma_a else ""
1278+ scalar_load_b_tma_src = scalar_load_b_src + "\n " if not args .use_tma_b else ""
1279+ full_body = (
12501280 f"{ scalar_load_a_tma_src } "
12511281 f"{ scalar_load_b_tma_src } "
1252- f" if { args .exec_active } :\n "
1253- " cute.arch.sync_warp()\n "
1282+ f"if { args .exec_active } :\n "
1283+ " cute.arch.sync_warp()\n "
12541284 + (
1255- " pass\n "
1285+ " pass\n "
12561286 if args .skip_consumer_wait
12571287 else (
1258- f" { args .tma_pipeline } .consumer_wait("
1288+ f" { args .tma_pipeline } .consumer_wait("
12591289 f"{ args .tma_consumer_state } , { args .tma_consumer_try_token } )\n "
12601290 )
12611291 )
1262- + " cute.arch.sync_threads()\n "
1263- + "else:\n "
1264- + f"{ scalar_load_a_src } \n "
1265- + f"{ scalar_load_b_src } \n "
1266- + " cute.arch.sync_threads()"
1292+ + "cute.arch.sync_threads()"
1293+ )
1294+ fallback_body = (
1295+ f"{ scalar_load_a_src } \n { scalar_load_b_src } \n cute.arch.sync_threads()"
1296+ )
1297+ src = (
1298+ f"if { args .tma_full_tile } :\n "
1299+ f"{ textwrap .indent (full_body , ' ' )} \n "
1300+ "else:\n "
1301+ f"{ textwrap .indent (fallback_body , ' ' )} "
12671302 )
12681303 return statement_from_string (src )
12691304
@@ -1276,20 +1311,25 @@ def _build_kloop_non_pipeline_release_if(args: _PerKiterTmaArgs) -> ast.stmt:
12761311 consumer state normally advance here. The producer advance is omitted
12771312 only by the guarded invalid-output bridge diagnostic.
12781313 """
1314+ assert not args .static_full_tiles , (
1315+ "static-full fast path is only valid for pipelined all-TMA K loops"
1316+ )
12791317 producer_advance_src = (
1280- emit_pipeline_advance (args .tma_producer_state , indent = " " ) + "\n "
1318+ emit_pipeline_advance (args .tma_producer_state , indent = "" ) + "\n "
12811319 if not args .skip_producer_advance
12821320 else ""
12831321 )
1322+ full_body = (
1323+ "cute.arch.sync_threads()\n "
1324+ f"if { args .exec_active } :\n "
1325+ " cute.arch.sync_warp()\n "
1326+ f" { args .tma_pipeline } .consumer_release({ args .tma_consumer_state } )\n "
1327+ + producer_advance_src
1328+ + emit_pipeline_advance (args .tma_consumer_state , indent = "" )
1329+ )
12841330 src = (
12851331 f"if { args .tma_full_tile } :\n "
1286- " cute.arch.sync_threads()\n "
1287- f" if { args .exec_active } :\n "
1288- " cute.arch.sync_warp()\n "
1289- f" { args .tma_pipeline } .consumer_release({ args .tma_consumer_state } )\n "
1290- + producer_advance_src
1291- + emit_pipeline_advance (args .tma_consumer_state , indent = " " )
1292- + "\n "
1332+ f"{ textwrap .indent (full_body , ' ' )} \n "
12931333 "else:\n "
12941334 " cute.arch.sync_threads()"
12951335 )
@@ -1970,6 +2010,12 @@ def _emit_mma_pipeline(
19702010 )
19712011 # Keep a distinct name so future MMA-exec gating changes are localized.
19722012 tcgen05_use_role_local_mma_exec = tcgen05_use_role_local_tma_producer
2013+ tcgen05_static_full_tma_fast_path = (
2014+ tcgen05_static_full_tiles
2015+ and tcgen05_use_tma_pipeline
2016+ and not tcgen05_is_two_cta
2017+ and not tcgen05_use_role_local_mma_exec
2018+ )
19732019 tcgen05_acc_producer_mode = df .config .get (
19742020 TCGEN05_ACC_PRODUCER_MODE_CONFIG_KEY ,
19752021 TCGEN05_ACC_PRODUCER_MODE_NORMAL ,
@@ -3954,6 +4000,7 @@ def _tcgen05_tma_tile_predicate(
39544000 mma_stage_stmt : ast .stmt | None = None
39554001 smem_a_mma_stmt : ast .stmt | None = None
39564002 smem_b_mma_stmt : ast .stmt | None = None
4003+ tma_full_tile_predicate_src : str | None = None
39574004 tma_full_tile_stmt : ast .stmt | None = None
39584005 if mma_impl == "tcgen05" :
39594006 assert tcgen05_plan is not None
@@ -4001,16 +4048,17 @@ def _tcgen05_tma_tile_predicate(
40014048 tma_k_tile_stmt = statement_from_string (
40024049 f"{ tma_k_tile } = { k_offset_var } // cutlass.Int32({ bk } )"
40034050 )
4051+ tma_full_tile_predicate_src = _tcgen05_tma_tile_predicate (
4052+ k_tile_start_expr = k_offset_var ,
4053+ full_tile_end_expr = f"{ k_offset_var } + cutlass.Int32({ bk } )" ,
4054+ )
40044055 tma_full_tile_stmt = statement_from_string (
4005- f"{ tma_full_tile } = "
4006- + _tcgen05_tma_tile_predicate (
4007- k_tile_start_expr = k_offset_var ,
4008- full_tile_end_expr = f"{ k_offset_var } + cutlass.Int32({ bk } )" ,
4009- )
4056+ f"{ tma_full_tile } = " + tma_full_tile_predicate_src
40104057 )
40114058 if not tcgen05_use_role_local_mma_exec :
40124059 cg .add_statement (tma_k_tile_stmt )
4013- cg .add_statement (tma_full_tile_stmt )
4060+ if not tcgen05_static_full_tma_fast_path :
4061+ cg .add_statement (tma_full_tile_stmt )
40144062 smem_a_store = f"{ smem_a } [_row, _col]"
40154063 smem_b_store = f"{ smem_b } [_row, _col]"
40164064 if mma_impl == "tcgen05" :
@@ -4081,19 +4129,21 @@ def _tcgen05_tma_tile_predicate(
40814129 scalar_load_a = scalar_load_a ,
40824130 scalar_load_b = scalar_load_b ,
40834131 cluster_n = tcgen05_cluster_n ,
4132+ static_full_tiles = tcgen05_static_full_tma_fast_path ,
40844133 )
40854134 if tcgen05_use_tma_pipeline :
40864135 if tcgen05_use_role_local_tma_producer :
4136+ # The static-full fast path is non-role-local only; the
4137+ # role-local loops keep the full-tile predicate and scalar
4138+ # fallback structure.
4139+ assert not tcgen05_static_full_tma_fast_path
4140+ assert tma_full_tile_predicate_src is not None
40874141 producer_loop_body : list [ast .stmt ] = [
40884142 statement_from_string (
40894143 f"{ tma_k_tile } = { k_offset_var } // cutlass.Int32({ bk } )"
40904144 ),
40914145 statement_from_string (
4092- f"{ tma_full_tile } = "
4093- + _tcgen05_tma_tile_predicate (
4094- k_tile_start_expr = k_offset_var ,
4095- full_tile_end_expr = f"{ k_offset_var } + cutlass.Int32({ bk } )" ,
4096- )
4146+ f"{ tma_full_tile } = " + tma_full_tile_predicate_src
40974147 ),
40984148 statement_from_string (
40994149 f"{ tma_next_full_tile } = "
@@ -4120,6 +4170,7 @@ def _tcgen05_tma_tile_predicate(
41204170 assert smem_a_mma_stmt is not None
41214171 assert smem_b_mma_stmt is not None
41224172 assert tma_full_tile_stmt is not None
4173+ assert not tcgen05_static_full_tma_fast_path
41234174 exec_loop_body : list [ast .stmt ] = [
41244175 mma_stage_stmt ,
41254176 smem_a_mma_stmt ,
@@ -4220,8 +4271,12 @@ def _tcgen05_tma_tile_predicate(
42204271 cg .add_statement (
42214272 _build_kloop_pipeline_consumer_if (
42224273 tma_kloop_args ,
4274+ include_scalar_fallback = (
4275+ not tcgen05_static_full_tma_fast_path
4276+ ),
42234277 sync_before_scalar_fallback = (
42244278 tcgen05_sync_before_scalar_fallback
4279+ and not tcgen05_static_full_tma_fast_path
42254280 ),
42264281 )
42274282 )
@@ -4315,7 +4370,12 @@ def _tcgen05_tma_tile_predicate(
43154370 assert tma_kloop_args is not None
43164371 if tcgen05_use_tma_pipeline :
43174372 cg .add_statement (
4318- _build_kloop_pipeline_release_if (tma_kloop_args )
4373+ _build_kloop_pipeline_release_if (
4374+ tma_kloop_args ,
4375+ include_scalar_fallback = (
4376+ not tcgen05_static_full_tma_fast_path
4377+ ),
4378+ )
43194379 )
43204380 else :
43214381 cg .add_statement (
0 commit comments