Skip to content

Commit c924c63

Browse files
committed
mutation: wip
1 parent ade7edf commit c924c63

3 files changed

Lines changed: 816 additions & 64 deletions

File tree

helion/_compiler/_inductor/codegen.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch._inductor.ir import ComputedBuffer
1919
from torch._inductor.ir import IRNode
2020
from torch._inductor.ir import Pointwise
21+
from torch._inductor.scheduler import BaseSchedulerNode
2122
import torch._inductor.lowering # noqa: F401
2223
from torch._inductor.ops_handler import DefaultHandler
2324
from torch._inductor.virtualized import V
@@ -153,18 +154,41 @@ def capture_epilogue(n: str) -> str:
153154
epilogue_items = []
154155
elif not tb._fusion_store_map:
155156
epilogue_items = list(tb._epilogue_specs.items())
157+
if len(epilogue_items) > 1:
158+
epilogue_items = []
156159
else:
157160
epilogue_items = []
158161

159162
extra_stores: list[ast.expr] = []
160163
for acc_name, nodes in epilogue_items:
161164
if nodes:
162-
acc_map = {acc_name: ast.unparse(value)}
165+
ep_nodes = list(nodes)
166+
if len(ep_nodes) == 1 and isinstance(ep_nodes[0], BaseSchedulerNode):
167+
ep_nodes = list(ep_nodes[0].get_nodes())
168+
if len(ep_nodes) > 1:
169+
graph_outputs = set(V.graph.get_output_names())
170+
filtered = [
171+
n
172+
for n in ep_nodes
173+
if any(o.get_name() in graph_outputs for o in n.get_outputs())
174+
]
175+
if filtered:
176+
ep_nodes = filtered
177+
value_str = ast.unparse(value)
178+
acc_map = {acc_name: value_str}
179+
# Add aliases to accumulator map for epilogue fusion
180+
alias_map = getattr(template_buffer, "_helion_alias_map", None)
181+
if alias_map:
182+
for alias in alias_map:
183+
acc_map.setdefault(alias, value_str)
184+
if template_buffer._output_aliases:
185+
for alias in template_buffer._output_aliases:
186+
acc_map.setdefault(alias, value_str)
163187
value = _invoke_pointwise_with_ops_handler(
164-
nodes, acc_map, subscript_names, capture_epilogue, "epilogue"
188+
ep_nodes, acc_map, subscript_names, capture_epilogue, "epilogue"
165189
)
166190
# If epilogue changes dtype, store to epilogue output buffer instead
167-
epilogue_out = nodes[-1].node
191+
epilogue_out = ep_nodes[-1].node
168192
if isinstance(epilogue_out, ComputedBuffer):
169193
if epilogue_out.get_dtype() != V.graph.get_dtype(acc_name):
170194
param = capture_epilogue(epilogue_out.get_name())
@@ -292,6 +316,11 @@ def codegen_prologue_fusion(
292316
if not nodes:
293317
return value
294318

319+
if input_name in tb._helion_mutated_input_names:
320+
if input_name in tb._prologue_fused_once:
321+
return value
322+
tb._prologue_fused_once.add(input_name)
323+
295324
subscript_names = _get_subscript_names(state, subscript)
296325
if not subscript_names:
297326
return value

0 commit comments

Comments
 (0)