Skip to content

Commit a7affe7

Browse files
committed
[cutedsl] skip static full tma fallback
1 parent 032e34e commit a7affe7

2 files changed

Lines changed: 182 additions & 66 deletions

File tree

helion/_compiler/cute/cute_mma.py

Lines changed: 107 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,10 @@ class _PerKiterTmaArgs:
905905
# (cute_plan.md §6.12.7). Default 1 preserves byte-identity for the
906906
# validated cluster_m=2 cluster_n=1 path.
907907
cluster_n: int = 1
908+
# Static-full one-CTA pipelined TMA loops can drop the per-K runtime
909+
# full-tile branch and scalar fallback. Non-pipelined/asymmetric or two-CTA
910+
# TMA paths must keep the guarded fallback path.
911+
static_full_tiles: bool = False
908912

909913

910914
def _kloop_tma_copy_a_src(args: _PerKiterTmaArgs, *, k_offset: str) -> str:
@@ -993,8 +997,13 @@ def _build_kloop_pipeline_producer_if(
993997
assert args.use_tma_a and args.use_tma_b, (
994998
"pipelined branch requires both A and B to be TMA-loaded"
995999
)
1000+
assert not (args.static_full_tiles and args.is_two_cta), (
1001+
"static-full fast path is only valid for one-CTA pipelined TMA loops"
1002+
)
9961003
k_offset = f"{args.tma_k_tile} + cutlass.Int32({args.ab_stage_count})"
997-
predicate_terms = [args.tma_full_tile]
1004+
predicate_terms = []
1005+
if not args.static_full_tiles:
1006+
predicate_terms.append(args.tma_full_tile)
9981007
if gate_tma_warp:
9991008
predicate_terms.append(args.tma_warp)
10001009
predicate_terms.append(args.tma_next_full_tile)
@@ -1035,6 +1044,17 @@ def _build_kloop_pipeline_consumer_if(
10351044
sync_before_scalar_fallback: bool = False,
10361045
) -> ast.stmt:
10371046
"""Per-K-iter TMA consumer / scalar-fallback ``if`` for the pipelined branch."""
1047+
if args.static_full_tiles:
1048+
assert not args.is_two_cta, (
1049+
"static-full fast path is only valid for one-CTA pipelined TMA loops"
1050+
)
1051+
assert gate_exec_warp, "static-full fast path requires an exec-warp gate"
1052+
assert not include_scalar_fallback, (
1053+
"static-full fast path has no scalar fallback branch"
1054+
)
1055+
assert not sync_before_scalar_fallback, (
1056+
"static-full fast path has no scalar fallback presync"
1057+
)
10381058
if args.skip_consumer_wait:
10391059
consumer_src = "pass"
10401060
else:
@@ -1056,8 +1076,10 @@ def _build_kloop_pipeline_consumer_if(
10561076
gate_exec_warp=gate_exec_warp,
10571077
cluster_n=args.cluster_n,
10581078
),
1059-
indent=" ",
1079+
indent="" if args.static_full_tiles else " ",
10601080
)
1081+
if args.static_full_tiles:
1082+
return statement_from_string(full_tile_src)
10611083
fallback_src = ""
10621084
if include_scalar_fallback:
10631085
scalar_load_a_src = textwrap.indent(ast.unparse(args.scalar_load_a), " ")
@@ -1117,6 +1139,14 @@ def _build_kloop_pipeline_release_if(
11171139
advance their local consumer state. Peer CTAs participate via the
11181140
multicast mask; separate peer arrivals over-count the empty barrier.
11191141
"""
1142+
if args.static_full_tiles:
1143+
assert not args.is_two_cta, (
1144+
"static-full fast path is only valid for one-CTA pipelined TMA loops"
1145+
)
1146+
assert gate_exec_warp, "static-full fast path requires an exec-warp gate"
1147+
assert not include_scalar_fallback, (
1148+
"static-full fast path has no scalar fallback branch"
1149+
)
11201150
release_src = f"{args.tma_pipeline}.consumer_release({args.tma_consumer_state})"
11211151
release_gate = _tcgen05_two_cta_owner_predicate(
11221152
args.exec_active,
@@ -1125,19 +1155,22 @@ def _build_kloop_pipeline_release_if(
11251155
cluster_n=args.cluster_n,
11261156
)
11271157
advance_src = emit_pipeline_advance(args.tma_consumer_state)
1158+
indent = "" if args.static_full_tiles else " "
11281159
if args.is_two_cta:
11291160
# With gate_exec_warp=False the caller is already inside the
11301161
# role-local exec loop, so every iteration can advance local state.
11311162
advance_gate = args.exec_active if gate_exec_warp else None
11321163
full_tile_src = (
1133-
_tcgen05_emit_optional_gate(release_src, release_gate, indent=" ")
1164+
_tcgen05_emit_optional_gate(release_src, release_gate, indent=indent)
11341165
+ "\n"
1135-
+ _tcgen05_emit_optional_gate(advance_src, advance_gate, indent=" ")
1166+
+ _tcgen05_emit_optional_gate(advance_src, advance_gate, indent=indent)
11361167
)
11371168
else:
11381169
full_tile_src = _tcgen05_emit_optional_gate(
1139-
release_src + "\n" + advance_src, release_gate, indent=" "
1170+
release_src + "\n" + advance_src, release_gate, indent=indent
11401171
)
1172+
if args.static_full_tiles:
1173+
return statement_from_string(full_tile_src)
11411174
fallback_src = (
11421175
"\nelse:\n cute.arch.sync_threads()" if include_scalar_fallback else ""
11431176
)
@@ -1208,6 +1241,9 @@ def _build_kloop_non_pipeline_producer_if(
12081241
offset on the cute.copy, and no ``advance`` here (the release block
12091242
advances both producer and consumer state).
12101243
"""
1244+
assert not args.static_full_tiles, (
1245+
"static-full fast path is only valid for pipelined all-TMA K loops"
1246+
)
12111247
predicate_terms = [args.tma_full_tile]
12121248
if gate_tma_warp:
12131249
predicate_terms.append(args.tma_warp)
@@ -1233,37 +1269,36 @@ def _build_kloop_non_pipeline_consumer_if(args: _PerKiterTmaArgs) -> ast.stmt:
12331269
into the full-tile branch (e.g. A-TMA + B-scalar still loads B
12341270
here on full tiles).
12351271
"""
1236-
scalar_load_a_src = textwrap.indent(ast.unparse(args.scalar_load_a), " ")
1237-
scalar_load_b_src = textwrap.indent(ast.unparse(args.scalar_load_b), " ")
1238-
scalar_load_a_tma_src = (
1239-
textwrap.indent(ast.unparse(args.scalar_load_a), " ") + "\n"
1240-
if not args.use_tma_a
1241-
else ""
1242-
)
1243-
scalar_load_b_tma_src = (
1244-
textwrap.indent(ast.unparse(args.scalar_load_b), " ") + "\n"
1245-
if not args.use_tma_b
1246-
else ""
1272+
assert not args.static_full_tiles, (
1273+
"static-full fast path is only valid for pipelined all-TMA K loops"
12471274
)
1248-
src = (
1249-
f"if {args.tma_full_tile}:\n"
1275+
scalar_load_a_src = ast.unparse(args.scalar_load_a)
1276+
scalar_load_b_src = ast.unparse(args.scalar_load_b)
1277+
scalar_load_a_tma_src = scalar_load_a_src + "\n" if not args.use_tma_a else ""
1278+
scalar_load_b_tma_src = scalar_load_b_src + "\n" if not args.use_tma_b else ""
1279+
full_body = (
12501280
f"{scalar_load_a_tma_src}"
12511281
f"{scalar_load_b_tma_src}"
1252-
f" if {args.exec_active}:\n"
1253-
" cute.arch.sync_warp()\n"
1282+
f"if {args.exec_active}:\n"
1283+
" cute.arch.sync_warp()\n"
12541284
+ (
1255-
" pass\n"
1285+
" pass\n"
12561286
if args.skip_consumer_wait
12571287
else (
1258-
f" {args.tma_pipeline}.consumer_wait("
1288+
f" {args.tma_pipeline}.consumer_wait("
12591289
f"{args.tma_consumer_state}, {args.tma_consumer_try_token})\n"
12601290
)
12611291
)
1262-
+ " cute.arch.sync_threads()\n"
1263-
+ "else:\n"
1264-
+ f"{scalar_load_a_src}\n"
1265-
+ f"{scalar_load_b_src}\n"
1266-
+ " cute.arch.sync_threads()"
1292+
+ "cute.arch.sync_threads()"
1293+
)
1294+
fallback_body = (
1295+
f"{scalar_load_a_src}\n{scalar_load_b_src}\ncute.arch.sync_threads()"
1296+
)
1297+
src = (
1298+
f"if {args.tma_full_tile}:\n"
1299+
f"{textwrap.indent(full_body, ' ')}\n"
1300+
"else:\n"
1301+
f"{textwrap.indent(fallback_body, ' ')}"
12671302
)
12681303
return statement_from_string(src)
12691304

@@ -1276,20 +1311,25 @@ def _build_kloop_non_pipeline_release_if(args: _PerKiterTmaArgs) -> ast.stmt:
12761311
consumer state normally advance here. The producer advance is omitted
12771312
only by the guarded invalid-output bridge diagnostic.
12781313
"""
1314+
assert not args.static_full_tiles, (
1315+
"static-full fast path is only valid for pipelined all-TMA K loops"
1316+
)
12791317
producer_advance_src = (
1280-
emit_pipeline_advance(args.tma_producer_state, indent=" ") + "\n"
1318+
emit_pipeline_advance(args.tma_producer_state, indent="") + "\n"
12811319
if not args.skip_producer_advance
12821320
else ""
12831321
)
1322+
full_body = (
1323+
"cute.arch.sync_threads()\n"
1324+
f"if {args.exec_active}:\n"
1325+
" cute.arch.sync_warp()\n"
1326+
f" {args.tma_pipeline}.consumer_release({args.tma_consumer_state})\n"
1327+
+ producer_advance_src
1328+
+ emit_pipeline_advance(args.tma_consumer_state, indent="")
1329+
)
12841330
src = (
12851331
f"if {args.tma_full_tile}:\n"
1286-
" cute.arch.sync_threads()\n"
1287-
f" if {args.exec_active}:\n"
1288-
" cute.arch.sync_warp()\n"
1289-
f" {args.tma_pipeline}.consumer_release({args.tma_consumer_state})\n"
1290-
+ producer_advance_src
1291-
+ emit_pipeline_advance(args.tma_consumer_state, indent=" ")
1292-
+ "\n"
1332+
f"{textwrap.indent(full_body, ' ')}\n"
12931333
"else:\n"
12941334
" cute.arch.sync_threads()"
12951335
)
@@ -1970,6 +2010,12 @@ def _emit_mma_pipeline(
19702010
)
19712011
# Keep a distinct name so future MMA-exec gating changes are localized.
19722012
tcgen05_use_role_local_mma_exec = tcgen05_use_role_local_tma_producer
2013+
tcgen05_static_full_tma_fast_path = (
2014+
tcgen05_static_full_tiles
2015+
and tcgen05_use_tma_pipeline
2016+
and not tcgen05_is_two_cta
2017+
and not tcgen05_use_role_local_mma_exec
2018+
)
19732019
tcgen05_acc_producer_mode = df.config.get(
19742020
TCGEN05_ACC_PRODUCER_MODE_CONFIG_KEY,
19752021
TCGEN05_ACC_PRODUCER_MODE_NORMAL,
@@ -3954,6 +4000,7 @@ def _tcgen05_tma_tile_predicate(
39544000
mma_stage_stmt: ast.stmt | None = None
39554001
smem_a_mma_stmt: ast.stmt | None = None
39564002
smem_b_mma_stmt: ast.stmt | None = None
4003+
tma_full_tile_predicate_src: str | None = None
39574004
tma_full_tile_stmt: ast.stmt | None = None
39584005
if mma_impl == "tcgen05":
39594006
assert tcgen05_plan is not None
@@ -4001,16 +4048,17 @@ def _tcgen05_tma_tile_predicate(
40014048
tma_k_tile_stmt = statement_from_string(
40024049
f"{tma_k_tile} = {k_offset_var} // cutlass.Int32({bk})"
40034050
)
4051+
tma_full_tile_predicate_src = _tcgen05_tma_tile_predicate(
4052+
k_tile_start_expr=k_offset_var,
4053+
full_tile_end_expr=f"{k_offset_var} + cutlass.Int32({bk})",
4054+
)
40044055
tma_full_tile_stmt = statement_from_string(
4005-
f"{tma_full_tile} = "
4006-
+ _tcgen05_tma_tile_predicate(
4007-
k_tile_start_expr=k_offset_var,
4008-
full_tile_end_expr=f"{k_offset_var} + cutlass.Int32({bk})",
4009-
)
4056+
f"{tma_full_tile} = " + tma_full_tile_predicate_src
40104057
)
40114058
if not tcgen05_use_role_local_mma_exec:
40124059
cg.add_statement(tma_k_tile_stmt)
4013-
cg.add_statement(tma_full_tile_stmt)
4060+
if not tcgen05_static_full_tma_fast_path:
4061+
cg.add_statement(tma_full_tile_stmt)
40144062
smem_a_store = f"{smem_a}[_row, _col]"
40154063
smem_b_store = f"{smem_b}[_row, _col]"
40164064
if mma_impl == "tcgen05":
@@ -4081,19 +4129,21 @@ def _tcgen05_tma_tile_predicate(
40814129
scalar_load_a=scalar_load_a,
40824130
scalar_load_b=scalar_load_b,
40834131
cluster_n=tcgen05_cluster_n,
4132+
static_full_tiles=tcgen05_static_full_tma_fast_path,
40844133
)
40854134
if tcgen05_use_tma_pipeline:
40864135
if tcgen05_use_role_local_tma_producer:
4136+
# The static-full fast path is non-role-local only; the
4137+
# role-local loops keep the full-tile predicate and scalar
4138+
# fallback structure.
4139+
assert not tcgen05_static_full_tma_fast_path
4140+
assert tma_full_tile_predicate_src is not None
40874141
producer_loop_body: list[ast.stmt] = [
40884142
statement_from_string(
40894143
f"{tma_k_tile} = {k_offset_var} // cutlass.Int32({bk})"
40904144
),
40914145
statement_from_string(
4092-
f"{tma_full_tile} = "
4093-
+ _tcgen05_tma_tile_predicate(
4094-
k_tile_start_expr=k_offset_var,
4095-
full_tile_end_expr=f"{k_offset_var} + cutlass.Int32({bk})",
4096-
)
4146+
f"{tma_full_tile} = " + tma_full_tile_predicate_src
40974147
),
40984148
statement_from_string(
40994149
f"{tma_next_full_tile} = "
@@ -4120,6 +4170,7 @@ def _tcgen05_tma_tile_predicate(
41204170
assert smem_a_mma_stmt is not None
41214171
assert smem_b_mma_stmt is not None
41224172
assert tma_full_tile_stmt is not None
4173+
assert not tcgen05_static_full_tma_fast_path
41234174
exec_loop_body: list[ast.stmt] = [
41244175
mma_stage_stmt,
41254176
smem_a_mma_stmt,
@@ -4220,8 +4271,12 @@ def _tcgen05_tma_tile_predicate(
42204271
cg.add_statement(
42214272
_build_kloop_pipeline_consumer_if(
42224273
tma_kloop_args,
4274+
include_scalar_fallback=(
4275+
not tcgen05_static_full_tma_fast_path
4276+
),
42234277
sync_before_scalar_fallback=(
42244278
tcgen05_sync_before_scalar_fallback
4279+
and not tcgen05_static_full_tma_fast_path
42254280
),
42264281
)
42274282
)
@@ -4315,7 +4370,12 @@ def _tcgen05_tma_tile_predicate(
43154370
assert tma_kloop_args is not None
43164371
if tcgen05_use_tma_pipeline:
43174372
cg.add_statement(
4318-
_build_kloop_pipeline_release_if(tma_kloop_args)
4373+
_build_kloop_pipeline_release_if(
4374+
tma_kloop_args,
4375+
include_scalar_fallback=(
4376+
not tcgen05_static_full_tma_fast_path
4377+
),
4378+
)
43194379
)
43204380
else:
43214381
cg.add_statement(

0 commit comments

Comments
 (0)